In [None]:
import glob 
import os
from PIL import Image
from PIL import ImageDraw
from PIL import ImageOps
from PIL import ImageFont

In [None]:
#creates surface figure for one contrast/group/timepoint

def create_subcortical_figure_by_group(group,ses,task,contrast):
    
    #check if file exists
    if len(glob.glob(f'../../../derivatives/task_analysis_surface/visualization/raw_indiv_figures/task-{task}/group-{group}_ses-{ses}_task-{task}_contrast-{contrast}_threshold-*_n-*_display-coronal.png'))==0:
        print(f'no data for group {group} session {ses} task {task} contrast {contrast} - cannot make subcortical figure for this group')
        return 
    
    #get coronal image and its dim
    coronal_path = glob.glob(f'../../../derivatives/task_analysis_surface/visualization/raw_indiv_figures/task-{task}/group-{group}_ses-{ses}_task-{task}_contrast-{contrast}_threshold-*_n-*_display-coronal.png')[0]
    coronal=Image.open(coronal_path,mode='r')
    w_c,h_c=coronal.size
    
    #find area locations of the 4 brains + cbar
    area_coronal_1 = (0, 0, 219, h_c)
    area_coronal_2 = (219, 0, 438, h_c)
    area_coronal_3 = (438, 0, 657, h_c)
    area_coronal_4 = (657, 0, 876, h_c)
    area_coronal_cbar = (897, 0, w_c, h_c)

    coronal_crop_1 = coronal.crop(area_coronal_1)
    coronal_crop_2 = coronal.crop(area_coronal_2)
    coronal_crop_3 = coronal.crop(area_coronal_3)
    coronal_crop_4 = coronal.crop(area_coronal_4)
    coronal_crop_cbar = coronal.crop(area_coronal_cbar)
    
    # Get the width and height of the original image
    cbar_width, cbar_height = coronal_crop_cbar.size

    # Crop the left half of the image
    flip_point = int(cbar_width*0.6)
    cbar_left_half = coronal_crop_cbar.crop((0, 0, flip_point, cbar_height))

    # Crop the right half of the image
    cbar_right_half = coronal_crop_cbar.crop((flip_point, 0, cbar_width, cbar_height))
    cbar_right_half = cbar_right_half.resize((cbar_right_half.width+10, cbar_right_half.height))
    
    cbar_offset_horizontal_left = 14
    flipped_coronal_crop_cbar = Image.new('RGB', (cbar_left_half.width+cbar_right_half.width+cbar_offset_horizontal_left, cbar_height))

    # Paste the left half on the right side of the new image
    flipped_coronal_crop_cbar.paste(cbar_right_half, (cbar_offset_horizontal_left, 0))
    flipped_coronal_crop_cbar.paste(cbar_left_half, ( cbar_offset_horizontal_left+cbar_right_half.width, 0))
    
    coronal_cut_height = coronal_crop_1.height
    coronal_cut_width_no_cbar = coronal_crop_1.width

    #resize color bar while maintaining the aspect ratio
    cbar_offset_vertical = 17
    desired_cbar_height = coronal_cut_height*2-cbar_offset_vertical*2
    aspect_ratio = flipped_coronal_crop_cbar.width/flipped_coronal_crop_cbar.height
    desired_cbar_width = int(desired_cbar_height*aspect_ratio)
    resized_coronal_cbar = flipped_coronal_crop_cbar.resize((desired_cbar_width, desired_cbar_height))
    
    #set up new image
    space_between_brains = 40
    left_offset = 23
    new_im = Image.new('RGB', ( (coronal_cut_width_no_cbar*2+resized_coronal_cbar.width+space_between_brains+left_offset) , coronal_cut_height*2) ,(0, 0, 0, 1)) #make a new image

    #paste in the cut out images
    new_im.paste(coronal_crop_1,(left_offset,0)) 
    new_im.paste(coronal_crop_2,(left_offset+coronal_cut_width_no_cbar+space_between_brains,0))
    new_im.paste(resized_coronal_cbar,(left_offset+coronal_cut_width_no_cbar*2+space_between_brains,cbar_offset_vertical))

    new_im.paste(coronal_crop_3,(left_offset,coronal_cut_height)) 
    new_im.paste(coronal_crop_4,(left_offset+coronal_cut_width_no_cbar+space_between_brains,coronal_cut_height))

    threshold = coronal_path.split('threshold-')[1].split('_')[0]
    part_count = coronal_path.split('_n-')[1].split('_')[0]
    
    #create paths to output dir if not exist
    derivatives_path = '../../../derivatives'
    nilearn_output_path = os.path.join(derivatives_path, 'task_analysis_surface','visualization','indiv_figures',f'task-{task}')
    if not os.path.isdir(nilearn_output_path):
        os.makedirs (nilearn_output_path)

    new_im.save(f'../../../derivatives/task_analysis_surface/visualization/indiv_figures/task-{task}/group-{group}_ses-{ses}_task-{task}_contrast-{contrast}_threshold-{threshold}_n-{part_count}_cifti_figure_subcortical.png')
        
    return


In [None]:
#creates surface figure for one contrast/group/timepoint

def create_surf_figure_by_group(group,ses,task,contrast):
    area = (30, 60, 370, 330)
    cbar_area = (390, 80, 460, 310)
    
    if len(glob.glob(f'../../../derivatives/task_analysis_surface/visualization/raw_indiv_figures/task-{task}/group-{group}_ses-{ses}_task-{task}_contrast-{contrast}_threshold-*_n-*_display-righthemi.png'))==0:
        print(f'no data for group {group} session {ses} task {task} contrast {contrast} - cannot make surface figure for this group')
        return 

    right_hemi_path = glob.glob(f'../../../derivatives/task_analysis_surface/visualization/raw_indiv_figures/task-{task}/group-{group}_ses-{ses}_task-{task}_contrast-{contrast}_threshold-*_n-*_display-righthemi.png')[0]
    right_hemi=Image.open(right_hemi_path,mode='r')
    rH = right_hemi.crop(area) #crop just the part of the image that you want

    right_flat_path = glob.glob(f'../../../derivatives/task_analysis_surface/visualization/raw_indiv_figures/task-{task}/group-{group}_ses-{ses}_task-{task}_contrast-{contrast}_threshold-*_n-*_display-rightflat.png')[0]
    right_flat=Image.open(right_flat_path,mode='r')
    rF = right_flat.crop(area) #crop just the part of the image that you want

    left_hemi_path = glob.glob(f'../../../derivatives/task_analysis_surface/visualization/raw_indiv_figures/task-{task}/group-{group}_ses-{ses}_task-{task}_contrast-{contrast}_threshold-*_n-*_display-lefthemi.png')[0]
    left_hemi=Image.open(left_hemi_path,mode='r')
    lH = left_hemi.crop(area) #crop just the part of the image that you want

    left_flat_path = glob.glob(f'../../../derivatives/task_analysis_surface/visualization/raw_indiv_figures/task-{task}/group-{group}_ses-{ses}_task-{task}_contrast-{contrast}_threshold-*_n-*_display-leftflat.png')[0]
    left_flat=Image.open(left_flat_path,mode='r')
    lF = left_flat.crop(area) #crop just the part of the image that you want
    lF = lF.transpose(Image.FLIP_LEFT_RIGHT)
    
    cbar = right_hemi.crop(cbar_area) #crop just the part of the image that you want
    
    w,h=lF.size
    
    #resize color bar while maintaining the aspect ratio
    cbar_offset = 15
    desired_cbar_height = h*2-cbar_offset*2
    aspect_ratio = cbar.width/cbar.height
    desired_cbar_width = int(desired_cbar_height*aspect_ratio)
    resized_coronal_cbar = cbar.resize((desired_cbar_width, desired_cbar_height))

    
    new_im = Image.new('RGB', ( (w*2+resized_coronal_cbar.width) , h*2) ,(255, 255, 255, 1)) #make a new image
                                       
    #paste in the cut out images
    new_im.paste(lH,(0,0))
    new_im.paste(rH,(w,0))
    new_im.paste(lF,(0,h)) 
    new_im.paste(rF,(w,h))
    new_im.paste(resized_coronal_cbar,(w*2,cbar_offset))
                                       
    threshold = right_hemi_path.split('threshold-')[1].split('_')[0]
    part_count = right_hemi_path.split('_n-')[1].split('_')[0]
    
    new_im.save(f'../../../derivatives/task_analysis_surface/visualization/indiv_figures/task-{task}/group-{group}_ses-{ses}_task-{task}_contrast-{contrast}_threshold-{threshold}_n-{part_count}_cifti_figure_surface.png')
                
    return


In [None]:
#stitches together surface and subcortical for one contrast/group/timepoint

def create_cifti_figure_by_group(group,ses,task,contrast):
    
    if len(glob.glob(f'../../../derivatives/task_analysis_surface/visualization/raw_indiv_figures/task-{task}/group-{group}_ses-{ses}_task-{task}_contrast-{contrast}_threshold-*_n-*_display-righthemi.png'))==0 or len(glob.glob(f'../../../derivatives/task_analysis_surface/visualization/raw_indiv_figures/task-{task}/group-{group}_ses-{ses}_task-{task}_contrast-{contrast}_threshold-*_n-*_display-coronal.png'))==0:
        print(f'no data for group {group} session {ses} task {task} contrast {contrast} - cannot make whole grayordinate figure for this group')
        return (None,None)
    
    surface_path = glob.glob(f'../../../derivatives/task_analysis_surface/visualization/indiv_figures/task-{task}/group-{group}_ses-{ses}_task-{task}_contrast-{contrast}_threshold-*_n-*_cifti_figure_surface.png')[0]
    surface = Image.open(surface_path,mode='r')
    w_surf, h_surf = surface.size
    
    subcortical_path = glob.glob(f'../../../derivatives/task_analysis_surface/visualization/indiv_figures/task-{task}/group-{group}_ses-{ses}_task-{task}_contrast-{contrast}_threshold-*_n-*_cifti_figure_subcortical.png')[0]
    subcortical = Image.open(subcortical_path,mode='r')
    subcortical = ImageOps.contain(subcortical, (w_surf,w_surf))
    w_subcortical, h_subcortical = subcortical.size
    
    
    new_im = Image.new('RGB', ( w_subcortical , h_surf+h_subcortical) ,(255, 255, 255, 1)) #make a new image
    #paste in the cut out images
    new_im.paste(surface,(0,0)) 
    new_im.paste(subcortical,(0,h_surf))

    threshold = surface_path.split('threshold-')[1].split('_')[0]
    part_count = surface_path.split('_n-')[1].split('_')[0]

    new_im.save(f'../../../derivatives/task_analysis_surface/visualization/indiv_figures/task-{task}/group-{group}_ses-{ses}_task-{task}_contrast-{contrast}_threshold-{threshold}_n-{part_count}_cifti_figure_by_group.png')
    
    return (h_surf, h_subcortical)

In [None]:
#stitch together the groups/timepoints for one contrast

def create_cifti_figure_by_contrast(task, contrast):

    if len(glob.glob(f'../../../derivatives/task_analysis_surface/visualization/indiv_figures/task-{task}/group-*_ses-*_*task-{task}_contrast-{contrast}_*_cifti_figure_by_group.png'))!=3:
        print(f'incomplete or too much data for contrast {contrast} of task {task} - cannot make grayordinate figure for this contrast')
        print('number of figures for same contrast:',len(glob.glob(f'../../../derivatives/task_analysis_surface/visualization/indiv_figures/task-{task}/group-*_ses-*_*task-{task}_contrast-{contrast}_*_cifti_figure_by_group.png')))
        return None
    
    png_paths = glob.glob(f'../../../derivatives/task_analysis_surface/visualization/indiv_figures/task-{task}/group-*_ses-*_*task-{task}_contrast-{contrast}_*_cifti_figure_by_group.png')
    HCb_path = glob.glob(f'../../../derivatives/task_analysis_surface/visualization/indiv_figures/task-{task}/group-HC_ses-baseline_task-{task}_contrast-{contrast}_*_cifti_figure_by_group.png')[0]
    MMb_path = glob.glob(f'../../../derivatives/task_analysis_surface/visualization/indiv_figures/task-{task}/group-MM_ses-baseline_task-{task}_contrast-{contrast}_*_cifti_figure_by_group.png')[0]
    MM1_path = glob.glob(f'../../../derivatives/task_analysis_surface/visualization/indiv_figures/task-{task}/group-MM_ses-1year_task-{task}_contrast-{contrast}_*_cifti_figure_by_group.png')[0]
    HCb=Image.open(HCb_path,mode='r')
    MMb=Image.open(MMb_path,mode='r')
    MM1=Image.open(MM1_path,mode='r')
    w,h = HCb.size

    new_im = Image.new('RGB', ( (w*3) , h) ,(255, 255, 255, 1)) #make a new image
    #paste in the cut out images
    new_im.paste(HCb,(0,0)) 
    new_im.paste(MMb,(w,0))
    new_im.paste(MM1,(w*2,0))
        
    #create paths to output dir if not exist
    derivatives_path = '../../../derivatives'
    nilearn_output_path = os.path.join(derivatives_path, 'task_analysis_surface','visualization','per_contrast_figures',f'task-{task}')
    if not os.path.isdir(nilearn_output_path):
        os.makedirs (nilearn_output_path)
       
    
    HCb_threshold = HCb_path.split('threshold-')[1].split('_')[0]
    HCb_part_count = HCb_path.split('_n-')[1].split('_')[0]

    MMb_threshold = MMb_path.split('threshold-')[1].split('_')[0]
    MMb_part_count = MMb_path.split('_n-')[1].split('_')[0]

    MM1_threshold = MM1_path.split('threshold-')[1].split('_')[0]
    MM1_part_count = MM1_path.split('_n-')[1].split('_')[0]
    

    new_im.save(f'../../../derivatives/task_analysis_surface/visualization/per_contrast_figures/task-{task}/task-{task}_contrast-{contrast}_HCb_threshold-{HCb_threshold}_HCb_n-{HCb_part_count}_MMb_threshold-{MMb_threshold}_MMb_n-{MMb_part_count}_MM1_threshold-{MM1_threshold}_MM1_n-{MM1_part_count}_cifti_figure_by_contrast.png')                                        

    return w


In [None]:
#stitch together the various contrasts for a task

def create_surf_figure_all(task,width_surf,height_surf,height_subcortical,contrasts,supplemental):
    
    labels_dict = get_labels_dict(task)

    font = ImageFont.truetype("../templates/fonts/G_ari_bd.TTF",40)
    font_heading = ImageFont.truetype("../templates/fonts/G_ari_bd.TTF",50)
    
    #create paths to output dir if not exist
    derivatives_path = '../../../derivatives'
    nilearn_output_path = os.path.join(derivatives_path, 'task_analysis_surface','visualization','complete_figures',f'task-{task}')
    if not os.path.isdir(nilearn_output_path):
        os.makedirs (nilearn_output_path)
        
    #get the paths to the contrast figures that were possible to be generated
    per_contrast_figures_paths = glob.glob(f'../../../derivatives/task_analysis_surface/visualization/per_contrast_figures/task-{task}/task-{task}_contrast-*_HCb_threshold-*_HCb_n-*_MMb_threshold-*_MMb_n-*_MM1_threshold-*_MM1_n-*_cifti_figure_by_contrast.png')
    contrast_paths_dict = {img_path.split('contrast-')[1].split('_')[0]:img_path for img_path in per_contrast_figures_paths}
    contrast_count = len(contrasts)
    
    if len(contrast_paths_dict.keys()) != len(per_contrast_figures_paths):
        print('two or more per contrast figures for the same contrast found -- address this by deleting the undesired one first to make sure correct one is added to complete figure')
        return
    
    first_img=Image.open(per_contrast_figures_paths[0],mode='r')
    w,h=first_img.size
    
    margin_height = 45 #height of top margin
    new_im = Image.new('RGB', (w, (50+h)*contrast_count+margin_height) ,(255, 255, 255, 1)) #make a new image
    
    for i in range(contrast_count):
        contrast = contrasts[i]
        img_path = contrast_paths_dict[contrast]
        img = Image.open(img_path,mode='r')
        new_im.paste(img,(0,margin_height+50+(h+50)*i))
        
    #resize to have consistent width of letterhead page
    letterhead_width = 8.5 * 300  #assuming 300 pixels per inch
    left_margin = 350

    #calculate the new height to maintain the aspect ratio
    aspect_ratio = new_im.width/new_im.height
    new_height = int((letterhead_width-left_margin)/aspect_ratio)

    #resize the image
    resized_image = new_im.resize((int(letterhead_width-left_margin), new_height))
    
    #create left margin image to be added to left of figure
    left_margin_im = Image.new('RGB', (left_margin, new_height) ,(255, 255, 255, 1)) #make a new image

    #paste left margin
    final_im = Image.new('RGB', (int(letterhead_width), new_height) ,(255, 255, 255, 1)) #make a new image
    final_im.paste(left_margin_im,(0,0))
    final_im.paste(resized_image,(left_margin,0))
    
    #get ratios for scaling
    resize_width_ratio = resized_image.width/new_im.width
    resize_height_ratio = resized_image.height/new_im.height
    
    draw = ImageDraw.Draw(final_im)
    
    top_offset = 25
#     draw.text(((0.5*width_surf+left_margin-15)*resize_width_ratio,top_offset),"HC baseline",fill=(0,0,0),font=font_heading,anchor="ma")
#     draw.text(((1.5*width_surf+left_margin-15)*resize_width_ratio,top_offset),"MCC baseline",fill=(0,0,0),font=font_heading,anchor="ma")
#     draw.text(((2.5*width_surf+left_margin-15)*resize_width_ratio,top_offset),"MCC one-year",fill=(0,0,0),font=font_heading,anchor="ma")
    draw.text(((0.5*width_surf-65)*resize_width_ratio+left_margin,top_offset),"HC baseline",fill=(0,0,0),font=font_heading,anchor="ma")
    draw.text(((1.5*width_surf-65)*resize_width_ratio+left_margin,top_offset),"MCC baseline",fill=(0,0,0),font=font_heading,anchor="ma")
    draw.text(((2.5*width_surf-65)*resize_width_ratio+left_margin,top_offset),"MCC one-year",fill=(0,0,0),font=font_heading,anchor="ma")
    
    
    for i in range(contrast_count):
        contrast = contrasts[i]
        label = labels_dict[contrast]
        draw.multiline_text((10,(height_surf+(50+h)*i)*resize_height_ratio),f'{label}',fill=(0,0,0),font=font,anchor="la")

    #display complete figure and save it 
    display(final_im)
    
    if supplemental:
        supplemental_figure_path = f'../../../derivatives/task_analysis_surface/visualization/complete_figures/task-{task}/task-{task}_figure_supplemental.png'
        final_im.save(supplemental_figure_path)
    else:
        figure_path = f'../../../derivatives/task_analysis_surface/visualization/complete_figures/task-{task}/task-{task}_figure.png'
        final_im.save(figure_path)
    
    return

In [None]:
def get_labels_dict(task):
    
    if task == 'mid':
        labels_dict={ 'HiRewCue-NeuCue':"High Reward\nCue vs.\nNeutral Cue", #high reward anticipation
                      'HiRewCue-LoRewCue':"High Reward\nCue vs.\nLow Reward Cue", #high vs. low reward anticipation
                      'RewCue-NeuCue':"Reward Cue vs.\nNeutral Cue", #combined reward anticipation
                      'LoRewCue-NeuCue':"Low Reward\nCue vs.\nNeutral Cue", #low reward anticipation -- ABCD
                     
                      'HiLossCue-NeuCue':"High Loss\nCue vs.\nNeutral Cue", #high loss anticipation
                      'HiLossCue-LoLossCue':"High Loss\nCue vs.\nLow Loss Cue", #high vs. low loss anticipation
                      'LossCue-NeuCue':"Loss Cue vs.\nNeutral Cue", #combined loss anticipation
                      'LoLossCue-NeuCue':"Low Loss\nCue vs.\nNeutral Cue", #low loss anticipation -- ABCD
                     
                      'HiRewCue-HiLossCue':"High Reward Cue vs.\nHigh Loss Cue", #high reward vs. high loss anticipation
                      'RewCue-LossCue':"Reward Cue vs.\nLoss Cue", #combined reward vs. loss anticipation

                      'HiRewCue-Baseline':"High Reward\nCue vs.\nBaseline", #high reward anticipation vs. baseline -- paper A4

                      'HiWin-NeuHit':"High Reward vs.\nNeutral Hit", #high reward outcome cp. to neutral hit
                      'Win-NeuHit':"Reward vs.\nNeutral Hit", #combined reward outcome cp. to neutral hit

                      'HiWin-HiNoWin':"High Reward vs.\nHigh Missed Reward", #high reward outcome cp. to high reward miss
                      'Win-NoWin':"Reward vs.\nMissed Reward", #combined reward outcome cp. to combined reward miss

                      'HiLoss-NeuMiss':"High Loss vs.\nNeutral Miss", #high loss cp. to neutral miss
                      'Loss-NeuMiss':"Loss vs.\nNeutral Miss", #combined loss cp. to neutral miss

                      'HiLoss-AvoidHiLoss':"High Loss vs.\nAvoided High Loss", #high loss cp. to high avoid loss
                      'Loss-AvoidLoss':"Loss vs.\nAvoided Loss", #combined loss cp. to combined avoid loss

                      'HiLoss-NeuHit':"High Loss vs.\nNeutral Hit", #high loss cp. to neutral hit
                      'Loss-NeuHit':"Loss vs.\nNeutral Hit", #combined loss cp. to neutral hit 

                      'HiWin-HiLoss':"High Reward vs.\nHigh Loss", #high reward outcome cp. to high loss
                      'Win-Loss':"Reward vs.\nLoss", #combined reward outcome cp. to combined loss 
        }
    
    elif task == 'sst':
        labels_dict={'SuccStop-Go':"Successful\nSTOP vs.\nGO",'UnsuccStop-Go':"Unsuccessful\nSTOP vs.\nGO",'UnsuccStop-SuccStop':"Unsuccessful\nSTOP vs.\nSuccessful\nSTOP"}

    elif task == 'nback':
        labels_dict={'twoback-zeroback': "Two-back vs.\nZero-back"}
    
    return labels_dict

In [None]:
def create_figure(groups,sessions,task,supplemental=False):
    
    if task == 'mid':
        #main text figure
        contrasts = [ 'HiRewCue-Baseline', #high reward anticipation vs. baseline -- paper A4
                      'Win-NoWin', #combined reward outcome cp. to combined reward miss -- ABCD
                    ]
        
        if supplemental:
            #supplemental figure
            contrasts = [ 'HiRewCue-NeuCue', #high reward anticipation -- paper A2, ABCD
                          'RewCue-NeuCue', #combined reward anticipation -- paper A1
                          'HiRewCue-LoRewCue', #high vs. low reward anticipation -- paper A3, ABCD
                          'HiLossCue-NeuCue', #high loss anticipation -- paper A5, ABCD
                          'LoLossCue-NeuCue', #low loss anticipation -- ABCD
                          'HiLossCue-LoLossCue', #high vs. low loss anticipation -- ABCD
                         ]

    elif task == 'sst':
        #main text figure
        contrasts=['SuccStop-Go','UnsuccStop-Go']
        
        if supplemental:
            #supplemental figure
            contrasts=['UnsuccStop-SuccStop']

    elif task == 'nback':
        #main text figure
        contrasts=['twoback-zeroback']
        
        if supplemental:
            print('supplemental figure not created for nback')
            return
    
    height_surf, height_subcortical, width_surf = None, None, None
    
    #create parts of figures
    for contrast in contrasts:
        for group in groups:
            for ses in sessions:
                if group == 'HC' and ses == '1year':
                    continue
                #make individual figures for subcortical and surface images
                create_subcortical_figure_by_group(group,ses,task,contrast)
                create_surf_figure_by_group(group,ses,task,contrast)
                #combine subcortical and surface images for one group and one timepoint and one contrast
                height_surf_out, height_subcortical_out = create_cifti_figure_by_group(group,ses,task,contrast)
                if height_surf_out != None:
                    height_surf = height_surf_out
                    height_subcortical = height_subcortical_out
        #combine the subcortical and surface images for one contrast across groups and timepoints
        width_surf_out = create_cifti_figure_by_contrast(task, contrast)
        if width_surf_out != None:
            width_surf = width_surf_out
    
    #make complete figure
    create_surf_figure_all(task,width_surf,height_surf,height_subcortical,contrasts,supplemental) 
    
    return

In [None]:
groups=['HC','MM']
sessions=['baseline','1year']
task='mid'
supplemental=False

#call function to that has function calls to make figure parts and final figure
create_figure(groups,sessions,task,supplemental)
