In [1]:
import os
import glob
import numpy as np
import einops


clip_img_preds = sorted(glob.glob("/SSD/slava/algonauts/algonauts_2023_challenge_submission_clip/*/"))
clip_txt_preds = sorted(glob.glob("/SSD/slava/algonauts/algonauts_2023_challenge_submission_clip_textonly/*/"))

clip_preds = sorted(glob.glob("/SSD/slava/algonauts/algonauts_2023_challenge_submission_clip_text_n_image/*/"))
sam_preds = sorted(glob.glob("/SSD/slava/algonauts/algonauts_2023_challenge_submission_sam_extended_kernel16/*/"))

data_dir = "/SSD/slava/algonauts/algonauts_2023_challenge_data/"
save_dir = "/SSD/slava/algonauts/algonauts_2023_challenge_submission_clip_w_text_sam_roi_v2"

clip_img_preds[0], clip_txt_preds[0], sam_preds[0]

('/SSD/slava/algonauts/algonauts_2023_challenge_submission_clip/subj01/',
 '/SSD/slava/algonauts/algonauts_2023_challenge_submission_clip_textonly/subj01/',
 '/SSD/slava/algonauts/algonauts_2023_challenge_submission_sam_extended_kernel16/subj01/')

In [2]:
best_ids_dict = {'V1v': 1,
 'V1d': 1,
 'V2v': 1,
 'V2d': 1,
 'V3v': 1,
 'V3d': 1,
 'hV4': 1,
 'EBA': 0,
 'FBA-1': 0,
 'FBA-2': 0,
 'OFA': 0,
 'FFA-1': 0,
 'FFA-2': 0,
 'OPA': 0,
 'PPA': 0,
 'RSC': 0,
 'OWFA': 0,
 'VWFA-1': 0,
 'VWFA-2': 0,
 'mfs-words': 0,
 'mTL-words': 0,
 'early': 1,
 'midventral': 1,
 'midlateral': 1,
 'midparietal': 1,
 'ventral': 0,
 'lateral': 0,
 'parietal': 0}

In [3]:
rois = ["V1v", "V1d", "V2v", "V2d", "V3v", "V3d", "hV4", "EBA", "FBA-1", "FBA-2", "mTL-bodies", "OFA", "FFA-1", "FFA-2", "mTL-faces", "aTL-faces", "OPA", "PPA", "RSC", "OWFA", "VWFA-1", "VWFA-2", "mfs-words", "mTL-words", "early", "midventral", "midlateral", "midparietal", "ventral", "lateral", "parietal"]

def get_roi_class(roi):
    if roi in ["V1v", "V1d", "V2v", "V2d", "V3v", "V3d", "hV4"]:
        roi_class = 'prf-visualrois'
    elif roi in ["EBA", "FBA-1", "FBA-2", "mTL-bodies"]:
        roi_class = 'floc-bodies'
    elif roi in ["OFA", "FFA-1", "FFA-2", "mTL-faces", "aTL-faces"]:
        roi_class = 'floc-faces'
    elif roi in ["OPA", "PPA", "RSC"]:
        roi_class = 'floc-places'
    elif roi in ["OWFA", "VWFA-1", "VWFA-2", "mfs-words", "mTL-words"]:
        roi_class = 'floc-words'
    elif roi in ["early", "midventral", "midlateral", "midparietal", "ventral", "lateral", "parietal"]:
        roi_class = 'streams'

    return roi_class

In [4]:
def read_load(path):
    f = open(path, 'rb')
    data = np.load(f)
    return data

def write_save(data, path):
    f = open(path, 'wb')
    np.save(f, data)
    

for idx, (clip_img_path, clip_txt_path, sam_path, clip_path) in enumerate(zip(clip_img_preds, clip_txt_preds, sam_preds, clip_preds)):
    # for rh and lh
    for part in os.listdir(clip_img_path):
        clip_img_pred = os.path.join(clip_img_path, part)
        clip_txt_pred = os.path.join(clip_txt_path, part)
        clip_pred = os.path.join(clip_path, part)
        sam_pred = os.path.join(sam_path, part)

        clip_img_pred = read_load(clip_img_pred)
        clip_txt_pred = read_load(clip_txt_pred)
        
        clip_pred = read_load(clip_pred)
        sam_pred = read_load(sam_pred)
        hemisphere = part[0]

        ensemble_pred = (clip_pred + sam_pred)/2

        for roi in rois:
            if roi not in best_ids_dict:
                continue
            
            roi_class = get_roi_class(roi)
            
            # Load the ROI brain surface maps
            challenge_roi_class_dir = os.path.join(data_dir, f'subj0{idx+1}', 'roi_masks',
                hemisphere[0]+'h.'+roi_class+'_challenge_space.npy')
            
            roi_map_dir = os.path.join(data_dir, f'subj0{idx+1}', 'roi_masks',
                'mapping_'+roi_class+'.npy')
            challenge_roi_class = np.load(challenge_roi_class_dir)
            roi_map = np.load(roi_map_dir, allow_pickle=True).item()
            
            # Select the vertices corresponding to the ROI of interest
            roi_mapping = list(roi_map.keys())[list(roi_map.values()).index(roi)]
            challenge_roi = np.asarray(challenge_roi_class == roi_mapping, dtype=int)
            
            challenge_roi = einops.repeat(challenge_roi, 'h -> n h', n=ensemble_pred.shape[0])

            best_id = best_ids_dict[roi]
            
            if best_id==0:
                best_pred = clip_pred
            elif best_id==1:
                best_pred = sam_pred
                
            
            ensemble_pred = ((1-challenge_roi)*ensemble_pred) + (best_pred * challenge_roi)
            assert abs(challenge_roi.sum() - ((best_pred * challenge_roi) == ensemble_pred).sum())<5, \
                f'{roi} {challenge_roi.sum(), ((best_pred * challenge_roi) == ensemble_pred).sum()}'

        os.makedirs(f'{save_dir}/subj0{idx+1}', exist_ok=True)
        write_save(np.float32(ensemble_pred), 
                  f'{save_dir}/subj0{idx+1}/{part}')

    print(idx+1)



1
2
3
4
5
6
7
8


In [5]:
glob.glob(save_dir + "/*/*npy")

['/SSD/slava/algonauts/algonauts_2023_challenge_submission_clip_w_text_sam_roi_v2/subj04/rh_pred_test.npy',
 '/SSD/slava/algonauts/algonauts_2023_challenge_submission_clip_w_text_sam_roi_v2/subj04/lh_pred_test.npy',
 '/SSD/slava/algonauts/algonauts_2023_challenge_submission_clip_w_text_sam_roi_v2/subj05/rh_pred_test.npy',
 '/SSD/slava/algonauts/algonauts_2023_challenge_submission_clip_w_text_sam_roi_v2/subj05/lh_pred_test.npy',
 '/SSD/slava/algonauts/algonauts_2023_challenge_submission_clip_w_text_sam_roi_v2/subj08/rh_pred_test.npy',
 '/SSD/slava/algonauts/algonauts_2023_challenge_submission_clip_w_text_sam_roi_v2/subj08/lh_pred_test.npy',
 '/SSD/slava/algonauts/algonauts_2023_challenge_submission_clip_w_text_sam_roi_v2/subj01/rh_pred_test.npy',
 '/SSD/slava/algonauts/algonauts_2023_challenge_submission_clip_w_text_sam_roi_v2/subj01/lh_pred_test.npy',
 '/SSD/slava/algonauts/algonauts_2023_challenge_submission_clip_w_text_sam_roi_v2/subj03/rh_pred_test.npy',
 '/SSD/slava/algonauts/algon

In [9]:
# test the equation

ens_1 = np.ones((5, 5)) * 5
ens_2 = np.ones((5, 5)) * 3

ens = (ens_1 + ens_2)/2

print(ens)

for i in range(2):
    mask = np.random.choice([0, 1], size=(5, 5), p=[.5, .5])

    if i == 0:
        best_pred = ens_1
    else:
        best_pred = ens_2
    
    ens = ((1-mask)*ens) + (best_pred * mask)

    print(mask)
    print(ens)
    


[[4. 4. 4. 4. 4.]
 [4. 4. 4. 4. 4.]
 [4. 4. 4. 4. 4.]
 [4. 4. 4. 4. 4.]
 [4. 4. 4. 4. 4.]]
[[1 0 0 0 1]
 [0 0 0 1 0]
 [1 0 1 1 0]
 [0 1 0 1 1]
 [0 1 1 1 0]]
[[5. 4. 4. 4. 5.]
 [4. 4. 4. 5. 4.]
 [5. 4. 5. 5. 4.]
 [4. 5. 4. 5. 5.]
 [4. 5. 5. 5. 4.]]
[[1 1 0 0 0]
 [0 1 1 1 1]
 [1 0 0 0 1]
 [1 0 1 0 0]
 [0 1 0 1 1]]
[[3. 3. 4. 4. 5.]
 [4. 3. 3. 3. 3.]
 [3. 4. 5. 5. 3.]
 [3. 5. 3. 5. 5.]
 [4. 3. 5. 3. 3.]]
