In [1]:
import numpy as np
import torch
from bin import representation_geometry
import nibabel as nib
import pandas as pd

In [2]:
shape_color_attention_niis = ["/Users/loggiasr/Projects/fmri/monkey_fmri/MTurk1/subjects/wooster/sessions/20220603/2/reg_run_beta.nii.gz",
                              "/Users/loggiasr/Projects/fmri/monkey_fmri/MTurk1/subjects/wooster/sessions/20220603/3/reg_run_beta.nii.gz",
                              "/Users/loggiasr/Projects/fmri/monkey_fmri/MTurk1/subjects/wooster/sessions/20220603/4/reg_run_beta.nii.gz",
                              "/Users/loggiasr/Projects/fmri/monkey_fmri/MTurk1/subjects/wooster/sessions/20220603/5/reg_run_beta.nii.gz"
                              ]
atlas_nii = nib.load("/Users/loggiasr/Projects/fmri/monkey_fmri/MTurk1/D99_v2.0_dist/simplified_atlas.nii")
lookup = pd.read_csv("/Users/loggiasr/Projects/fmri/monkey_fmri/MTurk1/D99_v2.0_dist/simplified_color_map.txt", header=1, index_col=0, delimiter='\t')

In [3]:
color_betas = []
achrom_color_betas = []
achrom_betas = []

for nii_path in shape_color_attention_niis:
    run_beta = nib.load(nii_path).get_fdata()
    # break the beta into different datasets
    color_betas.append(run_beta[:, :, :, 1:8])
    achrom_color_betas.append(run_beta[:, :, :, 8:15])
    achrom_betas.append(run_beta[:, :, :, 15:22])

color_betas = torch.from_numpy(np.concatenate(color_betas, axis=3))
achrom_betas = torch.from_numpy(np.concatenate(achrom_betas, axis=3))
achrom_color_betas = torch.from_numpy(np.concatenate(achrom_color_betas, axis=3))

atlas = torch.round(torch.from_numpy(atlas_nii.get_fdata())).int()

In [4]:
targets = torch.tile(torch.arange(7), (4,))
lookup_dict = {idx.item():lookup.loc[idx.item()]['Label Name:'] for i, idx in enumerate(torch.unique(atlas)[1:])}

In [None]:
sd = representation_geometry.ROIDecoder(atlas, lookup_dict, out_dim=7)

In [None]:
sd.fit(achrom_color_betas, targets, optim_threshold=1e-9, cutoff_epoch=10000)

In [9]:
res = sd.predict(color_betas, targets)

In [5]:
# now try searchlight
searchlight = representation_geometry.SearchLightDecoder(None, None, 7, kernel=3, dev='cpu')

In [8]:
res = searchlight.fit(achrom_color_betas, targets, optim_threshold=-1, cutoff_epoch=500, downscale_factor=2, lr=.01)

128
torch.Size([28, 1, 64, 64, 64])
1
CE on epoch 0 is 0.44117085129838884
CE on epoch 1 is 0.4411267280913219
CE on epoch 2 is 0.4411227920568625
CE on epoch 3 is 0.44111993171347713
CE on epoch 4 is 0.4411121950112044
CE on epoch 5 is 0.4411017987972731
CE on epoch 6 is 0.4410928674061992
CE on epoch 7 is 0.44109392237789624
CE on epoch 8 is 0.44109899486704
CE on epoch 9 is 0.441094060249955
CE on epoch 10 is 0.44108675041275675
CE on epoch 11 is 0.4410845314437765
CE on epoch 12 is 0.4410846969520871
CE on epoch 13 is 0.4410831248415506
CE on epoch 14 is 0.4410796774938311
CE on epoch 15 is 0.44107626944504447
CE on epoch 16 is 0.4410745699663477
CE on epoch 17 is 0.44107451018446653
CE on epoch 18 is 0.44107477508700843
CE on epoch 19 is 0.441074401219161
CE on epoch 20 is 0.4410734618814699
CE on epoch 21 is 0.4410725571611493
CE on epoch 22 is 0.4410720574418668
CE on epoch 23 is 0.4410717932178792
CE on epoch 24 is 0.4410713864063695
CE on epoch 25 is 0.44107074257065687
CE on 

In [44]:
# remove trials where decoding is not possible
decodable = []
for trial in res:
    if torch.min(trial) < 1.2:
        decodable.append(trial)
good_ce = torch.stack(decodable)

In [45]:
npres = good_ce.detach().numpy().mean(axis=0)

In [46]:
np.min(npres)

1.7030872091280125

In [47]:
import pickle as pkl
pkl.dump(searchlight, open('../MTurk1/misc_testing_files/searchlight_decode_uncolored_shapes_500downsampled.pkl', 'wb'))

In [48]:
ce_nii = nib.Nifti1Image(npres, affine=atlas_nii.affine, header=atlas_nii.header)
nib.save(ce_nii, '../MTurk1/misc_testing_files/train_train_uncolored_shape_searchlight_ce.nii.gz')