In [None]:
import os
import torch
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from helper import read_data, frame_based_corrected_trials, concat_nback_decoder_data_gen, nback_decoder_cnn_data_gen, decoder, nback_decoder_data_gen
save_path = "/home/xiaoxuan/projects/multfs_triple/evaluation/Organized_analysis/results/"

import matplotlib as mpl 
import seaborn as sns
sns.set_style("white")
sns.set_style("ticks", {"xtick.major.size": 14, "ytick.major.size": 14})
sns.set_context("poster")
mpl.rcParams['axes.linewidth']=2.5
mpl.rcParams['ytick.major.width']=2.5
mpl.rcParams['xtick.major.width']=2.5

In [None]:
def trial_based_corrected_trials(df, is_balanced = True):
    "return trials with corrected trials across all frames, if is_balanced is True, subsample to the smallest trials of all tasks"
    task_index_list = df.plain_task_index.unique()
    update_dfs = [] 
    min_len_df = 100000

    for i, task_index in enumerate(task_index_list):
        curr_df = df[df["plain_task_index"] == task_index]
        selected_trials = [i for i in range(len(curr_df)) if (curr_df.predicted_action.iloc[i] == curr_df.corrected_action.iloc[i]).all()]
        print("length of selected_trials:", len(selected_trials))
        curr_df = curr_df.iloc[selected_trials]
        update_dfs.append(curr_df)

        if min_len_df > len(curr_df):
            min_len_df = len(curr_df)
    
    if is_balanced: # subsample each df to min_len_df
        for i, curr_df in enumerate(update_dfs):
            curr_df = curr_df.sample(n=min_len_df, random_state=42)
            update_dfs[i] = curr_df

    return pd.concat(update_dfs, ignore_index=True)

In [None]:
# concatenate across multiple df but only store the relevant information

basepath_list = [
        "/mnt/store1/xiaoxuan/multfs_triple/eval/experiment_logs/2024611/RNN_rank256_init" # change it to the path you need
                ]


model_type = "RNN"
small_dataset = False
task_name = "nback"

mode = "val_angle"
basepath = basepath_list[0]
if not small_dataset:
    from decoding_helper import read_data
else:

    for basepath in basepath_list:
        task_name, df = read_data(basepath, mode = mode)
        break
    


if small_dataset:
    RNN_activations = []
    CNN_activations = []
    ntask_indices = []
    feature_indices = []
    loc_labels = []
    obj_labels = []
    ctg_labels = []
    df = trial_based_corrected_trials(df,is_balanced = True)
    df = df.sample(frac=1)
    # if with RNN and GRU
    RNN_activation = np.stack(df.activation.to_numpy())[:,0,:]
    # if with LSTM cell_state
    # RNN_activation = np.stack(df.cell_state.to_numpy())[:,0,:]
    # if with LSTM hidden_state
    # RNN_activation = np.stack(df.hidden_state.to_numpy())[:,0,:]

    CNN_activation = np.stack(df.CNN_activation_2.to_numpy())[:,0,:,:,:].reshape(-1, 256*7*7)
    ntask_index = df.ntask_index.to_numpy()
    feature_index = df.feature_index.to_numpy()
    loc_label = np.squeeze(np.stack(df.input_loc.to_numpy())[:,0])
    obj_label = np.squeeze(np.stack(df.input_obj.to_numpy())[:,0])
    ctg_label = np.squeeze(np.stack(df.input_ctg.to_numpy())[:,0])

    del df

    RNN_activations.append(RNN_activation)
    CNN_activations.append(CNN_activation)
    ntask_indices.append(ntask_index)
    feature_indices.append(feature_index)
    loc_labels.append(loc_label)
    obj_labels.append(obj_label)
    ctg_labels.append(ctg_label)
    RNN_activation = np.concatenate(RNN_activations, axis = 0)
    CNN_activation = np.concatenate(CNN_activations, axis = 0)
    ntask_index = np.concatenate(ntask_indices, axis = 0)
    feature_index = np.concatenate(feature_indices, axis = 0)
    loc_label = np.concatenate(loc_labels, axis = 0)
    obj_label = np.concatenate(obj_labels, axis = 0)
    ctg_label = np.concatenate(ctg_labels, axis = 0)

In [None]:
# make sure there are 9 tasks in toal
if small_dataset:
    assert np.max(ntask_index) == 3
    assert np.min(ntask_index) == 1
    assert np.max(feature_index) == 2
    assert np.min(feature_index) == 0

In [None]:
if not small_dataset:
    # iterate over all pickle files and collect all possible nback ns and feature index
    feature_nback_mapping = {}

    # Iterate over files in the directory
    for filename in os.listdir(basepath_list[0]):
        # Check if the filename matches the pattern
        if filename.endswith("_activations.csv"):
            # Parse the nback and feature indices from the filename
            parts = filename.split("_")
            nback_index = int(parts[0].replace("back", ""))
            feature_index = int(parts[2].replace("feature", ""))

            # Update the dictionary
            if feature_index not in feature_nback_mapping:
                feature_nback_mapping[feature_index] = set()
            feature_nback_mapping[feature_index].add(nback_index)

    print("Feature-Nback Mapping:")
    for feature_index, nback_indices in feature_nback_mapping.items():
        print("Feature", feature_index, ": Possible nback indices:", nback_indices)


# Single task decoding analysis
single-task (N) single-feature (L,I,C) model:
		for each stingle task model (3 in total), each row of the matrix shows the performance of the decoder on the vallidation dataset. 
	the point to be made is :
		whether the single task networks retain or forget the task irrelevant feature


In [None]:
def concat_nback_decoder_data_gen(RNN_activation,ntask_index, selected_ntask_index, feature_index, selected_feature_index, decoding_feature_labels, split_ratio = 0.8):
    # get [0.8,0.2] splited activation for selected task at frames on frame_list
    indice_1 = np.where(ntask_index == selected_ntask_index)
    indice_2 = np.where(feature_index == selected_feature_index)
    indice = np.intersect1d(indice_1, indice_2)
    # randomize indices
    np.random.shuffle(indice)
    
    curr_data = RNN_activation[indice]
    # subtract the mean of the RNN activation
    curr_data = curr_data - np.mean(curr_data, axis = 0)
    print("curr data shape:", curr_data.shape)
    
    curr_label = decoding_feature_labels[indice]
    l = curr_data.shape[0]
    l_train = int(np.floor(l*split_ratio))

    train_data = curr_data[:l_train]
    val_data = curr_data[l_train:]
    train_label = curr_label[:l_train]
    val_label = curr_label[l_train:]

    return train_data, train_label, val_data, val_label

In [None]:
# goal: save 9 datapoints
if small_dataset:
    task_feature_index = 0
    accs = np.zeros((1,3)) # rows: L,I,C 1back model, columns: decoding for loc, id, ctg feature

    classifier_type = "svm_linear"
    labels = [loc_label, obj_label, ctg_label]
    n_bootstraps = 1

    for selected_train_ntask_index in [3]: # only consider 1back task here
        for selected_train_feature_index in [task_feature_index]:
            for i in range(n_bootstraps):
                # within the same task
                selected_val_ntask_index = selected_train_ntask_index
                selected_val_feature_index = selected_train_feature_index

                # test for location
                for decoding_feature in range(3):
                    decoding_feature_labels = labels[decoding_feature]

                    train_data, train_label, val_data, val_label = concat_nback_decoder_data_gen(RNN_activation, 
                                                                                                 ntask_index,
                                                                                                 selected_train_ntask_index, 
                                                                                                 feature_index, 
                                                                                                 selected_train_feature_index, 
                                                                                                 decoding_feature_labels, 
                                                                                                 split_ratio = 0.8)

                    acc= decoder(train_data, val_data, train_label, val_label, type = classifier_type)
                    print("decoding accuracy:", acc)
                    accs[0, decoding_feature] = acc


In [None]:
# information retention analysis: goal: save 9 datapoints
if small_dataset:
    accs = np.zeros((3,3)) # rows: L,I,C 1back model, columns: decoding for loc, id, ctg feature

    classifier_type = "svm_linear"
    labels = [loc_label, obj_label, ctg_label]
    n_bootstraps = 10

    for selected_train_ntask_index in [1]: # only consider 1back task here
        for selected_train_feature_index in range(3):
            for i in range(n_bootstraps):
                for decoding_feature in range(3):

                    selected_val_ntask_index = selected_train_ntask_index



                    decoding_feature_labels = labels[decoding_feature]

                    train_data, train_label, val_data, val_label= concat_nback_decoder_data_gen(RNN_activation, 
                                                                                 ntask_index,
                                                                                 selected_train_ntask_index, 
                                                                                 feature_index, 
                                                                                 selected_train_feature_index, 
                                                                                 decoding_feature_labels, 
                                                                                 split_ratio = 0.8)



                    acc= decoder(train_data, val_data, train_label, val_label, type = classifier_type)
                    print("decoding accuracy:", acc)
                    accs[selected_train_feature_index, decoding_feature] = acc
    print(accs)


In [None]:
# information retention analysis: goal: save 9 datapoints
if not small_dataset:
    n_bootstraps = 2
    accs = np.zeros((3,3,3, n_bootstraps)) #(nback index, task feature, decoding feature)

    classifier_type = "svm_linear"
    
    for task_feature in feature_nback_mapping.keys():
        for nback_n in feature_nback_mapping[task_feature]:

            decoding_feature = task_feature
            # obtain the task relevant data
            task_name, df = read_data(basepath, path2file = basepath + "/%dback_feature_%d_activations.pkl" % (nback_n, task_feature), mode = mode)
            n_feature_values = []
            n_feature_values.append(list(set(np.stack(df.input_loc.to_numpy()).reshape(-1))))
            n_feature_values.append(list(set(np.stack(df.input_obj.to_numpy()).reshape(-1))))
            n_feature_values.append(list(set(np.stack(df.input_ctg.to_numpy()).reshape(-1))))

            RNN_activations = []
            CNN_activations = []
            ntask_indices = []
            feature_indices = []
            loc_labels = []
            obj_labels = []
            ctg_labels = []
            df = trial_based_corrected_trials(df,is_balanced = True)
            df = df.sample(frac=1)
            # if with RNN and GRU
            RNN_activation = np.stack(df.activation.to_numpy())[:,0,:]
            # if with LSTM cell_state
            # RNN_activation = np.stack(df.cell_state.to_numpy())[:,0,:]
            # if with LSTM hidden_state
            # RNN_activation = np.stack(df.hidden_state.to_numpy())[:,0,:]

            CNN_activation = np.stack(df.CNN_activation_2.to_numpy())[:,0,:,:,:].reshape(-1, 256*7*7)
            ntask_index = df.ntask_index.to_numpy()
            feature_index = df.feature_index.to_numpy()
            loc_label = np.squeeze(np.stack(df.input_loc.to_numpy())[:,0])
            obj_label = np.squeeze(np.stack(df.input_obj.to_numpy())[:,0])
            ctg_label = np.squeeze(np.stack(df.input_ctg.to_numpy())[:,0])

            del df

            RNN_activations.append(RNN_activation)
            CNN_activations.append(CNN_activation)
            ntask_indices.append(ntask_index)
            feature_indices.append(feature_index)
            loc_labels.append(loc_label)
            obj_labels.append(obj_label)
            ctg_labels.append(ctg_label)
            RNN_activation = np.concatenate(RNN_activations, axis = 0)
            CNN_activation = np.concatenate(CNN_activations, axis = 0)
            ntask_index = np.concatenate(ntask_indices, axis = 0)
            feature_index = np.concatenate(feature_indices, axis = 0)
            loc_label = np.concatenate(loc_labels, axis = 0)
            obj_label = np.concatenate(obj_labels, axis = 0)
            ctg_label = np.concatenate(ctg_labels, axis = 0)

            labels = [loc_label, obj_label, ctg_label]
            

            selected_train_ntask_index = nback_n
            selected_train_feature_index = task_feature
            
            for i in range(n_bootstraps):
                for decoding_feature in range(3):

                    selected_val_ntask_index = selected_train_ntask_index
                    decoding_feature_labels = labels[decoding_feature]

                    train_data, train_label, val_data, val_label= concat_nback_decoder_data_gen(RNN_activation, 
                                                                                 ntask_index,
                                                                                 selected_train_ntask_index, 
                                                                                 feature_index, 
                                                                                 selected_train_feature_index, 
                                                                                 decoding_feature_labels, 
                                                                                 split_ratio = 0.8)



                    acc= decoder(train_data, val_data, train_label, val_label, type = classifier_type)
                    print("decoding accuracy:", acc)
                    accs[nback_n-1, selected_train_feature_index, decoding_feature,i] = acc
    accs = np.mean(accs, axis = -1)