In [None]:
import os
import numpy as np
from sklearn.metrics import pairwise_distances
from nltools.data import Adjacency
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import pandas as pd






In [None]:

srm_data_path = '/Volumes/ARCHIVES/thesis_pipeline/SRM_data'
save_data_dir = '/Volumes/ARCHIVES/thesis_pipeline/data'
save_fig_dir = '/Volumes/ARCHIVES/thesis_pipeline/figures/IS-TR_RSA'

# Function to load shared response data for all subjects and stack into a 3D array (subjects x TRs x features)
def load_feature_time_series(roi_name, n_subjects):
    shared_data_file = os.path.join(srm_data_path, f"{roi_name}_shareddata.pkl")
    
    if os.path.exists(shared_data_file):
        with open(shared_data_file, 'rb') as file:
            shared_data = pickle.load(file)

        # Stack data into 3D array: subjects x TRs x features
        feature_time_series = np.array([shared_data[subj_idx].T for subj_idx in range(n_subjects)])  # TRs x features for each subject
        
        return feature_time_series
    else:
        raise FileNotFoundError(f"Shared data for {roi_name} not found.")

# Function to compute RSA between subjects for a given region
def compute_rsa_for_feature_time_series(feature_time_series):
    n_subjects, n_trs, n_features = feature_time_series.shape
    
    # Initialize an empty matrix to store RSA results: subjects x subjects
    rsa_matrix = np.zeros((n_subjects, n_subjects))
    
    # Compute pairwise distances for each feature at each TR
    for tr in range(n_trs):
        for feature in range(n_features):
            # Extract the feature values across all subjects at the current TR
            tr_feature_data = feature_time_series[:, tr, feature]
            
            # Compute pairwise correlation distances between all subjects
            rsa_tr_feature = pairwise_distances(tr_feature_data[:, np.newaxis], metric='correlation')
            
            # Add the distances to the RSA matrix (average across features and TRs)
            rsa_matrix += rsa_tr_feature

    # Normalize by the number of TRs and features to get the average distance
    rsa_matrix /= (n_trs * n_features)
    
    return rsa_matrix

# Main function to load feature time series, compute RSA, and save results for each region
def run_rsa_analysis(roi_names, n_subjects):
    # Directory to save the time series data
    os.makedirs(save_data_dir, exist_ok=True)

    for roi_name in roi_names:
        print(f"Processing ROI: {roi_name}")
        
        # Load feature time series for the current region
        feature_time_series = load_feature_time_series(roi_name, n_subjects)
        
        # Save the 3D feature time series array before computing RSA
        time_series_save_path = os.path.join(save_data_dir, f'SRM_3D_Time_series_{roi_name}.npy')
        np.save(time_series_save_path, feature_time_series)
        print(f"Saved feature time series for {roi_name} at {time_series_save_path}")
        
        # Compute the RSA matrix for the current region
        rsa_matrix = compute_rsa_for_feature_time_series(feature_time_series)
        
        # Plot and save the RSA matrix after computation
        plot_and_save_rsa_matrix(rsa_matrix, roi_name)

# Function to plot and save the RSA matrix
def plot_and_save_rsa_matrix(rsa_matrix, roi_name):
    plt.figure(figsize=(8, 6))
    sns.heatmap(rsa_matrix, annot=False, cmap='coolwarm', vmin=0, vmax=2)
    plt.title(f'Inter-Subject RSA for {roi_name}')
    plt.xlabel('Subjects')
    plt.ylabel('Subjects')
    
    # Save the figure
    save_path = os.path.join(save_fig_dir, f'{roi_name}_IS-TR_RSA.png')
    plt.savefig(save_path, dpi=300)
    plt.close()
    print(f"Saved RSA matrix for {roi_name} at {save_path}")

# Example usage: Define the regions and number of subjects
roi_names = ['PTL', 'ATL', 'AG', 'IFG', 'MFG', 'IFGorb']
n_subjects = 8  # Adjust according to the actual number of subjects

# Run the RSA analysis
run_rsa_analysis(roi_names, n_subjects)
