In [2]:
import glob
import os
import sys
import numpy as np
from scipy import stats
from scipy import io as sio
import scipy.spatial.distance as sp_distance
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.patches as patches
from scipy.stats import norm, zscore, pearsonr
from scipy.signal import gaussian, convolve
from sklearn import decomposition
import pandas as pd
import pickle

# Import BrainIAK modules
from brainiak.isc import isc
from brainiak.fcma.util import compute_correlation

In [None]:
# Define the base paths
base_path = "/Volumes/ARCHIVES/thesis_pipeline"
srm_data_path = os.path.join(base_path, 'SRM_data')

# Function to load shared response data for all subjects and stack into a 3D array (subjects x TRs x features)
def load_feature_time_series(roi_name, n_subjects):
    shared_data_file = os.path.join(srm_data_path, f"{roi_name}_shareddata.pkl")
    
    if os.path.exists(shared_data_file):
        with open(shared_data_file, 'rb') as file:
            shared_data = pickle.load(file)

        # Stack data into 3D array: subjects x TRs x features
        feature_time_series = np.array([shared_data[subj_idx].T for subj_idx in range(n_subjects)])  # TRs x features for each subject
        
        return feature_time_series
    else:
        raise FileNotFoundError(f"Shared data for {roi_name} not found.")

# Function to compute RSA matrix using BrainIAK's compute_correlation
def compute_rsa_using_brainiak(feature_time_series):
    n_subjects, n_trs, n_features = feature_time_series.shape
    
    # Initialize an empty matrix to store RSA results: subjects x subjects
    rsa_matrix = np.zeros((n_subjects, n_subjects))
    
    # Compute pairwise correlations for each feature at each TR using BrainIAK's compute_correlation
    for tr in range(n_trs):
        for feature in range(n_features):
            # Extract the feature values across all subjects at the current TR
            tr_feature_data = feature_time_series[:, tr, feature]
            
            # Compute pairwise correlation using BrainIAK's compute_correlation function
            corr_matrix = compute_correlation(tr_feature_data[:, np.newaxis].T, tr_feature_data[:, np.newaxis].T)  # Provide matrix1 and matrix2 as the same
            
            # Add the distances to the RSA matrix (average across features and TRs)
            rsa_matrix += corr_matrix

    # Normalize by the number of TRs and features to get the average correlation
    rsa_matrix /= (n_trs * n_features)
    
    return rsa_matrix

# Example usage: Define region and subjects
roi_name = 'IFG'
n_subjects = 8  # Adjust to the actual number of subjects

# Load the SRM data
feature_time_series = load_feature_time_series(roi_name, n_subjects)

# Compute the RSA matrix using BrainIAK
rsa_matrix = compute_rsa_using_brainiak(feature_time_series)

# Print the RSA matrix for inspection
print("RSA matrix shape:", rsa_matrix.shape)
print("RSA matrix values (first 5x5):")
print(rsa_matrix[:5, :5])

In [None]:
# Define base paths
base_path = '/Volumes/ARCHIVES/thesis_pipeline'
srm_data_path = os.path.join(base_path, 'SRM_data')

# Function to load the relevant SRM file and stack into a 3D array (subjects x TRs x features)
def load_srm_data(roi_name, n_subjects):
    shared_data_file = os.path.join(srm_data_path, f"{roi_name}_shareddata.pkl")
    
    if os.path.exists(shared_data_file):
        with open(shared_data_file, 'rb') as file:
            shared_data = pickle.load(file)

        # Stack data into 3D array: subjects x TRs x features
        feature_time_series = np.array([shared_data[subj_idx].T for subj_idx in range(n_subjects)])  # TRs x features for each subject
        
        return feature_time_series
    else:
        raise FileNotFoundError(f"Shared data for {roi_name} not found at {shared_data_file}")

# Example usage: Load SRM data for one region
roi_name = 'IFG'  # Change this as needed
n_subjects = 8  # Adjust this based on your data
feature_time_series = load_srm_data(roi_name, n_subjects)

# Print the shape of the stacked data to verify
print(f"Loaded SRM data shape for {roi_name}: {feature_time_series.shape}")


In [None]:
# Function to compute RSA matrix by running pairwise correlations on the 3D array
def compute_rsa_for_feature_time_series(feature_time_series):
    n_subjects, n_trs, n_features = feature_time_series.shape
    
    # Initialize an empty matrix to store RSA results: subjects x subjects
    rsa_matrix = np.zeros((n_subjects, n_subjects))
    
    # Compute pairwise distances for each feature at each TR
    for tr in range(n_trs):
        for feature in range(n_features):
            # Extract the feature values across all subjects at the current TR
            tr_feature_data = feature_time_series[:, tr, feature]
            
            # Compute pairwise correlation distances between all subjects
            rsa_tr_feature = pairwise_distances(tr_feature_data[:, np.newaxis], metric='correlation')
            
            # Add the distances to the RSA matrix (average across features and TRs)
            rsa_matrix += rsa_tr_feature

    # Normalize by the number of TRs and features to get the average distance
    rsa_matrix /= (n_trs * n_features)
    
    return rsa_matrix

# Compute the RSA matrix
rsa_matrix = compute_rsa_for_feature_time_series(feature_time_series)

# Save the RSA matrix for debugging
rsa_save_path = os.path.join(base_path, 'rsa_matrices', f'{roi_name}_rsa_matrix.npy')
os.makedirs(os.path.dirname(rsa_save_path), exist_ok=True)
np.save(rsa_save_path, rsa_matrix)

print(f"Computed RSA matrix for {roi_name} and saved to {rsa_save_path}")


In [None]:
# Function to inspect the RSA matrix before plotting
def inspect_rsa_matrix(rsa_matrix, roi_name):
    # Print the shape of the RSA matrix
    print(f"RSA matrix shape for {roi_name}: {rsa_matrix.shape}")
    
    # Print summary statistics for the RSA matrix
    print(f"RSA matrix summary for {roi_name}:")
    print(f"  Min value: {np.min(rsa_matrix)}")
    print(f"  Max value: {np.max(rsa_matrix)}")
    print(f"  Mean value: {np.mean(rsa_matrix)}")
    print(f"  Number of NaN values: {np.isnan(rsa_matrix).sum()}")

    # Print the first few rows of the RSA matrix for inspection
    print(f"First few rows of the RSA matrix for {roi_name}:")
    print(rsa_matrix[:5, :5])  # Print a 5x5 slice of the matrix for a quick view

# Inspect the RSA matrix before plotting
inspect_rsa_matrix(rsa_matrix, roi_name)