# RSA Analysis

In [1]:
import numpy as np
from tqdm.notebook import tqdm

import torch
from torch.utils.data import DataLoader

import seaborn as sns
import matplotlib.pyplot as plt

import albumentations as A
from albumentations.pytorch import ToTensorV2

from sklearn.metrics.pairwise import cosine_similarity, cosine_distances

from src.ml.test import test_model
from src.dataset.kay import load_dataset
from src.ml.dataset import StimulusDataset, FMRIDataset
from src.ml.model import StimulusClassifier, FMRIClassifier
from src.ml.utils import get_latent_emb_per_class
from src.ml.rsa import run_rsa
from src.ml import StimulusClassifierConfig, FMRIClassifierConfig
from src.utils.util import prepare_stimulus_data, prepare_fmri_data

In [2]:
%matplotlib inline

sns.set(style="ticks", context="talk")
plt.style.use("dark_background")

title_size = 16
params = {
    "legend.fontsize": 14,
    "axes.labelsize": title_size - 2,
    "axes.titlesize": title_size,
    "xtick.labelsize": title_size - 4,
    "ytick.labelsize": title_size - 4,
    "axes.titlepad": 1.5 * title_size,
}

plt.rcParams.update(params)

## Load Model Configuration

In [3]:
stim_config = StimulusClassifierConfig()
fmri_config = FMRIClassifierConfig()

## Load Data

In [4]:
all_data = load_dataset(data_path="./../data/")

x_stim, y_stim = prepare_stimulus_data(
    all_data=all_data,
    data_subset="test",
    class_ignore_list=stim_config.class_ignore_list,
    label_level=stim_config.label_level,
)

x_fmri, y_fmri = prepare_fmri_data(
    all_data=all_data,
    data_subset="test",
    class_ignore_list=fmri_config.class_ignore_list,
    label_level=fmri_config.label_level,
    roi_select_list=fmri_config.roi_select_list,
)

x_stim.shape, y_stim.shape, x_fmri.shape, y_fmri.shape

((114, 128, 128), (114,), (114, 8427), (114,))

In [5]:
assert len(np.unique(y_stim)) == len(np.unique(y_fmri))
class2idx = {k: i for i, k in enumerate(np.unique(y_stim))}
idx2class = {v: k for k, v in class2idx.items()}

In [6]:
stim_dataset = StimulusDataset(
    x_data=x_stim,
    y_data=y_stim,
    img_transform=stim_config.img_transform["test"],
    class2idx=class2idx,
)

stim_loader = DataLoader(dataset=stim_dataset, shuffle=False, batch_size=1)


fmri_dataset = FMRIDataset(x_data=x_fmri, y_data=y_fmri, class2idx=class2idx)

fmri_loader = DataLoader(dataset=fmri_dataset, shuffle=False, batch_size=1)

## Load Model

In [7]:
fmri_model = FMRIClassifier(num_features=x_fmri.shape[1], num_classes=len(class2idx))
fmri_model.load_state_dict(
    torch.load("./../models/fmri_classifier/fmri_classifier_model.pth")
)
fmri_model.eval()

print("fMRI model loaded successfully.")

fMRI model loaded successfully.


## Representational Similarity Analysis

In [None]:
(
    stim_model_names,
    stim_vs_fmri_cosine_norm_list,
    stim_model_norm_list,
    stim_model_num_param_list,
) = run_rsa(
    fmri_model=fmri_model,
    fmri_loader=fmri_loader,
    fmri_config=fmri_config,
    stim_loader=stim_loader,
    stim_config=stim_config,
    class2idx=class2idx,
    idx2class=idx2class,
)

In [None]:
stim_model_names, stim_vs_fmri_cosine_norm_list, stim_model_norm_list, stim_model_num_param_list

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(30, 25))
sns.lineplot(x=stim_model_names, y=stim_model_norm_list, ax=axes[0, 0])
axes[0, 0].set_title("Frobenius Norm for Different Models")
axes[0, 0].set_xlabel("Model Names")
axes[0, 0].set_ylabel("Forbenius Norms")

sns.lineplot(x=stim_model_names, y=stim_model_num_param_list, ax=axes[0, 1])
axes[0, 1].set_title("Number of Parameters for Different Models")
axes[0, 1].set_xlabel("Model Names")
axes[0, 1].set_ylabel("Number of Parameters")

sns.lineplot(x=stim_model_names, y=stim_vs_fmri_cosine_norm_list, ax=axes[1, 0])
axes[1, 0].set_title("Cosine Similarity between fMRI and Stimulus Models")
axes[1, 0].set_xlabel("Model Names")
axes[1, 0].set_ylabel("Cosine Similarity")


# sns.lineplot(x=stim_model_names, y=stim_vs_fmri_cosinesim_norm_list, ax=axes[1, 0])
axes[1, 1].set_title("Accuracy for Different Stimulus Models")
axes[1, 1].set_xlabel("Model Names")
axes[1, 1].set_ylabel("Accuracy")


plt.suptitle("Stimulus Model Characterstics")

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(30, 25))

sns.lineplot(x=stim_model_num_param_list, y=stim_vs_fmri_cosine_norm_list, ax=axes[0, 0])
axes[0, 0].set_title("Effect of Number of Parameters on RSM (Stim vs fMRI)")
axes[0, 0].set_xlabel("Number of parameters")
axes[0, 0].set_ylabel("Cosine Similarity")

sns.lineplot(x=stim_model_norm_list, y=stim_vs_fmri_cosine_norm_list, ax=axes[0, 1])
axes[0, 1].set_title("Effect of Parameter Norm on RSM (Stim vs fMRI)")
axes[0, 1].set_xlabel("Forbenius Norm of parameters")
axes[0, 1].set_ylabel("norm (RSM)")

plt.suptitle("Representation Similarity Analysis")

In [None]:
y_true_list, y_pred_list = test_model(fmri_model, fmri_loader, device="cpu")