In [1]:
import math
import yaml
import wandb
import xarray as xr
import asyncio
import submitit
import pickle
import sys
from pathlib import Path
import gc
from collections import defaultdict
from nilearn.connectome import sym_matrix_to_vec, vec_to_sym_matrix
import numpy as np
import pandas as pd
import hydra
from hydra import initialize, compose
from omegaconf import DictConfig, OmegaConf
import torch
import torch.nn as nn
import torch.optim as optim
from scipy.stats import spearmanr
from sklearn.model_selection import (
    train_test_split,
)
from torch.utils.data import DataLoader, Dataset, Subset, TensorDataset
from tqdm.auto import tqdm
# from augmentations import augs, aug_args
import glob, os, shutil
from nilearn.datasets import fetch_atlas_schaefer_2018
import random
from sklearn.preprocessing import MinMaxScaler

from ContModeling.utils import gaussian_kernel, cauchy, standardize, save_embeddings
from ContModeling.losses import LogEuclideanLoss, NormLoss, KernelizedSupCon, OutlierRobustMSE
from ContModeling.models import PhenoProj
from ContModeling.helper_classes import MatData
from ContModeling.viz_func import wandb_plot_acc_vs_baseline, wandb_plot_test_recon_corr, wandb_plot_individual_recon

torch.cuda.empty_cache()
multi_gpu = False

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# %%
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
def train(run, train_ratio, train_dataset, test_dataset, mean, std, B_init_fMRI, cfg, model=None, device=device):
    print("Start training...")

    # MODEL DIMS
    input_dim_feat = cfg.input_dim_feat
    input_dim_target = cfg.input_dim_target
    hidden_dim = cfg.hidden_dim
    output_dim_target = cfg.output_dim_target
    output_dim_feat = cfg.output_dim_feat
    kernel = SUPCON_KERNELS[cfg.SupCon_kernel]
    
    # TRAINING PARAMS
    lr = cfg.lr
    batch_size = cfg.batch_size
    dropout_rate = cfg.dropout_rate
    weight_decay = cfg.weight_decay
    num_epochs = cfg.num_epochs

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

    mean= torch.tensor(mean).to(device)
    std = torch.tensor(std).to(device)
    if model is None:
        model = PhenoProj(
            input_dim_feat,
            input_dim_target,
            hidden_dim,
            output_dim_target,
            output_dim_feat,
            dropout_rate,
            cfg
        ).to(device)

    if cfg.mat_ae_pretrained:
        print("Loading pretrained MatrixAutoencoder...")
        state_dict = torch.load(f"{cfg.output_dir}/{cfg.pretrained_mat_ae_exp}/saved_models/autoencoder_weights_fold{cfg.best_mat_ae_fold}.pth")
        model.matrix_ae.load_state_dict(state_dict)
    else:
        model.matrix_ae.enc_mat1.weight = torch.nn.Parameter(B_init_fMRI.transpose(0,1))
        model.matrix_ae.enc_mat2.weight = torch.nn.Parameter(B_init_fMRI.transpose(0,1))
    
    if cfg.target_ae_pretrained:
        print("Loading pretrained TargetAutoencoder...")
        state_dict = torch.load(f"{cfg.output_dir}/{cfg.pretrained_target_ae_exp}/saved_models/autoencoder_weights_fold{cfg.best_target_ae_fold}.pth")
        model.target_ae.load_state_dict(state_dict)

    criterion_pft = KernelizedSupCon(
        method="expw",
        temperature=cfg.pft_temperature,
        base_temperature= cfg.pft_base_temperature,
        reg_term = cfg.pft_reg_term,
        kernel=kernel,
        krnl_sigma=cfg.pft_sigma,
    )

    criterion_ptt = KernelizedSupCon(
        method="expw",
        temperature=cfg.ptt_temperature,
        base_temperature= cfg.ptt_base_temperature,
        reg_term = cfg.ptt_reg_term,
        kernel=kernel,
        krnl_sigma=cfg.ptt_sigma,
    )
    
    feature_autoencoder_crit = EMB_LOSSES[cfg.feature_autoencoder_crit]
    target_decoding_crit = EMB_LOSSES[cfg.target_decoding_crit]

    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, factor=0.1, patience = cfg.scheduler_patience)

    loss_terms = []
    validation = []
    autoencoder_features = []

    gc.collect()
    
    wandb.init(project=cfg.project,
        mode = "offline",
        name=f"{cfg.experiment_name}_run{run}_train_ratio_{train_ratio}",
        dir = cfg.output_dir,
        config = OmegaConf.to_container(cfg, resolve=True))

    with tqdm(range(num_epochs), desc="Epochs", leave=False) as pbar:
        for epoch in pbar:
            model.train()

            loss_terms_batch = defaultdict(lambda:0)
            for features, targets in train_loader:
                
                optimizer.zero_grad()
                features = features.to(device)
                targets = targets.to(device)

                ## FEATURE ENCODING
                embedded_feat = model.encode_features(features)
                ## FEATURE DECODING
                if not cfg.mat_ae_pretrained:
                    reconstructed_feat = model.decode_features(embedded_feat)
                    ## FEATURE DECODING LOSS
                    feature_autoencoder_loss = feature_autoencoder_crit(features, reconstructed_feat) / 10_000
                
                ## REDUCED FEAT TO TARGET EMBEDDING
                embedded_feat_vectorized = sym_matrix_to_vec(embedded_feat.detach().cpu().numpy(), discard_diagonal = True)
                embedded_feat_vectorized = torch.tensor(embedded_feat_vectorized).to(device)

                features_vectorized = sym_matrix_to_vec(features.detach().cpu().numpy(), discard_diagonal = True)
                features_vectorized = torch.tensor(features_vectorized).to(device)
                features_vectorized = nn.functional.normalize(features_vectorized, p=2, dim=1)

                ## TARGET DECODING FROM MAT EMBEDDINGs
                reduced_feat_embedding = model.transfer_embedding(embedded_feat_vectorized)
                out_target_decoded = model.decode_targets(reduced_feat_embedding)

                ## KERNLIZED LOSS: MAT embedding vs MAT

                if cfg.SupConLoss_on_mat:
                    kernel_embedded_feature_loss, direction_reg_features = criterion_ptt(reduced_feat_embedding.unsqueeze(1), features_vectorized)
                    kernel_embedded_feature_loss = 100 * kernel_embedded_feature_loss
                    direction_reg_features = 100 * direction_reg_features


                ## KERNLIZED LOSS: MAT embedding vs targets
                kernel_embedded_target_loss, direction_reg_target = criterion_ptt(reduced_feat_embedding.unsqueeze(1), targets)
                kernel_embedded_target_loss = 100 * kernel_embedded_target_loss
                direction_reg_target = 100 * direction_reg_target

                ## LOSS: TARGET DECODING FROM TARGET EMBEDDING
                if cfg.target_decoding_crit == 'Huber' and cfg.huber_delta != 'None':
                    target_decoding_crit = nn.HuberLoss(delta = cfg.huber_delta)
                
                target_decoding_from_reduced_emb_loss = target_decoding_crit(targets, out_target_decoded) / 10


                ## SUM ALL LOSSES
                loss = kernel_embedded_target_loss + target_decoding_from_reduced_emb_loss
                
                if cfg.SupConLoss_on_mat:
                    loss += kernel_embedded_feature_loss

                if not cfg.mat_ae_pretrained:
                    loss += feature_autoencoder_loss

                loss.backward()

                if cfg.clip_grad:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                    
                if cfg.log_gradients:
                    for name, param in model.named_parameters():
                        if param.grad is not None:
                            wandb.log({
                                "Epoch": epoch,
                                f"Gradient Norm/{name}": param.grad.norm().item()
                                })  

                optimizer.step()

                loss_terms_batch['loss'] = loss.item() / len(features)
                loss_terms_batch['kernel_embedded_target_loss'] = kernel_embedded_target_loss.item() / len(features)
                loss_terms_batch['target_decoding_from_reduced_emb_loss'] = target_decoding_from_reduced_emb_loss.item() / len(features)
                loss_terms_batch['direction_reg_target_loss'] = direction_reg_target.item() / len(features)
                
                if cfg.SupConLoss_on_mat:
                    loss_terms_batch['direction_reg_features_loss'] = direction_reg_features.item() / len(features)
                    loss_terms_batch['kernel_embedded_feature_loss'] = kernel_embedded_feature_loss.item() / len(features)
                
                if not cfg.mat_ae_pretrained:
                    loss_terms_batch['feature_autoencoder_loss'] = feature_autoencoder_loss.item() / len(features)
                    wandb.log({
                        'Epoch': epoch,
                        'feature_autoencoder_loss': loss_terms_batch['feature_autoencoder_loss']
                    })
                
                wandb.log({
                    'Epoch': epoch,
                    'Run': run,
                    'total_loss': loss_terms_batch['loss'],
                    'kernel_embedded_target_loss': loss_terms_batch['kernel_embedded_target_loss'],
                    'kernel_embedded_feature_loss': loss_terms_batch['kernel_embedded_feature_loss'],
                    'direction_reg_target_loss': loss_terms_batch['direction_reg_target_loss'],
                    'target_decoding_from_reduced_emb_loss': loss_terms_batch['target_decoding_from_reduced_emb_loss']
                })

            if cfg.SupConLoss_on_mat:
                wandb.log({
                    'Epoch': epoch,
                    'Run': run,
                    'direction_reg_features_loss': loss_terms_batch['direction_reg_features_loss'],
                    'kernel_embedded_feature_loss': loss_terms_batch['kernel_embedded_feature_loss'],
                })

            loss_terms_batch['epoch'] = epoch
            loss_terms.append(loss_terms_batch)

            model.eval()
            mape_batch = 0
            corr_batch = 0
            with torch.no_grad():
                for (features, targets) in test_loader:
                    
                    features, targets = features.to(device), targets.to(device)                    
                    out_feat = model.encode_features(features)
                    out_feat = torch.tensor(sym_matrix_to_vec(out_feat.detach().cpu().numpy(), discard_diagonal = True)).float().to(device)
                    transfer_out_feat = model.transfer_embedding(out_feat)
                    out_target_decoded = model.decode_targets(transfer_out_feat)
                    
                    epsilon = 1e-8
                    mape =  torch.mean(torch.abs((targets - out_target_decoded)) / torch.abs((targets + epsilon))) * 100
                    corr =  spearmanr(targets.cpu().numpy().flatten(), out_target_decoded.cpu().numpy().flatten())[0]
                    mape_batch+=mape.item()
                    corr_batch += corr

                mape_batch = mape_batch/len(test_loader)
                corr_batch = corr_batch/len(test_loader)
                validation.append(mape_batch)

            wandb.log({
                'Target MAPE/val' : mape_batch,
                'Target Corr/val': corr_batch,
                })
            
            scheduler.step(mape_batch)
            if np.log10(scheduler._last_lr[0]) < -4:
                break

            pbar.set_postfix_str(
                f"Epoch {epoch} "
                f"| Loss {loss_terms[-1]['loss']:.02f} "
                f"| Corr {corr_batch:.02f} "
            )
    wandb.finish()
    loss_terms = pd.DataFrame(loss_terms)
    return loss_terms, model

In [4]:
EMB_LOSSES ={
    'Norm': NormLoss(),
    'LogEuclidean': LogEuclideanLoss(),
    'MSE': nn.functional.mse_loss,
    'MSERobust': OutlierRobustMSE(),
    'Huber': nn.HuberLoss(),
    'cosine': nn.functional.cosine_embedding_loss,
}

SUPCON_KERNELS = {
    'cauchy': cauchy,
    'gaussian_kernel': gaussian_kernel,
    'None': None
    }

In [11]:
class ModelRun(submitit.helpers.Checkpointable):
    def __init__(self):
        self.results = None
        self.embeddings = None

    def __call__(self, train, test_size, indices, train_ratio, run_size, run, dataset, cfg, random_state=None, device=None, save_model = True, path: Path = None):
        if self.results is None:
            if device is None:
                device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
                print(f"Device {device}, ratio {train_ratio}", flush=True)
            if not isinstance(random_state, np.random.RandomState):
                random_state = np.random.RandomState(random_state)

            augmentations = cfg.augmentation

            recon_mat_dir = os.path.join(cfg.output_dir, cfg.experiment_name, cfg.reconstructed_dir)
            os.makedirs(recon_mat_dir, exist_ok=True)
    
            predictions = {}
            autoencoder_features = {}
            losses = []
            self.embeddings = {'train': [], 'test': []}
            self.run = run

            if cfg.mat_ae_pretrained:
                print("Loading test indices from the pretraining experiment...")
                test_indices = np.load(f"{cfg.output_dir}/{cfg.pretrained_mat_ae_exp}/test_idx.npy")
                train_indices = np.setdiff1d(indices, test_indices)
            elif cfg.external_test_mode:
                test_scanners = list(cfg.test_scanners)
                xr_dataset = xr.open_dataset(cfg.dataset_path)
                scanner_mask = np.sum([xr_dataset.isin(scanner).scanner.values for scanner in test_scanners],
                                    axis = 0).astype(bool)
                test_indices = indices[scanner_mask]
                train_indices = indices[~scanner_mask]
                del xr_dataset
            else:
                run_indices = random_state.choice(indices, run_size, replace=False)
                train_indices, test_indices = train_test_split(run_indices, test_size=test_size, random_state=random_state)
                
            train_dataset = Subset(dataset, train_indices)
            test_dataset = Subset(dataset, test_indices)

            train_features = train_dataset.dataset.matrices[train_dataset.indices]
            train_targets = train_dataset.dataset.targets[train_dataset.indices].numpy()
            std_train_targets, mean, std= standardize(train_targets)
            train_targets = np.log1p(train_targets+1)
            # scaler = MinMaxScaler().fit(train_targets)
            # train_targets = scaler.transform(train_targets)

            input_dim_feat =cfg.input_dim_feat
            output_dim_feat = cfg.output_dim_feat

            ## Weight initialization for bilinear layer
            mean_f = torch.mean(train_features, dim=0).to(device)
            [D,V] = torch.linalg.eigh(mean_f,UPLO = "U")
            B_init_fMRI = V[:,input_dim_feat-output_dim_feat:] 
            test_features= test_dataset.dataset.matrices[test_dataset.indices].numpy()
            test_targets = test_dataset.dataset.targets[test_dataset.indices].numpy()
            test_targets = np.log1p(test_targets+1)
            # test_targets = scaler.transform(test_targets)

            ### Augmentation
            if augmentations != 'None':
#                 aug_params = {}
                if not isinstance(augmentations, list):
                    augmentations = [augmentations]
                n_augs = len(augmentations)
                vect_train_features = sym_matrix_to_vec(train_features, discard_diagonal=True)
                n_samples = len(train_dataset)
                n_features = vect_train_features.shape[-1]
                new_train_features = np.zeros((n_samples + n_samples * n_augs, 1, n_features))
                new_train_features[:n_samples, 0, :] = vect_train_features

                for i, aug in enumerate(augmentations):
                    transform = augs[aug]
                    transform_args = aug_args[aug]
#                     aug_params[aug] = transform_args # to save later in the metrics df

                    num_aug = i + 1
                    aug_features = np.array([transform(sample, **transform_args) for sample in train_features])
                    aug_features = sym_matrix_to_vec(aug_features, discard_diagonal=True)

                    new_train_features[n_samples * num_aug: n_samples * (num_aug + 1), 0, :] = aug_features

                train_features = new_train_features
                train_targets = np.concatenate([train_targets]*(n_augs + 1), axis=0)
            
            train_dataset = TensorDataset(train_features, torch.from_numpy(train_targets).to(torch.float32))
            test_dataset = TensorDataset(torch.from_numpy(test_features).to(torch.float32), torch.from_numpy(test_targets).to(torch.float32))

            loss_terms, model = train(run, train_ratio, train_dataset, test_dataset,mean, std, B_init_fMRI, cfg, device=device)
            losses.append(loss_terms.eval("train_ratio = @train_ratio").eval("run = @run"))

            wandb.init(project=cfg.project,
                mode = "offline",
                name=f"TEST_{cfg.experiment_name}_run{run}_train_ratio_{train_ratio}",
                dir = cfg.output_dir,
                config = OmegaConf.to_container(cfg, resolve=True))
            
            embedding_dir = os.path.join(cfg.output_dir, cfg.experiment_name, cfg.embedding_dir)
            os.makedirs(embedding_dir, exist_ok=True)

            model.eval()
            with torch.no_grad():
                train_dataset = Subset(dataset, train_indices)
                train_features = train_dataset.dataset.matrices[train_dataset.indices].numpy()
                train_targets = train_dataset.dataset.targets[train_dataset.indices].numpy()
                train_targets = np.log1p(train_targets+1)
                train_dataset = TensorDataset(torch.from_numpy(train_features).to(torch.float32), torch.from_numpy(train_targets).to(torch.float32))

                for label, d, d_indices in (('train', train_dataset, train_indices), ('test', test_dataset, test_indices)):
                    is_test = True
                    if label == 'train':
                        is_test = False
                    
                    X, y = zip(*d)
                    X = torch.stack(X).to(device)
                    y = torch.stack(y).to(device)
                    X_embedded, y_embedded = model.forward(X, y)
                                        
                    if label == 'test' and train_ratio == 1.0:
                        np.save(f'{recon_mat_dir}/test_idx_run{run}',d_indices)
                        recon_mat = model.decode_features(X_embedded)
                        mape_mat = torch.abs((X - recon_mat) / (X + 1e-10)) * 100
                        
                        wandb_plot_test_recon_corr(wandb, cfg.experiment_name, cfg.work_dir, recon_mat.cpu().numpy(), X.cpu().numpy(), mape_mat.cpu().numpy(), True, run)
                        wandb_plot_individual_recon(wandb, cfg.experiment_name, cfg.work_dir, d_indices, recon_mat.cpu().numpy(), X.cpu().numpy(), mape_mat.cpu().numpy(), 0, True, run)

                        np.save(f'{recon_mat_dir}/recon_mat_run{run}', recon_mat.cpu().numpy())
                        np.save(f'{recon_mat_dir}/mape_mat_run{run}', mape_mat.cpu().numpy())

                    X_embedded = X_embedded.cpu().numpy()
                    X_embedded = torch.tensor(sym_matrix_to_vec(X_embedded, discard_diagonal=True)).to(torch.float32).to(device)
                    X_emb_reduced = model.transfer_embedding(X_embedded).to(device)

                    y_pred = model.decode_targets(X_emb_reduced)
                    y_pred = np.exp(y_pred.cpu().numpy())-1
                    y_pred = torch.tensor(y_pred).to(device)

                    y = np.exp(y.cpu().numpy())-1
                    y = torch.tensor(y).to(device)

                    save_embeddings(X_embedded, "mat", cfg, is_test, run)
                    save_embeddings(X_emb_reduced, "joint", cfg, is_test, run)

                    if label == 'test':
                        epsilon = 1e-8
                        mape =  100 * torch.mean(torch.abs((y - y_pred)) / torch.abs((y + epsilon))).item()
                        corr =  spearmanr(y.cpu().numpy().flatten(), y_pred.cpu().numpy().flatten())[0]

                        wandb.log({
                            'Run': run,
                            'Test | Target MAPE/val' : mape,
                            'Test | Target Corr/val': corr,
                            'Test | Train ratio' : train_ratio
                            })
            
                    predictions[(train_ratio, run, label)] = (y.cpu().numpy(), y_pred.cpu().numpy(), d_indices)
                    for i, idx in enumerate(d_indices):
                        self.embeddings[label].append({
                            'index': idx,
                            'target_embedded': y_embedded[i].cpu().numpy(),
                            'feature_embedded': X_emb_reduced[i].cpu().numpy()
                        })
            wandb.finish()
            
            self.results = (losses, predictions, self.embeddings)

        if save_model:
            saved_models_dir = os.path.join(cfg.output_dir, cfg.experiment_name, cfg.model_weight_dir)
            os.makedirs(saved_models_dir, exist_ok=True)
            torch.save(model.state_dict(), f"{saved_models_dir}/model_weights_run{run}.pth")

        return self.results

    def checkpoint(self, *args, **kwargs):
        print("Checkpointing", flush=True)
        return super().checkpoint(*args, **kwargs)

In [7]:
with initialize(version_base=None, config_path="."):
    cfg = compose(config_name='main_model_config.yaml')
    print(OmegaConf.to_yaml(cfg))

project: PhenProj
experiment_name: multivar_camcan_nonlin_pred_head
hypothesis: '-'
input_dim_feat: 400
output_dim_feat: 200
hidden_dim: 100
input_dim_target: 19
output_dim_target: 50
skip_enc1: false
ReEig: false
mat_ae_pretrained: false
target_ae_pretrained: false
pretrained_mat_ae_exp: internal_mat_ae_abcd
pretrained_target_ae_exp: target_ae
best_mat_ae_fold: 4
best_target_ae_fold: 1
synth_exp: false
multi_gpu: true
num_epochs: 100
batch_size: 28
n_runs: 1
lr: 0.001
weight_decay: 0
dropout_rate: 0
scheduler_patience: 10
test_ratio: 0.3
train_ratio: 1.0
log_gradients: true
clip_grad: true
external_test_mode: false
test_scanners:
- GE MEDICAL SYSTEMS_DISCOVERY MR750
- Philips Medical Systems_Achieva dStream
- Philips Medical Systems_Ingenia
SupCon_kernel: cauchy
SupConLoss_on_mat: false
pft_base_temperature: 0.07
pft_temperature: 0.07
pft_sigma: 1
pft_reg_term: 0.01
ptt_base_temperature: 0.07
ptt_temperature: 0.07
ptt_sigma: 1
ptt_reg_term: 0.01
feature_autoencoder_crit: Norm
joint_em

In [None]:
# test_ratio = 0.3
# train_ratio = 1.0
# dataset_path = "/gpfs3/well/margulies/users/cpy397/contrastive-learning/data/abcd_dataset_400parcels_1.nc"
# dataset = MatData(dataset_path, ["nihtbx_totalcomp_agecorrected"], synth_exp = False, threshold=0)
# n_sub = len(dataset)
# indices = np.arange(n_sub)
# train_size = int(n_sub * (1 - test_ratio) * train_ratio)
# test_size = int(test_ratio * n_sub)
# run_size = test_size + train_size
# random_state = np.random.RandomState(seed=42)
# run_indices = random_state.choice(indices, run_size, replace=False)
# train_indices, test_indices = train_test_split(run_indices, test_size=0.3, random_state=random_state)
# train_dataset = Subset(dataset, train_indices)
# test_dataset = Subset(dataset, test_indices)


In [None]:
# train_features = train_dataset.dataset.matrices[train_dataset.indices]
# train_targets = train_dataset.dataset.targets[train_dataset.indices].numpy()
# std_train_targets, mean, std= standardize(train_targets)
# # scaler = MinMaxScaler().fit(train_targets)
# # train_targets = scaler.transform(train_targets)

# input_dim_feat =cfg.input_dim_feat
# output_dim_feat = cfg.output_dim_feat

# ## Weight initialization for bilinear layer
# mean_f = torch.mean(train_features, dim=0).to(device)
# [D,V] = torch.linalg.eigh(mean_f,UPLO = "U")
# B_init_fMRI = V[:,input_dim_feat-output_dim_feat:] 
# test_features= test_dataset.dataset.matrices[test_dataset.indices].numpy()
# test_targets = test_dataset.dataset.targets[test_dataset.indices].numpy()

In [None]:
# train_loader = DataLoader(train_dataset, batch_size=28, shuffle=True)
# test_loader = DataLoader(test_dataset, batch_size=228, shuffle=True)

In [None]:
# model = PhenoProj(
#     cfg.input_dim_feat,
#     cfg.input_dim_target,
#     cfg.hidden_dim,
#     cfg.output_dim_target,
#     cfg.output_dim_feat,
#     cfg.dropout_rate,
#     cfg
# ).to(device)

In [None]:
# matrices = train_loader.dataset.dataset.matrices[:3]
# matrices = matrices.to(device)
# targets = train_loader.dataset.dataset.targets[:3]
# targets = targets.to(device)

In [None]:
# input_dim_feat=cfg.input_dim_feat,
# input_dim_target=cfg.input_dim_target,
# hidden_dim=cfg.hidden_dim,
# output_dim_target=cfg.output_dim_target,
# output_dim_feat=cfg.output_dim_feat,
# dropout_rate=cfg.dropout_rate,

# enc_mat1 = nn.Linear(in_features=400, out_features=70 ,bias=False)
# enc_mat2 = nn.Linear(in_features=400, out_features=70, bias=False)

# enc_mat1.weight = torch.nn.Parameter(B_init_fMRI.transpose(0,1))
# enc_mat2.weight = torch.nn.Parameter(B_init_fMRI.transpose(0,1))

# enc_mat1.weight = torch.nn.Parameter(B_init_fMRI.transpose(0,1))
# enc_mat2.weight = torch.nn.Parameter(enc_mat1.weight)


In [None]:
with initialize(version_base=None, config_path="."):
    cfg = compose(config_name='main_model_config.yaml')
    print(OmegaConf.to_yaml(cfg))

In [12]:
def main(cfg=cfg):

    results_dir = os.path.join(cfg.output_dir, cfg.experiment_name)
    os.makedirs(results_dir, exist_ok=True)

    random_state = np.random.RandomState(seed=42)

    dataset_path = cfg.dataset_path

    if isinstance(cfg.targets, str):
        
        targets =[cfg.targets]
    else:
        targets = list(cfg.targets)
        
    test_ratio = cfg.test_ratio

    dataset = MatData(dataset_path, targets, synth_exp = cfg.synth_exp, threshold=cfg.mat_threshold)
    n_sub = len(dataset)
    test_size = int(test_ratio * n_sub)
    indices = np.arange(n_sub)
    n_runs = cfg.n_runs
    multi_gpu = cfg.multi_gpu
    train_ratio = cfg.train_ratio
    
    multi_gpu = False
    if multi_gpu:
        print("Using multi-gpu")
        log_folder = Path("logs")
        executor = submitit.AutoExecutor(folder=str(log_folder / "%j"))
        executor.update_parameters(
            timeout_min=120,
            slurm_partition="gpu_short",
            gpus_per_node=1,
            tasks_per_node=1,
            nodes=1
            #slurm_constraint="v100-32g",
        )
        run_jobs = []

        with executor.batch():
            train_size = int(n_sub * (1 - test_ratio) * train_ratio)
            run_size = test_size + train_size
            for run in tqdm(range(n_runs)):
                run_model = ModelRun()
                job = executor.submit(run_model, train, test_size, indices, train_ratio, run_size, run, dataset, cfg, random_state=random_state, device=None)
                run_jobs.append(job)

        async def get_result(run_jobs):
            run_results = []
            for aws in tqdm(asyncio.as_completed([j.awaitable().result() for j in run_jobs]), total=len(run_jobs)):
                res = await aws
                run_results.append(res)
            return run_results
        run_results = asyncio.run(get_result(run_jobs))

    else:
        run_results = []
        train_size = int(n_sub * (1 - test_ratio) * train_ratio)
        run_size = test_size + train_size
        for run in tqdm(range(n_runs), desc="Model Run"):
            run_model = ModelRun()
            job = run_model(train, test_size, indices, train_ratio, run_size, run, dataset, cfg, random_state=random_state, device=None)
            run_results.append(job)

    losses, predictions, embeddings = zip(*run_results)

    prediction_metrics = predictions[0]
    for prediction in predictions[1:]:
        prediction_metrics.update(prediction)

    pred_results = []
    for k, v in prediction_metrics.items():
        true_targets, predicted_targets, indices = v
        
        true_targets_dict = {"train_ratio": [k[0]] * len(true_targets),
                             "model_run":[k[1]] * len(true_targets),
                             "dataset":[k[2]] * len(true_targets)
                            }
        predicted_targets_dict = {"indices": indices}
        
        for i, target in enumerate(targets):
            true_targets_dict[target] = true_targets[:, i]
            predicted_targets_dict[f"{target}_pred"] = predicted_targets[:, i]
            
            
        true_targets = pd.DataFrame(true_targets_dict)
        predicted_targets = pd.DataFrame(predicted_targets_dict)
        
        pred_results.append(pd.concat([true_targets, predicted_targets], axis = 1))
    pred_results = pd.concat(pred_results)
    pred_results.to_csv(f"{results_dir}/pred_results.csv", index=False)

    prediction_mape_by_element = []
    for k, v in prediction_metrics.items():
        true_targets, predicted_targets, indices = v
        
        mape_by_element = np.abs(true_targets - predicted_targets) / (np.abs(true_targets)+1e-10)
        
        for i, mape in enumerate(mape_by_element):
            prediction_mape_by_element.append(
                {
                    'train_ratio': k[0],
                    'model_run': k[1],
                    'dataset': k[2],
                    'mape': mape
                }
            )

    df = pd.DataFrame(prediction_mape_by_element)
    df = pd.concat([df.drop('mape', axis=1), df['mape'].apply(pd.Series)], axis=1)
    df.columns = ['train_ratio', 'model_run', 'dataset'] + targets
    df= df.groupby(['train_ratio', 'model_run', 'dataset']).agg('mean').reset_index()
    df.to_csv(f"{results_dir}/mape.csv", index = False)

if __name__ == "__main__":
    main()

Model Run:   0%|                                          | 0/1 [00:00<?, ?it/s]

Device cuda, ratio 1.0
Start training...



Epochs:   0%|                                           | 0/100 [00:00<?, ?it/s][A
Epochs:   0%|       | 0/100 [00:00<?, ?it/s, Epoch 0 | Loss 70.73 | Corr -0.13 ][A
Epochs:   1%| | 1/100 [00:00<00:41,  2.36it/s, Epoch 0 | Loss 70.73 | Corr -0.13[A
Epochs:   1%| | 1/100 [00:00<00:41,  2.36it/s, Epoch 1 | Loss 55.77 | Corr 0.01 [A
Epochs:   2%| | 2/100 [00:00<00:39,  2.46it/s, Epoch 1 | Loss 55.77 | Corr 0.01 [A
Epochs:   2%| | 2/100 [00:01<00:39,  2.46it/s, Epoch 2 | Loss 47.03 | Corr 0.22 [A
Epochs:   3%| | 3/100 [00:01<00:38,  2.54it/s, Epoch 2 | Loss 47.03 | Corr 0.22 [A
Epochs:   3%| | 3/100 [00:01<00:38,  2.54it/s, Epoch 3 | Loss 47.92 | Corr 0.52 [A
Epochs:   4%| | 4/100 [00:01<00:38,  2.50it/s, Epoch 3 | Loss 47.92 | Corr 0.52 [A
Epochs:   4%| | 4/100 [00:01<00:38,  2.50it/s, Epoch 4 | Loss 42.71 | Corr 0.77 [A
Epochs:   5%| | 5/100 [00:01<00:37,  2.55it/s, Epoch 4 | Loss 42.71 | Corr 0.77 [A
Epochs:   5%| | 5/100 [00:02<00:37,  2.55it/s, Epoch 5 | Loss 38.47 | Corr 

0,1
Epoch,▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▄▄▅▅▅▅▅▅▅▅▅▅▆▆▇▇▇▇█
Gradient Norm/feat_to_target_embedding.0.bias,█▅▄▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁
Gradient Norm/feat_to_target_embedding.0.weight,██▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▁▂▁▁▁▂▁▂▁▂▁▂▁▂▂▂
Gradient Norm/feat_to_target_embedding.1.bias,▃██▃▃▂▂▂▂▂▂▂▃▂▂▂▁▁▁▂▁▁▂▂▁▁▁▃▄▂▁▃▂▂▂▁▁▂▃▁
Gradient Norm/feat_to_target_embedding.1.weight,▇▅▃▃▄▅▄▄▄▂▂▂▂▃▂▂▁▂▁▂▂▂▂▂▁▂▂▂▃▄▅▆▇▄▁▂▄▃█▁
Gradient Norm/feat_to_target_embedding.4.bias,▅▆▅▃▄▃▄▃▂▁▃▃▂▁▂▁▄▃▂▁▃▃▁▂▂▃▃█▆▃▄▂▁▃▂▆▂▁▅▃
Gradient Norm/feat_to_target_embedding.4.weight,▇█▆▆▃▂▂▂▂▂▂▂▂▁▂▁▂▁▂▂▄▂▁▂▂▁▁▃▃▂▁▂▂▃▂▂▁▄▂▁
Gradient Norm/feat_to_target_embedding.5.bias,▁▃▃▅▃▃▃▃▃▄▃▃█▃▂▃▄▂▃▃▃▃▃▄▃▃▃▃▃▆▅▃▃▄▆▅▃▆▃▃
Gradient Norm/feat_to_target_embedding.5.weight,▁█▃▃▂▃▁▂▃▁▂▂▂▂▁▁▁▃▂▃▃▇▄▄▃▅▅▄▅▄▅▅▆▄▄▅▅▆▅▅
Gradient Norm/feat_to_target_embedding.8.bias,█▁▂▁▃▂▁▂▁▂▂▁▂▃▁▁▁▂▁▁▂▂▂▂▁▁▁▂▁▁▁▁▂▂▁▁▂▂▂▂

0,1
Epoch,75.0
Gradient Norm/feat_to_target_embedding.0.bias,0.0
Gradient Norm/feat_to_target_embedding.0.weight,0.11572
Gradient Norm/feat_to_target_embedding.1.bias,0.00284
Gradient Norm/feat_to_target_embedding.1.weight,0.00663
Gradient Norm/feat_to_target_embedding.4.bias,0.0
Gradient Norm/feat_to_target_embedding.4.weight,0.10062
Gradient Norm/feat_to_target_embedding.5.bias,0.02388
Gradient Norm/feat_to_target_embedding.5.weight,0.05096
Gradient Norm/feat_to_target_embedding.8.bias,0.01541


  correlations[i // cols, i % cols] = spearmanr(flat_true[:, i], flat_recon[:, i])[0]


0,1
Run,▁
Test | Target Corr/val,▁
Test | Target MAPE/val,▁
Test | Train ratio,▁

0,1
Run,0.0
Test | Target Corr/val,0.98309
Test | Target MAPE/val,6.76515
Test | Train ratio,1.0


Model Run: 100%|█████████████████████████████████| 1/1 [03:40<00:00, 220.13s/it]
