In [None]:
import os
import torch
import json
import scipy
import numpy as np
import pandas as pd
import sklearn
import matplotlib.pyplot as plt

from helper import read_data, frame_based_corrected_trials, RSA_template_matrix_constructor, centering, linear_HSIC, linear_CKA

from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score
from sklearn.linear_model import Ridge
from sklearn.linear_model import LogisticRegression

from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC


import matplotlib as mpl 
import seaborn as sns
sns.set_style("white")
sns.set_style("ticks", {"xtick.major.size": 14, "ytick.major.size": 14})
sns.set_context("poster")
mpl.rcParams['axes.linewidth']=2.5
mpl.rcParams['ytick.major.width']=2.5
mpl.rcParams['xtick.major.width']=2.5

In [None]:
# sample dataset to keep trial_wise accurate trials only
def trial_based_corrected_trials(df, is_balanced = True):
    "return trials with corrected trials across all frames, if is_balanced is True, subsample to the smallest trials of all tasks"
    task_index_list = df.plain_task_index.unique()
    update_dfs = [] 
    min_len_df = 100000

    for i, task_index in enumerate(task_index_list):
        curr_df = df[df["plain_task_index"] == task_index]
        selected_trials = [i for i in range(len(curr_df)) if (curr_df.predicted_action.iloc[i] == curr_df.corrected_action.iloc[i]).all()]
        curr_df = curr_df.iloc[selected_trials]
        update_dfs.append(curr_df)

        if min_len_df > len(curr_df):
            min_len_df = len(curr_df)
    
    if is_balanced: # subsample each df to min_len_df
        for i, curr_df in enumerate(update_dfs):
            curr_df = curr_df.sample(n=min_len_df, random_state=42)
            update_dfs[i] = curr_df

    return pd.concat(update_dfs, ignore_index=True)


In [None]:
# if need to project CNN activation to the same dimensionality of the RNN activation
target_dimension = 256

from sklearn.decomposition import PCA

def reduce_dimensionality(data, target_dimension):
    """
    Perform dimensionality reduction using Principal Component Analysis (PCA).
    
    Args:
    - data: Input data array with shape (n_samples, n_features).
    - target_dimension: Target dimensionality for the reduced data.
    
    Returns:
    - reduced_data: Dimensionality-reduced data with shape (n_samples, target_dimension).
    """
    pca = PCA(n_components=target_dimension)
    reduced_data = pca.fit_transform(data)
    return reduced_data


In [None]:
def decoder(X_train, X_test, y_train, y_test, return_weight_vector=True):
    # Grid search to find the best C value
    param_grid = {'C': [0.0001, ]}  # Adjust the values as needed
    grid_search = GridSearchCV(SVC(kernel='linear'), param_grid, cv=2)
    grid_search.fit(data, label)
    best_C = grid_search.best_params_['C']
    print("Best C:", best_C)
    
    # Split the data into train and test sets
#     X_train, X_test, y_train, y_test = train_test_split(data, label, test_size=0.1)
    
    # Train SVC classifier with best C value
    clf = make_pipeline(StandardScaler(), SVC(kernel="linear", gamma='auto', C=best_C))
    clf.fit(X_train, y_train)
    
    # Predictions
    y_pred = clf.predict(X_test)
    
    # Get the weight vector and intercept
    weight_vector = clf.named_steps['svc'].coef_
    intercept = clf.named_steps['svc'].intercept_

    # Calculate accuracy
    accuracy = accuracy_score(y_test, y_pred)
    print(f'Accuracy: {accuracy:.2f}')
    
    if not return_weight_vector:
        return clf.named_steps["svc"], accuracy,
    else:
        return clf.named_steps["svc"], weight_vector, accuracy, 



In [None]:
import copy
### binarize data
def binarize_labels(data, label, selected_index):
    
    match_index = np.where(label == selected_index)[0]
    
    non_match_index = np.where(label != selected_index)[0]
    np.random.shuffle(non_match_index)
    non_match_index = non_match_index[:len(match_index)]
    
    updated_label = copy.copy(label)
    updated_label[non_match_index] = 1000
    
    curr_data = np.concatenate([data[match_index], data[non_match_index]], axis = 0)
    curr_label = np.concatenate([updated_label[match_index], updated_label[non_match_index]], axis = 0)
    return curr_data, curr_label


In [None]:
from scipy.spatial.distance import cosine

def obtain_hyperplanes(data, label,):
    hyperplanes = np.zeros((2, len(np.unique(label)), data.shape[-1]))
    for i, selected_index in enumerate(np.unique(label)):
        # binarize label
        curr_data, curr_label = binarize_labels(data, label, selected_index)

        # cross validation: 50% choice for better construction of on-diagonal values
        X_train, X_test, y_train, y_test = train_test_split(curr_data, curr_label, test_size=0.5,)

        clf, weight_vector, accuracy = decoder(X_train, X_test, y_train, y_test,  )
        hyperplanes[0, i,:] = np.squeeze(weight_vector)

        clf, weight_vector, accuracy = decoder(X_test, X_train, y_test, y_train, )
        hyperplanes[1, i,:] = np.squeeze(weight_vector)
    return hyperplanes



In [None]:
def cossim_cal(hyperplanes):
    cosine_sim = np.zeros((hyperplanes.shape[1],hyperplanes.shape[1]))
    for i in range(hyperplanes.shape[1]):
        for j in range(hyperplanes.shape[1]):
            # udpate: do the abs on cosine similarity first then calculate the opposite 
            cosine_sim[i,j] = 1 - np.abs(1-cosine(hyperplanes[0,i,:].reshape(-1), hyperplanes[1,j,:].reshape(-1)))
            assert (np.abs(1-cosine(hyperplanes[0,i,:].reshape(-1), hyperplanes[1,j,:].reshape(-1))) < 1)

    return cosine_sim

### obtain the data

In [None]:
# obtain the data 
# concatenate across multiple df but only store the relevant information

model_type = "GRU"
which_decoder = "SVC"

basepath_list = [
    # update the path to the correct one
    "/mnt/store1/xiaoxuan/multfs_triple/eval/experiment_logs/2024318/nback_mulfeat_gru_rep1"         
                ]

RNN_activations = []
CNN_activations = []
ntask_indices = []
feature_indices = []
loc_labels = []
obj_labels = []
ctg_labels = []
for basepath in basepath_list:
    task_name, df = read_data(basepath)

df = trial_based_corrected_trials(df,  is_balanced = True)
df = df.sample(frac=1)

if model_type == "LSTM": 
    RNN_activation = np.stack(df.hidden_state.to_numpy())[:,0,:]
else:
    # for GRU and RNN
    RNN_activation = np.stack(df.activation.to_numpy())[:,0,:]

# obtain CNN activation, before/after PCA projection
# CNN_activation = np.stack(df.CNN_activation_2.to_numpy())[:,0,:,:,:].reshape(-1, 256*7*7) # after projection
CNN_activation = np.stack(df.CNN_activation.to_numpy())[:,0,:,:,:].reshape(-1, 2048*7*7) # before projection

ntask_index = df.ntask_index.to_numpy()
feature_index = df.feature_index.to_numpy()
loc_label = np.squeeze(np.stack(df.input_loc.to_numpy())[:,0])
obj_label = np.squeeze(np.stack(df.input_obj.to_numpy())[:,0])
ctg_label = np.squeeze(np.stack(df.input_ctg.to_numpy())[:,0])

del df

RNN_activations.append(RNN_activation)
CNN_activations.append(CNN_activation)
ntask_indices.append(ntask_index)
feature_indices.append(feature_index)
loc_labels.append(loc_label)
obj_labels.append(obj_label)
ctg_labels.append(ctg_label)

RNN_activation = np.concatenate(RNN_activations, axis = 0)
CNN_activation = np.concatenate(CNN_activations, axis = 0)
ntask_index = np.concatenate(ntask_indices, axis = 0)
feature_index = np.concatenate(feature_indices, axis = 0)
loc_label = np.concatenate(loc_labels, axis = 0)
obj_label = np.concatenate(obj_labels, axis = 0)
ctg_label = np.concatenate(ctg_labels, axis = 0)


In [None]:
ntask_indices = list(set(ntask_index))
task_feature_indices = list(set(feature_index))

### Decoding analysis
Adjust according to your need

In [None]:
# average value for off-diagonal elements
n_bootstraps = 2
ntask_indices = [1]
task_feature_indices = [2]
RNN_diag_means = np.zeros((len(task_feature_indices),len(ntask_indices),3, n_bootstraps, ))
CNN_diag_means = np.zeros((len(task_feature_indices),len(ntask_indices), 3,n_bootstraps, ))
RNN_off_diag_means = np.zeros((len(task_feature_indices),len(ntask_indices),3, n_bootstraps, ))
CNN_off_diag_means = np.zeros((len(task_feature_indices),len(ntask_indices),3, n_bootstraps, ))
# choose data and label to decode from


for sni, selected_ntask_index in enumerate(ntask_indices):
    for stf, selected_task_feature_index in enumerate(task_feature_indices):
        for selected_label_index in range(1):
            selected_label_index = selected_task_feature_index
            for ia, curr_activation in enumerate([RNN_activation, CNN_activation]):

                task_subsample_indices = np.where((ntask_index == selected_ntask_index) & (feature_index == selected_task_feature_index))[0]
                data = curr_activation[task_subsample_indices]
                if selected_label_index == 0:
                    label = loc_label[task_subsample_indices]
                elif selected_label_index == 1:
                    label = obj_label[task_subsample_indices]
                elif selected_label_index == 2:
                    label = ctg_label[task_subsample_indices]
    

                for nb in range(n_bootstraps):
                    off_diags = []
                    on_diags = []
                    hyperplanes = obtain_hyperplanes(data, label,)
                    cosine_sim = cossim_cal(hyperplanes)
                    off_diagonal = cosine_sim[np.triu_indices(cosine_sim.shape[0], k=1)]
                    off_diags.extend(off_diagonal)
                    on_diagonal = np.diag(cosine_sim)
                    on_diags.extend(on_diagonal)

                    if ia == 0:
                        RNN_diag_means[stf, sni, selected_label_index, nb] = np.mean(on_diags)
                        RNN_off_diag_means[stf, sni, selected_label_index, nb] = np.mean(off_diags)
                        print("RNN off diags:", off_diags)
                    elif ia == 1:
                        CNN_diag_means[stf, sni, selected_label_index, nb] = np.mean(on_diags)
                        CNN_off_diag_means[stf, sni, selected_label_index, nb] = np.mean(off_diags)
                        print("CNN off diags:", off_diags)


### Visualization
adjust according to your need

In [None]:
# only plot for 1back location task across different number of principle components
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib as mpl

# Define colors
colors = [(105/256, 103/256, 149/256), (114/256, 188/256, 213/256), (255/256, 208/256, 111/256), (231/256, 98/256, 84/256)]
bg_colors = [(55/256, 103/256, 149/256), (170/256, 220/256, 224/256), (255/256, 230/256, 183/256), (239/256, 138/256, 71/256)]

# Set plot aesthetics
sns.set_style("ticks")
mpl.rcParams.update({
    'axes.linewidth': 2.0,       # Thicker border of the plots
    'ytick.major.width': 2.0,    # Thicker y-tick marks
    'xtick.major.width': 2.0,    # Thicker x-tick marks
    'xtick.labelsize': 16,
    'ytick.labelsize': 16,
    'axes.labelsize': 18,
    'axes.titlesize': 20
})

# Create subplots grid (only one plot needed)
fig, ax = plt.subplots(1, 1, figsize=(5, 5))

feature_full_names = ['Location', "Identity", 'Category']
nback_names = ["1-back", "2-back", "3-back"]

# Plot for 1back location task (task_feature_index = 0, decoding_feature_index = 0)
task_feature_index = 0
decoding_feature_index = 0

common_range = (np.min(s_CNNs)-0.1, np.max(s_CNNs)+0.2)  # Adjusted common range calculation

# Plot the diagonal axis
ax.plot([common_range[0], common_range[1]], [common_range[0], common_range[1]], linestyle='--', color=bg_colors[0], linewidth=4)  # Thicker diagonal line

# Scatter plots for each nback
for ntask_index in range(3):  # Assuming there are 3 n-back tasks
    s_perceptuals = s_CNNs[task_feature_index, ntask_index, decoding_feature_index, :]
    s_encodings = s_RNNs[task_feature_index, ntask_index, decoding_feature_index, :]
    sns.scatterplot(x=s_encodings, y=s_perceptuals, ax=ax, label=nback_names[ntask_index], s=50, color=colors[ntask_index])

ax.set_ylabel('O(perceptual)', labelpad=10, fontsize=24)
ax.set_xlabel('O(encoding)', labelpad=10, fontsize=24)
ax.set_xlim([0,1])
ax.set_ylim([0,1])

ax.xaxis.set_major_formatter(plt.FormatStrFormatter('%.1f'))
ax.yaxis.set_major_formatter(plt.FormatStrFormatter('%.1f'))
ax.tick_params(axis='x', labelsize=34)
ax.tick_params(axis='y', labelsize=34)

sns.despine()

# Add legend with larger font size
# handles, labels = ax.get_legend_handles_labels()
# fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 1.0), ncol=3, fontsize=12, markerscale=1)

plt.tight_layout(rect=[0, 0, 1, 0.95])  # Adjust the rect to make space for the legend
# plt.savefig("/home/xiaoxuan/projects/WM_geometry_figures/Figures/neurips_2024/orthogonalization/9task_gru_orthogalization_%dpc.pdf"%target_dimension)

# Show the plot
plt.show()
