# RSA Analysis

In [1]:
import numpy as np

import torch
from torch.utils.data import DataLoader

import albumentations as A
from albumentations.pytorch import ToTensorV2

from src.ml import StimulusClassifierConfig, FMRIClassifierConfig

from src.dataset.kay import load_dataset
from src.utils.util import prepare_stimulus_data, prepare_fmri_data
from src.ml.model import StimulusClassifier, FMRIClassifier
from src.ml.dataset import StimulusDataset, FMRIDataset

In [2]:
config_stim = StimulusClassifierConfig()
config_fmri = FMRIClassifierConfig()

In [3]:
all_data = load_dataset(data_path="./../data/")

x_stim, y_stim = prepare_stimulus_data(
    all_data=all_data,
    data_subset="test",
    class_ignore_list=config_stim.class_ignore_list,
    label_level=config_stim.label_level,
)

x_fmri, y_fmri = prepare_fmri_data(
    all_data=all_data,
    data_subset="test",
    class_ignore_list=config_fmri.class_ignore_list,
    label_level=config_fmri.label_level,
    roi_select_list=config_fmri.roi_select_list,
)

x_stim.shape, y_stim.shape, x_fmri.shape, y_fmri.shape

((119, 128, 128), (119,), (119, 5166), (119,))

In [4]:
assert len(np.unique(y_stim)) == len(np.unique(y_fmri))
class2idx = {k: i for i, k in enumerate(np.unique(y_stim))}
idx2class = {v: k for k, v in class2idx.items()}

In [5]:
stim_dataset = StimulusDataset(
    x_data=x_stim,
    y_data=y_stim,
    img_transform=config_stim.img_transform["test"],
    class2idx=class2idx,
)

stim_loader = DataLoader(dataset=stim_dataset, shuffle=False, batch_size=1)


fmri_dataset = FMRIDataset(x_data=x_fmri, y_data=y_fmri, class2idx=class2idx)

fmri_loader = DataLoader(dataset=fmri_dataset, shuffle=False, batch_size=1)

In [6]:
stim_model = StimulusClassifier(num_channel=3, num_classes=5)
stim_model.load_state_dict(torch.load('./../models/stimulus_classifier/stim_classifier_model.pth'))
stim_model.eval()

fmri_model = FMRIClassifier(num_features=5166, num_classes=5)
fmri_model.load_state_dict(torch.load('./../models/fmri_classifier/fmri_classifier_model.pth'))
fmri_model.eval()

print("Models loaded.")

Models loaded.


In [8]:
with torch.no_grad():
    y_pred_stim = stim_model.predict(next(iter(stim_loader))[0])
    print(idx2class[y_pred_stim])
    
with torch.no_grad():
    y_pred_fmri = fmri_model.predict(next(iter(fmri_loader))[0])
    print(idx2class[y_pred_fmri])

artifact
entity


In [19]:
stim_model.fc1.weight.shape, fmri_model.block_1[0].weight.shape

(torch.Size([64, 8192]), torch.Size([64, 5166]))

In [20]:
with torch.no_grad():
    hidden_stim = stim_model.get_latent_rep(next(iter(stim_loader))[0])
    
with torch.no_grad():
    hidden_fmri = fmri_model.get_latent_rep(next(iter(fmri_loader))[0])

In [22]:
hidden_stim.shape, hidden_fmri.shape

(torch.Size([1, 64]), torch.Size([1, 64]))

In [24]:
hidden_fmri @ hidden_stim.T

tensor([[-0.4580]])