In [1]:
import numpy as np
import torch
from bin import representation_geometry
import nibabel as nib
import pandas as pd
torch.cuda.empty_cache()


In [2]:
shape_color_attention_niis = ["/home/spencer/Projects/monkey_fmri/MTurk1/misc_testing_files/shapecolorlow_runs/reg_run_beta_2.nii.gz",
                              "/home/spencer/Projects/monkey_fmri/MTurk1/misc_testing_files/shapecolorlow_runs/reg_run_beta_3.nii.gz",
                              "/home/spencer/Projects/monkey_fmri/MTurk1/misc_testing_files/shapecolorlow_runs/reg_run_beta_4.nii.gz",
                              "/home/spencer/Projects/monkey_fmri/MTurk1/misc_testing_files/shapecolorlow_runs/reg_run_beta_5.nii.gz"
                              ]
atlas_nii = nib.load("/home/spencer/Projects/monkey_fmri/MTurk1/D99_v2.0_dist/simplified_atlas.nii")
lookup = pd.read_csv("/home/spencer/Projects/monkey_fmri/MTurk1/D99_v2.0_dist/simplified_color_map.txt", header=1, index_col=0, delimiter='\t')

In [3]:
color_betas = []
achrom_color_betas = []
achrom_betas = []

for nii_path in shape_color_attention_niis:
    run_beta = nib.load(nii_path).get_fdata()
    # break the beta into different datasets
    color_betas.append(run_beta[:, :, :, 1:8])
    achrom_color_betas.append(run_beta[:, :, :, 8:15])
    achrom_betas.append(run_beta[:, :, :, 15:22])

color_betas = torch.from_numpy(np.concatenate(color_betas, axis=3))
achrom_betas = torch.from_numpy(np.concatenate(achrom_betas, axis=3))
achrom_color_betas = torch.from_numpy(np.concatenate(achrom_color_betas, axis=3))

atlas = torch.round(torch.from_numpy(atlas_nii.get_fdata())).int()

In [4]:
targets = torch.tile(torch.arange(7), (4,))
lookup_dict = {idx.item():lookup.loc[idx.item()]['Label Name:'] for i, idx in enumerate(torch.unique(atlas)[1:])}

In [None]:
sd = representation_geometry.ROIDecoder(atlas, lookup_dict, out_dim=7)

In [None]:
sd.fit(achrom_color_betas, targets, optim_threshold=1e-9, cutoff_epoch=10000)

In [None]:
res = sd.predict(color_betas, targets)

In [5]:
# now try searchlight
searchlight = representation_geometry.SearchLightDecoder(None, None, 7, kernel=5, dev='cuda:0')

In [35]:
# load from pickle
import pickle as pkl
searchlight = pkl.load(open('../MTurk1/misc_testing_files/searchlight_decode_uncolored_shapes_20000weird.pkl', 'rb'))

In [None]:
res = searchlight.fit(achrom_color_betas, targets, optim_threshold=-1, cutoff_epoch=10000, lr=.0001, batch_size=7)

128
torch.Size([28, 1, 128, 128, 128])
2
CE on epoch 0 is tensor(7.8866)
CE on epoch 1 is tensor(7.8848)
CE on epoch 2 is tensor(7.8833)
CE on epoch 3 is tensor(7.8817)
CE on epoch 4 is tensor(7.8804)
CE on epoch 5 is tensor(7.8795)
CE on epoch 6 is tensor(7.8783)
CE on epoch 7 is tensor(7.8770)
CE on epoch 8 is tensor(7.8761)
CE on epoch 9 is tensor(7.8752)
CE on epoch 10 is tensor(7.8743)
CE on epoch 11 is tensor(7.8734)
CE on epoch 12 is tensor(7.8726)
CE on epoch 13 is tensor(7.8719)
CE on epoch 14 is tensor(7.8710)
CE on epoch 15 is tensor(7.8704)
CE on epoch 16 is tensor(7.8695)
CE on epoch 17 is tensor(7.8689)
CE on epoch 18 is tensor(7.8682)
CE on epoch 19 is tensor(7.8675)
CE on epoch 20 is tensor(7.8668)
CE on epoch 21 is tensor(7.8661)
CE on epoch 22 is tensor(7.8656)
CE on epoch 23 is tensor(7.8650)
CE on epoch 24 is tensor(7.8644)
CE on epoch 25 is tensor(7.8639)
CE on epoch 26 is tensor(7.8634)
CE on epoch 27 is tensor(7.8628)
CE on epoch 28 is tensor(7.8622)
CE on epoch 

In [40]:
yhat, res = searchlight.predict(achrom_color_betas, targets)

128


In [42]:
# remove trials where decoding is not possible
decodable = []
for trial in res:
    if torch.min(trial) < 1.8:
        decodable.append(trial)
good_ce = torch.stack(decodable)

In [43]:
npres = good_ce.detach().numpy().mean(axis=0)

In [44]:
np.min(npres)

1.8704255

In [17]:
import pickle as pkl
pkl.dump(searchlight, open('../MTurk1/misc_testing_files/searchlight_decode_k5_uncolored_shapes_10000.pkl', 'wb'))

In [45]:
ce_nii = nib.Nifti1Image(npres, affine=atlas_nii.affine, header=atlas_nii.header)
nib.save(ce_nii, '../MTurk1/misc_testing_files/k5_train_train_uncolored_shape_searchlight_ce.nii.gz')