In [None]:
"""
Plotting voxelwise stresseffects
"""

import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap, Normalize
from nilearn import plotting, image
import numpy as np

# load second level NIfTI file
img_path = '/path/stress_effect_tvals_vox.nii'
stat_img = image.load_img(img_path)

# extract data from image
img_data = stat_img.get_fdata()
img_data = np.nan_to_num(img_data, nan=0.0, posinf=0.0, neginf=0.0)
vmin, vmax = np.min(img_data), np.max(img_data)
stat_img = image.new_img_like(stat_img, img_data)

# colormap
original_cmap = plt.cm.coolwarm
colors = original_cmap(np.linspace(0, 1, 256))
grey_idx_min, grey_idx_max = int(((-5.46 - vmin) / (vmax - vmin)) * 256), int(((5.46 - vmin) / (vmax - vmin)) * 256)
colors[grey_idx_min:grey_idx_max, :] = np.array([0.8, 0.8, 0.8, 1])
custom_cmap = LinearSegmentedColormap.from_list("custom_cmap", colors)

# define coordinates for slices
coords_display_modes = [
    ([-48], 'x'),  # sagittal plane
    ([16], 'y'),   # coronal plane
    ([3], 'z'),    # horizontal plane
    ([6], 'x'),    # sagittal plane
    ([-5], 'y'),   # coronal plane
    ([33], 'z')]   # horizontal plane

# Define plot settings
plot_settings = {
    'threshold': 0.01,
    'cmap': custom_cmap,
    'dim': -0.5,
    'vmin': vmin,
    'vmax': vmax,
    'alpha': 0.8}

# create a 2x3 grid for plots
fig, axes = plt.subplots(2, 3, figsize=(18, 12))
axes = axes.flatten()

# define subplot positions
positions = [
    [0.05, 0.55, 0.25, 0.35],  # first subplot
    [0.35, 0.55, 0.25, 0.35],  # second subplot
    [0.65, 0.55, 0.25, 0.35],  # third subplot
    [0.05, 0.1, 0.25, 0.35],   # fourth subplot
    [0.35, 0.1, 0.25, 0.35],   # fifth subplot
    [0.65, 0.1, 0.25, 0.35]]   # sixth subplot

# loop through each coordinate and display mode, plotting each in the 2x3 grid
for i, (cut_coords, display_mode) in enumerate(coords_display_modes):
    axes[i].set_position(positions[i])
    display = plotting.plot_stat_map(
        stat_img,
        cut_coords=cut_coords,
        display_mode=display_mode,
        axes=axes[i],
        colorbar=False,
        **plot_settings)

# create a ScalarMappable for the colorbar
sm = plt.cm.ScalarMappable(cmap=custom_cmap, norm=Normalize(vmin=vmin, vmax=vmax))
sm.set_array([])
cbar = fig.colorbar(sm, ax=axes, orientation='vertical', fraction=0.025, pad=0.04)
cbar.set_label('T-values', fontsize=20)
cbar.ax.tick_params(labelsize=20)

# save & show
plt.savefig('/path/stress_effect_tvals_vox_2by3_alpha.png', bbox_inches='tight')
plt.show()