In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from nilearn.connectome import sym_matrix_to_vec
from scipy.stats import pearsonr
from cmath import isinf
import torch.nn.functional as F
#import seaborn as sns
from torch.utils.data import Dataset, DataLoader, Subset, TensorDataset
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
import pandas as pd
import math
from cmath import isinf
from utils_v import compute_target_score
import torch.nn.functional as F
#from sklearn.model_selection import train_test_split, KFold, LearningCurveDisplay, learning_curve
from sklearn.linear_model import LogisticRegression, Ridge
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import mean_absolute_percentage_error, r2_score
from helper_classes import MatData, MLP
#from dev_losses import cauchy, rbf, gaussian_kernel, CustomSupCon, CustomContrastiveLoss
#from losses import KernelizedSupCon

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def multivariate_kernel(d, sigma : float,):
    
    exponent = -0.5 *(d/sigma)  
    exp = torch.exp(exponent)
    
    return exp

In [3]:
def mat_threshold(matrices, threshold): # as in Margulies et al. (2016)
    perc = np.percentile(np.abs(matrices), threshold, axis=2, keepdims=True)
    mask = np.abs(matrices) >= perc
    thresh_mat = matrices * mask
    return thresh_mat

In [5]:
path_matrix = "/data/parietal/store/work/dwassermann/data/victoria_mat_age/matrices.npy"
matrix = np.load(path_matrix)[:10]
matrix_trs = mat_threshold(matrix, 95)
vec_m = sym_matrix_to_vec(matrix, discard_diagonal=True)
vec_m_trs = sym_matrix_to_vec(matrix_trs, discard_diagonal=True)
original = torch.tensor(vec_m)
augmented = torch.tensor(vec_m_trs)

In [16]:
list_threshold = {10, 20, 30, 40,50,60,70,80,90,95,96,97,98,99}


In [17]:
def compute_kernelized_distance(matrix1, matrix2, sigma = 100_000, save_path=None):
    dist = torch.cdist(matrix1,matrix2, p=2)**2
    kernel = multivariate_kernel(dist, sigma)
    if save_path:
        torch.save(kernel, save_path)
    return kernel

In [18]:
def row_by_row_analysis(anchor_vs_augmented, other_matrix):
    diag = anchor_vs_augmented.diag()
    mask = 1-torch.eye(len(other_matrix))
    other_matrix_masked = other_matrix*mask
    comparison = (diag.unsqueeze(1) > other_matrix_masked).all(dim=1)
    percentage_superior = comparison.float().mean() * 100
    return percentage_superior

In [19]:
def col_by_col_analysis(anchor_vs_augmented, other_matrix):
    diag = anchor_vs_augmented.diag()
    mask = 1-torch.eye(len(other_matrix))
    other_matrix_masked = other_matrix*mask
    comparison = (diag.unsqueeze(1) > other_matrix_masked).all(dim=0)
    percentage_superior = comparison.float().mean() * 100
    return percentage_superior

In [42]:
def comparison(anchor_vs_augmented, other_matrix, dim):
    diag = anchor_vs_augmented.diag()
    mask = 1-torch.eye(len(other_matrix))
    other_matrix_masked = other_matrix*mask
    comparison = (diag.unsqueeze(1) > other_matrix_masked).all(dim=dim)
    percentage= comparison.float().mean() * 100
    return percentage.item()
    

In [43]:
def hyperparameter_search_augmentation (matrix, list_threshold, sigma = 100000):
    
    results = []

    for thr in list_threshold:
        matrix_thr = mat_threshold(matrix, thr)
        vec_m = sym_matrix_to_vec(matrix, discard_diagonal=True)
        vec_m_thr = sym_matrix_to_vec(matrix_thr, discard_diagonal=True)
        original = torch.tensor(vec_m)
        augmented = torch.tensor(vec_m_thr)
        
        #Compute kernelized distance matrix
        original_vs_original = compute_kernelized_distance(original, original, sigma=sigma)
        original_vs_augmented = compute_kernelized_distance(original, augmented, sigma=sigma)#originals are the rows, and augmented are the columns
        augmented_vs_augmented = compute_kernelized_distance(augmented, augmented, sigma=sigma)
        
        #Compare
        anchor_closer_to_aug_than_other_augs = comparison(original_vs_augmented, original_vs_augmented, 1)#compared to non diag values of original_vs_augmented, row by row
        anchor_closer_to_aug_than_other_originals = comparison(original_vs_augmented, original_vs_original, 1)
        aug_closer_to_anchor_than_other_augs = comparison(original_vs_augmented, augmented_vs_augmented, 0)#compared to non diag values of original_vs_augmented, col by col
        aug_closer_to_anchor_than_other_originals = comparison(original_vs_augmented, original_vs_augmented, 0)
        
        results.append({
            'Threshold': thr,
            'Anchor Closer to Aug Than Other Augs': anchor_closer_to_aug_than_other_augs,
            'Anchor Closer to Aug Than Other Originals': anchor_closer_to_aug_than_other_originals,
            'Aug Closer to Anchor Than Other Augs': aug_closer_to_anchor_than_other_augs,
            'Aug Closer to Anchor Than Other Originals': aug_closer_to_anchor_than_other_originals
        })
        
    results_df = pd.DataFrame(results)
    results_df = results_df.sort_values('Threshold', ascending=True)

    file_path = '/storage/store2/work/mrenaudi/contrastive-reg-2/tests_aug/threshold.csv'  # Change this to your desired path

    results_df.to_csv(file_path, index=False)


    return results_df 

In [44]:
hyperparameter_search_augmentation (matrix, list_threshold, sigma = 100000)

Unnamed: 0,Threshold,Anchor Closer to Aug Than Other Augs,Anchor Closer to Aug Than Other Originals,Aug Closer to Anchor Than Other Augs,Aug Closer to Anchor Than Other Originals
6,10,100.0,100.0,100.0,100.0
9,20,100.0,100.0,100.0,100.0
12,30,100.0,100.0,100.0,100.0
5,40,100.0,100.0,100.0,100.0
8,50,100.0,100.0,100.0,100.0
11,60,100.0,100.0,100.0,100.0
4,70,100.0,100.0,100.0,100.0
7,80,100.0,100.0,100.0,100.0
10,90,100.0,100.0,30.000002,100.0
13,95,100.0,100.0,0.0,100.0


In [30]:
df

Unnamed: 0,Threshold,Anchor Closer to Aug Than Other Augs,Anchor Closer to Aug Than Other Originals,Aug Closer to Anchor Than Other Augs,Aug Closer to Anchor Than Other Originals
6,10,tensor(100.),tensor(100.),tensor(100.),tensor(100.)
9,20,tensor(100.),tensor(100.),tensor(100.),tensor(100.)
12,30,tensor(100.),tensor(100.),tensor(100.),tensor(100.)
5,40,tensor(100.),tensor(100.),tensor(100.),tensor(100.)
8,50,tensor(100.),tensor(100.),tensor(100.),tensor(100.)
11,60,tensor(100.),tensor(100.),tensor(100.),tensor(100.)
4,70,tensor(100.),tensor(100.),tensor(100.),tensor(100.)
7,80,tensor(100.),tensor(100.),tensor(100.),tensor(100.)
10,90,tensor(100.),tensor(100.),tensor(90.),tensor(100.)
13,95,tensor(100.),tensor(100.),tensor(0.),tensor(100.)
