In [None]:
import pandas as pd
import nilearn as nl
import matplotlib.pyplot as plt
from scipy.stats import pearsonr, spearmanr
from nilearn.connectome import vec_to_sym_matrix, sym_matrix_to_vec
from nilearn import plotting
import seaborn as sns
import numpy as np
import torch
import xarray as xr
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, Subset, TensorDataset
from PIL import Image
from scipy.linalg import issymmetric
import os
import re
from tqdm import tqdm

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
var_dict = {
#     "Interview Age": "interview_age",
    "CBCL Internalizing": "cbcl_scr_syn_internal_r",
    "CBCL Externalizing": "cbcl_scr_syn_external_r",
    "CBCL Thought Problems": "cbcl_scr_syn_thought_r",
}

In [None]:
def replace_with_network(label, network_labels):
    for network in network_labels:
        if network in label:
            return network
    return label

In [None]:
def mape_cog(csv, cog_score):
    file = pd.read_csv(csv)
    file = file[["train_ratio", "experiment", "dataset", cog_score]]
    file[cog_score]= file[cog_score]#*100
    return file

In [None]:
def plot_cog(data, cog_score, title):
    plt.figure(figsize=(10, 6))
    ax = sns.violinplot(data=data, x="train_ratio", y=cog_score, hue="dataset", hue_order=['train', 'test'], width = 0.8, scale = 'count', split = True) #, width = 2, gap = 0.01

    for patch in ax.collections:
        patch.set_alpha(0.4)

    sns.pointplot(x='train_ratio', y=cog_score, hue='dataset', data=data.groupby(['train_ratio', 'dataset'], as_index=False)[cog_score].median(), ax=ax, hue_order=['train', 'test'], markers="_")
    # ax.set_yticks(np.arange(0, 50, 5))
    #set x axis limit to 100
    # ax.set_ylim(-5, 50)
    #plt.axhline(10, c='r')
    #plt.axhline(5, c='g', linestyle='--')
    plt.ylabel("MAPE")
    # plt.axhline(0, c='k')
    # plt.suptitle(f"Training set ratio 20%, 20 experiments per size, thresh =  {threshold}%, FlippedEdge Aug")
    plt.suptitle(title)

    plt.grid()

In [None]:
def plot_loss(csv, title):
    loss_j = pd.read_csv(csv)
    plt.figure(figsize=(10, 6))
    sns.pointplot(x='train_ratio', y='loss', data=loss_j.groupby(['train_ratio'], as_index=False)['loss'].median(), markers="_", label = "loss")
    sns.pointplot(x='train_ratio', y='target_decoding', data=loss_j.groupby(['train_ratio'], as_index=False)['target_decoding'].median(), markers="_", label='target decoding')
    sns.pointplot(x='train_ratio', y='kernel_feature', data=loss_j.groupby(['train_ratio'], as_index=False)['kernel_feature'].median(), markers="_", label = "kernel_feature")
    sns.pointplot(x='train_ratio', y='kernel_target', data=loss_j.groupby(['train_ratio'], as_index=False)['kernel_target'].median(), markers="_", label = "kernel_target")
    sns.pointplot(x='train_ratio', y='joint_embedding', data=loss_j.groupby(['train_ratio'], as_index=False)['joint_embedding'].median(), markers="_", label = "joint_embedding")
    #sns.pointplot(x='train_ratio', y='feature_decoding', data=loss_j.groupby(['train_ratio'], as_index=False)['feature_decoding'].median(), markers="_", label = "feature_decoding")

    plt.grid()
    plt.legend(title=title)
    plt.show()

In [None]:
def combine_images(image_paths, save_to):

    images = [Image.open(image_path) for image_path in image_paths]

    total_width = sum(image.width for image in images)
    max_height = max(image.height for image in images)

    combined_image = Image.new("RGB", (total_width, max_height))

    x_offset = 0
    for image in images:
        combined_image.paste(image, (x_offset, 0))
        x_offset += image.width

    combined_image.save(save_to)


In [None]:
def mat_correlations(true, recon):
    batch_size, rows, cols = true.shape
    correlations = np.zeros((batch_size, rows, cols))
    flat_true = true.reshape(batch_size, rows * cols)
    flat_recon = recon.reshape(batch_size, rows * cols)
    
    with tqdm(total=rows * cols, desc='Computing correlations') as pbar:
        for i in range(rows * cols):
            for b in range(batch_size):
                correlations[b, i // cols, i % cols] = pearsonr(flat_true[:, i], flat_recon[:, i])[0]

    return correlations

In [None]:
def compute_batch_elementwise_correlation(true, recon):
    batch_size, rows, cols = true.shape
    correlations = np.zeros((rows, cols))

    flat_true = true.reshape(batch_size, -1)
    flat_recon = recon.reshape(batch_size, -1)
    
    for i in range(rows * cols):
        correlations[i // cols, i % cols] = spearmanr(flat_true[:, i], flat_recon[:, i])[0]

    return correlations

In [None]:
atlas_labels = nl.datasets.fetch_atlas_schaefer_2018()['labels']
atlas_labels = [label.decode('utf-8') for label in atlas_labels]
network_labels = ['Vis', 'SomMot', 'DorsAttn', 'SalVentAttn', 'Limbic', 'Cont', 'Default']
network_labels = [replace_with_network(label, network_labels) for label in atlas_labels]

## Autoencoder

In [None]:
exp = "ae_loss_norm"
root = "/gpfs3/well/margulies/users/cpy397/contrastive-learning"
exp_dir = f"{root}/results/{exp}"
recon_mat_dir = f"{exp_dir}/recon_mat"
recon_mat_files = sorted([i for i in os.listdir(recon_mat_dir) if "recon_mat" in i])
mape_mat_files = sorted([i for i in os.listdir(recon_mat_dir) if "mape_mat" in i])
recon_paths = [os.path.join(recon_mat_dir, i) for i in recon_mat_files]
mape_paths = [os.path.join(recon_mat_dir, i) for i in mape_mat_files]

In [None]:
test_idx = np.load(f"{exp_dir}/test_idx.npy")

In [None]:
dataset_path = f"{root}/ABCD/abcd_dataset_400parcels.nc"
dataset = xr.open_dataset(dataset_path)
true_mat = dataset.isel(subject = test_idx).to_array().squeeze().values

In [None]:
recon_mat = np.concatenate([np.load(i) for i in recon_paths])
mape_mat = np.concatenate([np.load(i) for i in mape_paths])

In [None]:
for i in range(recon_mat.shape[0]):
    np.fill_diagonal(recon_mat[i], 1.0)

In [None]:
recon_mat_flat = sym_matrix_to_vec(recon_mat, discard_diagonal = True)
mape_mat_flat = sym_matrix_to_vec(mape_mat, discard_diagonal = True)
true_mat_flat = sym_matrix_to_vec(true_mat, discard_diagonal = True)

In [None]:
dataset_path = f"{root}/ABCD/abcd_dataset_400parcels.nc"
dataset = xr.open_dataset(dataset_path)

In [None]:
np.save('true_mat_flat.npy', true_mat_flat)
np.save('recon_mat_flat.npy', recon_mat_flat)

In [None]:
corr_sub, p_value = spearmanr(true_mat_flat.flatten(), recon_mat_flat.flatten())
corr_sub, p_value

In [None]:
corr_mat_pred = compute_batch_elementwise_correlation(true_mat, recon_mat)
np.fill_diagonal(corr_mat_pred, 1.0)

In [None]:
corr_data_ae = {
    'correlation': [],
    'network': [],
    'model': 'AE Only'
}
for i, network in enumerate(network_labels):
    corr_data_ae['correlation'].extend(corr_mat_pred[i])
    corr_data_ae['network'].extend([network]*corr_mat_pred.shape[1])
    
corr_data_ae = pd.DataFrame(corr_data_ae)

In [None]:
plotting.plot_matrix(corr_mat_pred,
    title=f"Corr(True, Recon) | Exp {exp} | AE Only",
                     grid = False,
                     vmax = 1.,
                     vmin = -1.
    )
# Calculate the mean correlation value
mean_corr = corr_mat_pred.mean()
mean_mape = mape_mat.mean()
# Add text annotation for the mean correlation value
plt.text(-12, 0.02, f'mean_corr = {mean_corr:.2f}', color='black', ha='right', va='bottom', fontsize=12, transform=plt.gca().transAxes,
        bbox=dict(facecolor='white', alpha=0.7, edgecolor='none'))
plt.text(-10.5, 0.09, f'mean_mape = {mean_mape:.2f}', color='black', ha='right', va='bottom', fontsize=12, transform=plt.gca().transAxes,
        bbox=dict(facecolor='white', alpha=0.7, edgecolor='none'))


In [None]:
for i, mat_idx in enumerate(test_idx[0:5]):
    recon = recon_mat[i]
    mape = np.abs(mape_mat[i])
    true = dataset.isel(subject = mat_idx).to_array().squeeze()

    # min_val = recon.min()
    # max_val = recon.max()
    # recon = (recon - min_val) / (max_val - min_val)

    residual = true - recon

    fig, axes = plt.subplots(1, 4, figsize=(36, 7))

    plotting.plot_matrix(true,
    axes = axes[0],
    title=f"True Mat | Exp {exp} idx{mat_idx}",
    )

    plotting.plot_matrix(recon,
    axes = axes[1],
    title=f"Recon Mat | Exp {exp} idx{mat_idx}",
    )

    plotting.plot_matrix(residual,
    axes = axes[2],
    title=f"Risiduals | Exp {exp} idx{mat_idx}",
    )

    plotting.plot_matrix(mape,
    axes = axes[3],
    title=f"MAPE | Exp {exp} idx{mat_idx}",
    vmax = 100, vmin=0
    )


## Full Model

In [None]:
exp = "main_model_loss_norm"
root = "/gpfs3/well/margulies/users/cpy397/contrastive-learning"
exp_dir = f"{root}/results/{exp}"
recon_mat_dir = f"{exp_dir}/recon_mat"
predictions=pd.read_csv(f"{exp_dir}/pred_results.csv")

In [None]:
dataset_path = "ABCD/abcd_dataset_400parcels.nc"
dataset = xr.open_dataset(dataset_path)


In [None]:
exp = 15
true_mat_idx = predictions[(predictions["dataset"] == "test") & (predictions["train_ratio"] == 1) & (predictions["experiment"] == exp)]['indices'].values
true_mat = dataset.isel(subject = true_mat_idx).to_array().squeeze().values
recon_paths = sorted([i for i in os.listdir(recon_mat_dir) if "recon_mat" in i and f"exp{exp}" in i])
mape_paths = sorted([i for i in os.listdir(recon_mat_dir) if "mape_mat" in i and f"exp{exp}" in i])

recon_mat = np.concatenate([np.load(f"{recon_mat_dir}/{i}") for i in recon_paths])
mape_mat = np.concatenate([np.load(f"{recon_mat_dir}/{i}") for i in mape_paths])

In [None]:
recon_mat_flat = sym_matrix_to_vec(recon_mat, discard_diagonal = True)
mape_mat_flat = sym_matrix_to_vec(mape_mat, discard_diagonal = True)
true_mat_flat = sym_matrix_to_vec(true_mat, discard_diagonal = True)

In [None]:
corr_sub, p_value = spearmanr(true_mat_flat.flatten(), recon_mat_flat.flatten())
corr_sub, p_value

## Correlation: True Mat vs. Recon Across Subjects

In [None]:
corr_mat_pred = compute_batch_elementwise_correlation(true_mat, recon_mat)
np.fill_diagonal(corr_mat_pred, 1.0)

In [None]:
plotting.plot_matrix(corr_mat_pred,
    title=f"Corr(True, Recon) | Exp {exp}",
                     grid = False,
                     vmax = 1.,
                     vmin = -1.
    )

mean_corr = corr_mat_pred.mean()
mean_mape = mape_mat.mean()
plt.text(-12, 0.02, f'mean_corr = {mean_corr:.2f}', color='black', ha='right', va='bottom', fontsize=12, transform=plt.gca().transAxes,
        bbox=dict(facecolor='white', alpha=0.7, edgecolor='none'))
plt.text(-10.5, 0.09, f'mean_mape = {mean_mape:.2f}', color='black', ha='right', va='bottom', fontsize=12, transform=plt.gca().transAxes,
        bbox=dict(facecolor='white', alpha=0.7, edgecolor='none'))

## Correlation: True Mat vs. Recon Across Subjects Per Network

In [None]:
corr_data_main = {
    'correlation': [],
    'network': [],
    'model': 'Main'
}
for i, network in enumerate(network_labels):
    corr_data_main['correlation'].extend(corr_mat_pred[i])
    corr_data_main['network'].extend([network]*corr_mat_pred.shape[1])
    
corr_data_main = pd.DataFrame(corr_data_main)

In [None]:
corr_data_ae[corr_data_ae["network"] == ''

In [None]:
corr_data = pd.concat([corr_data_ae, corr_data_main])

In [None]:
sns.set(style="whitegrid")
plt.figure(figsize=(15, 8))
sns.violinplot(data=corr_data, x="network", y="correlation", hue="model", split=True, inner="quart", width = 1., dodge = True, palette = 'hls')
plt.ylim(0, 1)


In [None]:
sns.set(style="whitegrid")

def add_mean_line(x, **kwargs):
    plt.axvline(x.mean(), color='#7b2cbf', linestyle='--', lw=2)

g = sns.FacetGrid(corr_data, row="network", col = 'model', hue="network", aspect=10, height=1, palette="Spectral_r", xlim = (0, 1.))

g.map(sns.kdeplot, "correlation", clip_on=False, shade=True, alpha=1, lw=1.5, bw=.5)
g.map(sns.kdeplot, "correlation", clip_on=False, color="black", lw=1.5, bw=.5)

g.map(plt.axhline, y=0, lw=2, clip_on=False)
g.map(add_mean_line, "correlation")

g.fig.subplots_adjust(hspace=0.25)

for ax, label in zip(g.axes.flat, g.row_names):
    ax.text(0, 0.2, label, fontsize=20, ha='left', va='center', transform=ax.transAxes)

g.set_titles("")
g.despine(bottom=True, left=True)
g.set(yticks=[], xlim=(0.1, 1.0))

plt.show()

## Looking at Individual Reconstructions

In [None]:
for i, mat_idx in enumerate(true_mat_idx[0:5]):
    recon = recon_mat[i]
    mape = np.abs(mape_mat[i])
    true = dataset.isel(subject = mat_idx).to_array().squeeze()

    # min_val = recon.min()
    # max_val = recon.max()
    # recon = (recon - min_val) / (max_val - min_val)

    residual = true - recon

    fig, axes = plt.subplots(1, 4, figsize=(36, 7))

    plotting.plot_matrix(true,
    axes = axes[0],
    title=f"True Mat | Exp {exp} idx{mat_idx}",
    )

    plotting.plot_matrix(recon,
    axes = axes[1],
    title=f"Recon Mat | Exp {exp} idx{mat_idx}",
    )

    plotting.plot_matrix(residual,
    axes = axes[2],
    title=f"Risiduals | Exp {exp} idx{mat_idx}",
    )

    plotting.plot_matrix(mape,
    axes = axes[3],
    title=f"MAPE | Exp {exp} idx{mat_idx}",
    vmax = 100, vmin=0

)
#     save_plot_path = f"results/multivariate/abcd/recon_mat/plots/mat_{exp}_idx{mat_idx}.png"
#     plt.savefig(save_plot_path)

## Targets: Learning Curve