In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
from nilearn.connectome import sym_matrix_to_vec
from scipy.stats import pearsonr
from cmath import isinf
import torch.nn.functional as F
#import seaborn as sns
from torch.utils.data import Dataset, DataLoader, Subset, TensorDataset
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
import pandas as pd
import math
from cmath import isinf
from utils_v import compute_target_score
import torch.nn.functional as F
#from sklearn.model_selection import train_test_split, KFold, LearningCurveDisplay, learning_curve
from sklearn.linear_model import LogisticRegression, Ridge
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import mean_absolute_percentage_error, r2_score
from helper_classes import MatData, MLP
#from dev_losses import cauchy, rbf, gaussian_kernel, CustomSupCon, CustomContrastiveLoss
#from losses import KernelizedSupCon
import itertools

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
print(device)

cuda


In [6]:
def multivariate_kernel(d, sigma : float,):
    
    exponent = -0.5 *(d/sigma)  
    exp = torch.exp(exponent)
    
    return exp

In [7]:
def mat_threshold(matrices, threshold): # as in Margulies et al. (2016)
    perc = np.percentile(np.abs(matrices), threshold, axis=2, keepdims=True)
    mask = np.abs(matrices) >= perc
    thresh_mat = matrices * mask
    return thresh_mat

In [8]:
def random_threshold(matrices, threshold, bound = 1): # as in Margulies et al. (2016)
    
    perc = np.percentile(np.abs(matrices), threshold, axis=2, keepdims=True)
    mask = np.abs(matrices) >= perc
    thresh_mat = matrices * mask
    
    random_values = np.random.uniform(-1/bound,1/bound, matrices.shape) * perc
    random_values_masked = random_values * (1-mask)
    
    mat = thresh_mat + random_values_masked
    
    return mat

In [9]:
path_matrix = "/data/parietal/store/work/dwassermann/data/victoria_mat_age/matrices.npy"
matrix = np.load(path_matrix)[:10]

In [15]:
def compute_kernelized_distance(matrix1, matrix2, sigma = 100_000, save_path=None):
    dist = torch.cdist(matrix1,matrix2, p=2)**2
    kernel = multivariate_kernel(dist, sigma)
    if save_path:
        torch.save(kernel, save_path)
    return kernel

In [16]:
def comparison(anchor_vs_augmented, other_matrix, dim):
    diag = anchor_vs_augmented.diag()
    mask = 1-torch.eye(len(other_matrix))
    other_matrix_masked = other_matrix*mask
    comparison = (diag.unsqueeze(1) > other_matrix_masked).all(dim=dim)
    percentage= comparison.float().mean() * 100
    return percentage.item()
    

In [33]:
import itertools

threshold = [10,20,30,40,50,60,70,80,90,95,96,97,98,99]  
bound = list(range(1,11)) 

param_combinations = list(itertools.product(threshold, bound))
threshold_params = [{'threshold' : comb[0]} for comb in param_combinations]
random_threshold_params = [{'threshold': comb[0], 'bound': comb[1]} for comb in param_combinations]

augmentations_with_params = [
    (mat_threshold, threshold_params), (random_threshold, random_threshold_params)
]

In [34]:
def hyperparameter_search_augmentation(matrix, augmentations_with_params, sigma=100000, file_path=None):
    results = []

    # Iterate over each augmentation and its corresponding list of parameters
    for augmentation, list_params in augmentations_with_params:
        for params in list_params:
            matrix_aug = augmentation(matrix, **params)  
            vec_m = sym_matrix_to_vec(matrix, discard_diagonal=True)
            vec_m_aug = sym_matrix_to_vec(matrix_aug, discard_diagonal=True)
            original = torch.tensor(vec_m)
            augmented = torch.tensor(vec_m_aug)

            # Compute kernelized distance matrix
            original_vs_original = compute_kernelized_distance(original, original, sigma=sigma)
            original_vs_augmented = compute_kernelized_distance(original, augmented, sigma=sigma)
            augmented_vs_augmented = compute_kernelized_distance(augmented, augmented, sigma=sigma)

            # Compare
            anchor_closer_to_aug_than_other_augs = comparison(original_vs_augmented, original_vs_augmented, 1)
            anchor_closer_to_aug_than_other_originals = comparison(original_vs_augmented, original_vs_original, 1)
            aug_closer_to_anchor_than_other_augs = comparison(original_vs_augmented, augmented_vs_augmented, 0)
            aug_closer_to_anchor_than_other_originals = comparison(original_vs_augmented, original_vs_augmented, 0)

            
            results.append({
                'Augmentation': augmentation.__name__,
                'Parameters': params,
                'Anchor Closer to Aug Than Other Augs': anchor_closer_to_aug_than_other_augs,
                'Anchor Closer to Aug Than Other Originals': anchor_closer_to_aug_than_other_originals,
                'Aug Closer to Anchor Than Other Augs': aug_closer_to_anchor_than_other_augs,
                'Aug Closer to Anchor Than Other Originals': aug_closer_to_anchor_than_other_originals
            })

    results_df = pd.DataFrame(results)
    results_df.sort_values('Parameters', key=lambda x: x.apply(lambda d: tuple(sorted(d.items()))), ascending=True, inplace=True)

    if file_path:
        results_df.to_csv(file_path, index=False)

    return results_df


In [35]:
df = hyperparameter_search_augmentation(matrix, augmentations_with_params, sigma=100000, file_path='/storage/store2/work/mrenaudi/contrastive-reg-2/tests_aug/threshold_random_threshold_bound.csv')

In [41]:
df = pd.read_csv("/storage/store2/work/mrenaudi/contrastive-reg-2/tests_aug/threshold_random_threshold_bound.csv")
pd.set_option('display.max_rows', None)
df

Unnamed: 0,Augmentation,Parameters,Anchor Closer to Aug Than Other Augs,Anchor Closer to Aug Than Other Originals,Aug Closer to Anchor Than Other Augs,Aug Closer to Anchor Than Other Originals
0,random_threshold,"{'threshold': 10, 'bound': 1}",100.0,100.0,100.0,100.0
1,random_threshold,"{'threshold': 20, 'bound': 1}",100.0,100.0,100.0,100.0
2,random_threshold,"{'threshold': 30, 'bound': 1}",100.0,100.0,100.0,100.0
3,random_threshold,"{'threshold': 40, 'bound': 1}",100.0,100.0,100.0,100.0
4,random_threshold,"{'threshold': 50, 'bound': 1}",100.0,100.0,100.0,100.0
5,random_threshold,"{'threshold': 60, 'bound': 1}",100.0,100.0,100.0,100.0
6,random_threshold,"{'threshold': 70, 'bound': 1}",100.0,100.0,100.0,100.0
7,random_threshold,"{'threshold': 80, 'bound': 1}",100.0,100.0,100.0,100.0
8,random_threshold,"{'threshold': 90, 'bound': 1}",100.0,30.000002,100.0,100.0
9,random_threshold,"{'threshold': 95, 'bound': 1}",70.0,0.0,100.0,70.0
