In [2]:
import torch
num_gpus = torch.cuda.device_count()
for i in range(num_gpus):
    device = torch.device(f'cuda:{i}')
print(torch.cuda.current_device())


0


In [3]:
num_gpus

1

In [1]:
import sys
sys.path.append('/pscratch/sd/y/ycryu/ENIGMA_OCD_MBBN/MBBN-main')

### UQ_Writer

In [2]:
import os
import dill
from pathlib import Path
import argparse
import sys

import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve
from tqdm import tqdm
from torch.cuda.amp import autocast
from torch.utils.data import Subset, DataLoader
from torch import nn

from utils import *
from trainer import Trainer
from loss_writer import Writer
from metrics import Metrics

class UQWriter(Writer):
    def __init__(self, sets, val_threshold, **kwargs):
        super().__init__(sets, val_threshold, **kwargs)
        self.confidence_list = []
        self.is_correct_list = []
        self.uncertainty_statistics_dict = {}
        self.uncertainty_quantification_stat = {}

    def compute_confidence(self, confidence_list: list, is_correct_list: list):
        num_bins = 10
        bin_edges = np.linspace(0.0, 1.0, num_bins + 1)  # Bin edges from 0 to 1
        bin_indices = np.digitize(confidence_list, bin_edges, right=True)
        bin_middlepoint = (bin_edges[1:] + bin_edges[:-1])/2

        bin_confidences = []
        bin_accuracies = []
        bin_gaps = []
        confidence_nparray = np.array(confidence_list)
        is_correct_nparray = np.array(is_correct_list)
        # ECE is weighted average of calibration error in each bin
        # MCE is maximum calibration error in each bin
        cum_ce = 0
        mce = 0

        # organizing bin elements
        for i in range(1, num_bins + 1):
            indices = np.where(bin_indices == i)[0]  # Get indices of elements in the bin
            if len(indices) > 0:
                avg_confidence = np.mean(confidence_nparray[indices])  # Average confidence
                avg_accuracy = np.mean(is_correct_nparray[indices])  # Accuracy as mean of correct labels
                gap = avg_confidence - avg_accuracy  # Gap between confidence and accuracy

                bin_confidences.append(avg_confidence)
                bin_accuracies.append(avg_accuracy)
                bin_gaps.append(gap)
                cum_ce += np.abs(gap) * len(indices)
                mce = max(mce, np.abs(gap))
            else:
                bin_confidences.append(0)
                bin_accuracies.append(0)
                bin_gaps.append(0)
        
        ece = cum_ce / len(confidence_list)


        # FAR95 statistics
        far95, threshold, fpr, tpr = self.metrics.compute_far95(confidence_list, is_correct_list) # Returns None if no threshold found
        if far95 is None:
            print("\nFalse Acceptance Rate at 95% Recall: No threshold found")
        
        # Drawing ROC curve
        plt.figure(figsize=(8, 6))
        plt.plot(fpr, tpr, marker='o', linestyle='-', label='ROC curve')
        plt.plot([0, 1], [0, 1], linestyle='--', color='gray', label='Chance')
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('Receiver Operating Characteristic (ROC) Curve')
        plt.legend(loc='lower right')
        plt.grid(True)
        
        # Save the ROC plot
        roc_save_path = os.path.join(self.kwargs.get("experiment_folder"), 'roc_curve.png')
        plt.savefig(roc_save_path, dpi=300, bbox_inches='tight')
        plt.close()

        # Save the ECE/MCE/FAR95 statistics
        self.uncertainty_quantification_stat['ece'] = ece
        self.uncertainty_quantification_stat['mce'] = mce
        self.uncertainty_quantification_stat['far95'] = far95
        self.uncertainty_quantification_stat['threshold'] = threshold

        stat_save_path = os.path.join(self.kwargs.get("experiment_folder"), 'statistics.txt')
        with open(stat_save_path, 'w') as f:
            f.write("==========All samples evaluated==========\n")
            f.write(f"Expected Calibration Error: {ece}\n")
            f.write(f"Maximum Calibration Error: {mce}\n")
            if far95 is not None:
                f.write(f"False Acceptance Rate at 95% Recall: {far95} (threshold: {threshold})\n")

        # drawing plot
        bar_width = 0.08  # Width of the bars
        plt.figure(figsize=(8, 6))

        plt.bar(bin_edges[:-1], bin_accuracies, width=bar_width, align='edge', color='blue', edgecolor='black', label="Outputs")
        plt.bar(bin_edges[:-1], bin_gaps, width=bar_width, align='edge', color='pink', alpha=0.7, label="Gap", bottom=bin_accuracies)
        plt.plot([0, 1], [0, 1], linestyle='--', color='gray', label="Perfect Calibration")

        plt.text(0.7, 0.1, f'ECE={ece:.4f}', fontsize=14, bbox=dict(facecolor='lightgray', alpha=0.5))

        plt.xlabel('Confidence')
        plt.ylabel('Accuracy')
        plt.title('Reliability Diagram')
        plt.legend()
        plt.grid(True)
        plt.xlim([0, 1])
        plt.ylim([0, 1])

        diagram_save_path = os.path.join(self.kwargs.get("experiment_folder"), 'reliability_diagram.png')
        plt.savefig(diagram_save_path, dpi=300, bbox_inches='tight')

    def sample_uncertainty_statistic(self, subj_name: int, subj_dict: dict, subj_truth: int):

        sample_prediction = 1 if subj_dict['score'].mean().item() > 0.5 else 0
        is_correct = 1 if subj_truth == sample_prediction else 0

        # probabilities list for predicted sample 
        # (which 0 or 1, prob of 1 when sample_prediction == 1 and prob of 0 when sample_prediction == 0)
        if sample_prediction == 1:
            sample_pred_probabilities_list = subj_dict['score'].tolist()
        else:
            sample_pred_probabilities_list = (1 - subj_dict['score']).tolist()
        
        # Calculate uncertainty
        mean = torch.mean(torch.tensor(sample_pred_probabilities_list), axis=0)
        variance = torch.var(torch.tensor(sample_pred_probabilities_list), axis=0)

        confidence_intervals_dict = {}
        confidence_levels = [0.9, 0.95]
        for confidence_level in confidence_levels:
            lower_percentile = (1 - confidence_level) / 2 * 100  # 2.5% for 95% CI
            upper_percentile = (1 + confidence_level) / 2 * 100  # 97.5% for 95% CI

            # Compute confidence intervals for each class
            probabilities_list = torch.stack([(1 - subj_dict['score']), subj_dict['score']], dim=1)
            confidence_intervals = np.percentile(probabilities_list, [lower_percentile, upper_percentile], axis=0)
            confidence_intervals_dict[confidence_level] = confidence_intervals

        stat_save_path = os.path.join(self.kwargs.get("experiment_folder"), 'sample_statistics.txt')
        with open(stat_save_path, 'a') as f:
            f.write(f"\nStatistics for sample {subj_name}\n")
            f.write(f"No. of forward passes: {len(sample_pred_probabilities_list)}\n")
            f.write(f"Final prediction: {sample_prediction}\n")
            f.write(f"True label: {subj_truth}\n")
            f.write(f"Correct: {True if is_correct else False}\n")
            f.write(f"Prediction probability: {mean}\n")
            f.write(f"Variance: {variance}\n")
            for confidence_level in confidence_levels:
                for i in range(2):
                    f.write(f"Class {i}: {confidence_level * 100}% CI = [{confidence_intervals_dict[confidence_level][0, i]}, {confidence_intervals_dict[confidence_level][1, i]}]\n")
            f.write("\n")

        self.uncertainty_statistics_dict[subj_name] = {
            'mean': mean,
            'variance': variance,
            'confidence_intervals': confidence_intervals_dict,
            'is_correct': is_correct,
            'sample_prediction': sample_prediction,
            'sample_pred_probabilities_list': sample_pred_probabilities_list,
            'truth': subj_truth
        }

        return mean, is_correct

    
    def accuracy_summary(self, mid_epoch, mean, std):
        pred_all_sets = {x:[] for x in self.sets}   # dictionary to store predictions
        truth_all_sets = {x:[] for x in self.sets}  # dictionary to store ground truth values
        std_all_sets = {x:[] for x in self.sets}  # dictionary to store prediction errors
        metrics = {}
        confidence_list = []
        is_correct_list = []
        
        for subj_name,subj_dict in self.subject_accuracy.items():  # per-subject prediction scores (score), ground truth labels (truth), and the set (mode) they belong to
            
            if self.fine_tune_task == 'binary_classification':
                subj_dict['score'] = torch.sigmoid(subj_dict['score'].float())

            # subj_dict['score'] denotes the logits for sequences for a subject
            subj_pred = subj_dict['score'].mean().item() 
            subj_error = subj_dict['score'].std().item()

            subj_truth = subj_dict['truth'].item()
            subj_mode = subj_dict['mode'] # train, val, test

            conf, is_corr = self.sample_uncertainty_statistic(subj_name, subj_dict, subj_truth)
            confidence_list.append(conf)
            is_correct_list.append(is_corr)

            # with open(os.path.join(self.per_subject_predictions,'iter_{}.txt'.format(self.eval_iter)),'a+') as f:
            #     f.write('subject:{} ({})\noutputs: {:.4f}\u00B1{:.4f}  -  truth: {}\n'.format(subj_name,subj_mode,subj_pred,subj_error,subj_truth))
            
            pred_all_sets[subj_mode].append(subj_pred) # don't use std in computing AUROC, ACC
            std_all_sets[subj_mode].append(subj_error)
            truth_all_sets[subj_mode].append(subj_truth)

        for (name,pred),(_, std),(_,truth) in zip(pred_all_sets.items(), std_all_sets.items(), truth_all_sets.items()):
            if len(pred) == 0:
                continue

            if self.fine_tune_task == 'regression':
                ## return to original scale ##
                unnormalized_pred = [i * std + mean for i in pred]
                unnormalized_truth = [i * std + mean for i in truth]

                metrics[name + '_MAE'] = self.metrics.MAE(unnormalized_truth,unnormalized_pred)
                metrics[name + '_MSE'] = self.metrics.MSE(unnormalized_truth,unnormalized_pred)
                metrics[name +'_NMSE'] = self.metrics.NMSE(unnormalized_truth,unnormalized_pred)
                metrics[name + '_R2_score'] = self.metrics.R2_score(unnormalized_truth,unnormalized_pred)
                
            else:
                metrics[name + '_Balanced_Accuracy'] = self.metrics.BAC(truth,[x>0.5 for x in torch.Tensor(pred)])
                metrics[name + '_Regular_Accuracy'] = self.metrics.RAC(truth,[x>0.5 for x in torch.Tensor(pred)]) # Stella modified it
                metrics[name + '_AUROC'] = self.metrics.AUROC(truth,pred)             
                metrics[name +'_best_bal_acc'], metrics[name + '_best_threshold'],metrics[name + '_gmean'],metrics[name + '_specificity'],metrics[name + '_sensitivity'],metrics[name + '_f1_score'] = self.metrics.ROC_CURVE(truth,pred,name,self.val_threshold)

            self.current_metrics = metrics
            
            
        for name,value in metrics.items():
            self.scalar_to_tensorboard(name,value)
            if hasattr(self,name):
                l = getattr(self,name)
                l.append(value)
                setattr(self,name,l)
            else:
                setattr(self, name, [value])
                
        self.eval_iter += 1
        if mid_epoch and len(self.subject_accuracy) > 0:
            self.subject_accuracy = {k: v for k, v in self.subject_accuracy.items() if v['mode'] == 'train'}
        else:
            self.subject_accuracy = {}

        self.confidence_list = confidence_list
        self.is_correct_list = is_correct_list


### UQ_Trainer

In [1]:
class UQTrainer(Trainer):

    def __init__(self, sets, model_idx = None, **kwargs):
        super().__init__(sets, **kwargs)
        self.model_idx = model_idx
        self.writer = UQWriter(sets, self.val_threshold, **kwargs)
        print(f"model_idx: {model_idx}")
    
    ## Should be changed to save at asdf_epoch{1}/checkpoint_model{0}.pth
    ## YC : CHANGED
    def save_checkpoint_(self, epoch, batch_idx, scaler):
        model_idx = self.model_idx

        loss = self.get_last_loss()
        #accuracy = self.get_last_AUROC()
        val_ACC = self.get_last_ACC()
        val_best_ACC = self.get_last_best_ACC()
        val_AUROC = self.get_last_AUROC()
        val_MAE = self.get_last_MAE()
        val_threshold = self.get_last_val_threshold()

        if self.method == 'ensemble':
            if model_idx is None:
                raise ValueError("model_idx must be provided for ensemble method.")
                
            title = str(self.writer.experiment_title) + '_epoch_' + str(epoch)
            directory = os.path.join(self.writer.experiment_folder, '_model_{}'.format(model_idx))
            if not os.path.exists(directory):
                os.makedirs(directory)
            if self.amp:
                amp_state = scaler.state_dict()
        
        else:
            title = str(self.writer.experiment_title) + '_epoch_' + str(int(epoch))
            directory = self.writer.experiment_folder

            # Create directory to save to
            if not os.path.exists(directory):
                os.makedirs(directory)
            if self.amp:
                amp_state = scaler.state_dict()

        # Build checkpoint dict to save.
        ckpt_dict = {
            # 'model_state_dict':self.model.module.state_dict(),  # Distributed case
            'model_state_dict':self.model.module.state_dict() if hasattr(self.model, "module") else self.model.state_dict(),
            'optimizer_state_dict':self.optimizer.state_dict() if self.optimizer is not None else None,
            'epoch':epoch,
            'loss_value':loss,
            'amp_state': amp_state}

        # if val_ACC is not None:
        #     ckpt_dict['val_ACC'] = val_ACC
        if val_AUROC is not None:
            ckpt_dict['val_AUROC'] = val_AUROC
        if val_threshold is not None:
            ckpt_dict['val_threshold'] = val_threshold
        if val_MAE is not None:
            ckpt_dict['val_MAE'] = val_MAE
        if self.lr_handler.schedule is not None:
            ckpt_dict['schedule_state_dict'] = self.lr_handler.schedule.state_dict()
            ckpt_dict['lr'] = self.optimizer.param_groups[0]['lr']
            print(f"current_lr:{self.optimizer.param_groups[0]['lr']}")
        if hasattr(self,'loaded_model_weights_path'):
            ckpt_dict['loaded_model_weights_path'] = self.loaded_model_weights_path
        
        # classification
        if val_AUROC is not None:
            if self.best_AUROC < val_AUROC:
                self.best_AUROC = val_AUROC
                name = "{}_BEST_val_AUROC.pth".format(title)
                torch.save(ckpt_dict, os.path.join(directory, name))
                print(f'updating best saved model with AUROC:{val_AUROC}')

                if self.best_ACC < val_ACC:
                    self.best_ACC = val_ACC
            elif self.best_AUROC >= val_AUROC:
                # If model is not improved in val AUROC, but improved in val ACC.
                if self.best_ACC < val_ACC:
                    self.best_ACC = val_ACC
                    name = "{}_BEST_val_ACC.pth".format(title)
                    torch.save(ckpt_dict, os.path.join(directory, name))
                    print(f'updating best saved model with ACC:{val_ACC}')

        # regression
        elif val_AUROC is None and val_MAE is not None:
            if self.best_MAE > val_MAE:
                self.best_MAE = val_MAE
                name = "{}_BEST_val_MAE.pth".format(title)
                torch.save(ckpt_dict, os.path.join(directory, name))
                print(f'updating best saved model with MAE: {val_MAE}')
            else:
                pass
                
        else:
            if self.best_loss > loss:
                self.best_loss = loss
                name = "{}_BEST_val_loss.pth".format(title)
                torch.save(ckpt_dict, os.path.join(directory, name))
                print(f'updating best saved model with loss: {loss}')
            else:
                pass

    def set_model_device(self):  # assigns the model to appropriate devices (e.g., GPU or CPU)
        if self.distributed:
            # For multiprocessing distributed, DistributedDataParallel constructor
            # should always set the single device scope, otherwise,
            # DistributedDataParallel will use all available devices.
            
            ### DEBUG STATEMENT ###
            print(f"self.gpu: {self.gpu}")
            if self.gpu is None:
                print("self.gpu is None")
            #######################
            
            if self.gpu is not None:
                print('id of gpu is:', self.gpu)
                self.device = torch.device('cuda:{}'.format(self.gpu))
                torch.cuda.set_device(self.gpu)
                self.model.cuda(self.gpu)
                self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[self.gpu], broadcast_buffers=False, find_unused_parameters=True) 
                net_without_ddp = self.model.module
            else:
                
                ### DEBUG STATEMENT ###
                print("Distributed training without specific GPU assignment")
                #######################
                
                self.device = torch.device("cuda" if self.cuda else "cpu")
                self.model.cuda()
                if 'reconstruction' in self.task.lower():
                    self.model = torch.nn.parallel.DistributedDataParallel(self.model) 
                else: # having unused parameter (classifier token)
                    self.model = torch.nn.parallel.DistributedDataParallel(self.model,find_unused_parameters=True) 
                model_without_ddp = self.model.module
        else:
            
            ### DEBUG STATEMENT ###
            print("Single GPU or CPU training")
            #######################
            
            self.device = torch.device("cuda" if self.cuda else "cpu")
            
            ### DEBUG STATEMENT ###
            print(f"self.gpu: {self.gpu}")
            print(f"self.device: {self.device}")
            #######################
            
            #self.model = DataParallel(self.model).to(self.device)
            
            ### DEBUG STATEMENT ###
            self.device = torch.device("cuda:0")   # added for debugging
            self.model = self.model.to(self.device)  
            #######################
            
            ### DEBUG STATEMENT ###
            print(f"moved model to: {self.device}")
            #######################


    def eval(self,set):
        ## If set == 'MC_dropout', then set dropout to True
        if set not in ['MC_dropout', 'train', 'val', 'test']:
            raise ValueError(f"Invalid set: {set}")
        self.mode = set
        if set == 'MC_dropout':
            for layer in self.model.modules():
                if isinstance(layer, nn.Dropout):
                    print(f"Enabling MC Dropout for layer {layer} - p={layer.p}")
                    layer.train()
        else:
            self.model = self.model.eval()

    def finish_eval(self, set):
        if set not in ['MC_dropout', 'train', 'val', 'test']:
            raise ValueError(f"Invalid set: {set}")
        if set == 'MC_dropout':
            self.model = self.model.eval()

    def concat_batch_results(self, inout_batches: list):
        inout_keys = inout_batches[0].keys()
        concat_inout = dict()
        for inout in inout_batches:
            for key in inout_keys:
                if key not in concat_inout:
                    concat_inout[key] = inout[key]
                else:
                    if isinstance(inout[key], list):
                        concat_inout[key] += inout[key]
                    elif isinstance(inout[key], torch.Tensor):
                        concat_inout[key] = torch.cat((concat_inout[key], inout[key]), dim=0)
                    else:
                        raise ValueError(f"Invalid inout type: {type(inout[key])}")
        
        return concat_inout
    
    def forward_pass(self,input_dict): 
        input_dict = {
            k: (
                v.to(self.device) if (self.cuda and torch.is_tensor(v)) else v
            ) for k, v in input_dict.items()
        }
        for k, v in input_dict.items():
            if torch.is_tensor(v):
                if not v.is_contiguous():
                    v = v.contiguous()
        
        if self.task.lower() == 'test':
            if self.fmri_type in ['timeseries', 'frequency', 'time_domain_high', 'time_domain_low', 'time_domain_ultralow', 'frequency_domain_low', 'frequency_domain_ultralow', 'frequency_domain_high']:
                output_dict = self.model(input_dict['fmri_sequence'])
            elif self.fmri_type == 'divided_timeseries':
                if self.fmri_dividing_type == 'two_channels':
                    output_dict = self.model(input_dict['fmri_lowfreq_sequence'], input_dict['fmri_ultralowfreq_sequence'])
                elif self.fmri_dividing_type == 'three_channels':
                    output_dict = self.model(input_dict['fmri_highfreq_sequence'], input_dict['fmri_lowfreq_sequence'], input_dict['fmri_ultralowfreq_sequence'])
                elif self.fmri_dividing_type == 'four_channels':
                    output_dict = self.model(input_dict['fmri_imf1_sequence'], input_dict['fmri_imf2_sequence'], input_dict['fmri_imf3_sequence'], input_dict['fmri_imf4_sequence'])

        
        #### train & valid ####
        else:
            if self.fmri_type in ['timeseries', 'frequency', 'time_domain_high', 'time_domain_low', 'time_domain_ultralow', 'frequency_domain_low', 'frequency_domain_ultralow', 'frequency_domain_high']:
                output_dict = self.model(input_dict['fmri_sequence'])
            elif self.fmri_type == 'divided_timeseries':
                if self.fmri_dividing_type == 'two_channels':
                    output_dict = self.model(input_dict['fmri_lowfreq_sequence'], input_dict['fmri_ultralowfreq_sequence'])
                elif self.fmri_dividing_type == 'three_channels':
                    output_dict = self.model(input_dict['fmri_highfreq_sequence'], input_dict['fmri_lowfreq_sequence'], input_dict['fmri_ultralowfreq_sequence'])
                elif self.fmri_dividing_type == 'four_channels':
                    output_dict = self.model(input_dict['fmri_imf1_sequence'], input_dict['fmri_imf2_sequence'], input_dict['fmri_imf3_sequence'], input_dict['fmri_imf4_sequence'])
                    
                    torch.cuda.synchronize()
                                
        return input_dict, output_dict

    def eval_epoch(self,set):  # evaluates the model for a single epoch
        loader = self.test_loader
        subset_indices = list(range(len(self.test_loader.dataset))) * self.num_forward_passes
        subset = Subset(self.test_loader.dataset, subset_indices)
        loader = DataLoader(subset, batch_size=8, shuffle=False, num_workers=0)
        subject_names = [data['subject_name'] for data in loader.dataset]

        self.eval(set)
        input_batches = []
        output_batches = []
        with torch.no_grad():
            for batch_idx, input_dict in enumerate(tqdm(loader, position=0, leave=True)):
                with autocast():
                    input_dict, output_dict = self.forward_pass(input_dict)
                    input_batches.append(input_dict)
                    output_batches.append(output_dict)

        self.finish_eval(set)
        return input_batches, output_batches

    def testing(self):  # manages the testing phase of the model
        # options = ['MC_dropout']
        roc_save_path = os.path.join(self.kwargs.get("experiment_folder"), 'roc_curve.png')
        stat_save_path = os.path.join(self.kwargs.get("experiment_folder"), 'statistics.txt')
        samp_stat_save_path = os.path.join(self.kwargs.get("experiment_folder"), 'sample_statistics.txt')
        if os.path.exists(roc_save_path):
            os.remove(roc_save_path)
        if os.path.exists(stat_save_path):
            os.remove(stat_save_path)
        if os.path.exists(samp_stat_save_path):
            os.remove(samp_stat_save_path)

        input_batches, output_batches = self.eval_epoch('MC_dropout')
        inputs = self.concat_batch_results(input_batches)
        outputs = self.concat_batch_results(output_batches)

        self.compute_accuracy(inputs, outputs)
        self.writer.accuracy_summary(mid_epoch=False, mean=None, std=None)
        self.writer.compute_confidence(self.writer.confidence_list, self.writer.is_correct_list)

NameError: name 'Trainer' is not defined

### main.py

In [4]:
from utils import *  #including 'init_distributed', 'weight_loader'
from trainer import Trainer
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import sys
from pathlib import Path

## YC : CHANGED
import torch
import torch.multiprocessing as mp


def get_arguments(base_path):
    """
    handle arguments from commandline.
    some other hyper parameters can only be changed manually (such as model architecture,dropout,etc)
    notice some arguments are global and take effect for the entire three phase training process, while others are determined per phase
    """
    parser = argparse.ArgumentParser()
    parser.add_argument('--exp_name', type=str,default="baseline") 
    parser.add_argument('--dataset_name', type=str, choices=['HCP1200', 'ABCD', 'ABIDE', 'UKB', 'ENIGMA_OCD'], default="ENIGMA_OCD")
    parser.add_argument('--fmri_type', type=str, choices=['timeseries', 'frequency', 'divided_timeseries', 'time_domain_low', 'time_domain_ultralow', 'time_domain_high', 'frequency_domain_low', 'frequency_domain_ultralow', 'frequency_domain_high'], default="divided_timeseries")
    parser.add_argument('--intermediate_vec', type=int, default=400)
    parser.add_argument('--abcd_path', default='/scratch/connectome/stellasybae/ABCD_ROI/7.ROI') ## labserver
    parser.add_argument('--ukb_path', default='/scratch/connectome/stellasybae/UKB_ROI') ## labserver
    parser.add_argument('--abide_path', default='/scratch/connectome/stellasybae/ABIDE_ROI') ## labserver
    parser.add_argument('--enigma_path', default='/pscratch/sd/p/pakmasha/MBBN_data') ## Perlmutter 
    parser.add_argument('--base_path', default=base_path) # where your main.py, train.py, model.py are in.
    parser.add_argument('--step', default='1', choices=['1','2','3','4'], help='which step you want to run') # YC : Step 1 : vanilla_BERT / Step 2 : MBBN / Step 3 : divfreqBERT_reconstruction / Step 4 : test
    
    
    parser.add_argument('--target', type=str, default='OCD')
    parser.add_argument('--fine_tune_task',
                        choices=['regression','binary_classification'],
                        help='fine tune model objective. choose binary_classification in case of a binary classification task')
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--visualization', action='store_true')
    parser.add_argument('--prepare_visualization', action='store_true')
    parser.add_argument('--weightwatcher', action='store_true')
    parser.add_argument('--weightwatcher_save_dir', default=None)

    
    
    parser.add_argument('--norm_axis', default=1, type=int, choices=[0,1,None])
    
    parser.add_argument('--cuda', default=True)
    parser.add_argument('--log_dir', type=str, default=os.path.join(base_path, 'runs'))

    parser.add_argument('--transformer_hidden_layers', type=int,default=8)
    
    # DDP configs:
    parser.add_argument('--world_size', default=-1, type=int, 
                        help='number of nodes for distributed training')
    parser.add_argument('--rank', default=-1, type=int, 
                        help='node rank for distributed training')
    parser.add_argument('--local_rank', default=-1, type=int, 
                        help='local rank for distributed training')
    parser.add_argument('--dist_backend', default='nccl', type=str, 
                        help='distributed backend')
    parser.add_argument('--init_method', default='file', type=str, choices=['file','env'], help='DDP init method')
    parser.add_argument('--distributed', default=True)

    # AMP configs:
    parser.add_argument('--amp', action='store_false')
    parser.add_argument('--gradient_clipping', action='store_true')
    parser.add_argument('--clip_max_norm', type=float, default=1.0)
    
    # Gradient accumulation
    parser.add_argument("--accumulation_steps", default=1, type=int,required=False,help='mini batch size == accumulation_steps * args.train_batch_size')
    
    # Nsight profiling
    parser.add_argument("--profiling", action='store_true')
    
    #wandb related
    parser.add_argument('--wandb_key', default='d0330ca06936eecd637c3470c47af6d33e1cb277', type=str,  help='default: key for ycryu')
    parser.add_argument('--wandb_mode', default='online', type=str,  help='online|offline')
    parser.add_argument('--wandb_entity', default='youngchanryu-seoul-national-university', type=str)
    parser.add_argument('--wandb_project', default='enigma-ocd_mbbn', type=str)

    
    # dividing
    parser.add_argument('--filtering_type', default='Boxcar', choices=['FIR', 'Boxcar'])
    parser.add_argument('--use_high_freq', action='store_true')
    parser.add_argument('--divide_by_lorentzian', action='store_true')
    parser.add_argument('--use_raw_knee', action='store_true')
    parser.add_argument('--seq_part', type=str, default='head')
    parser.add_argument('--fmri_dividing_type', default='three_channels', choices=['two_channels', 'three_channels', 'four_channels'])
    
    # Dropouts
    parser.add_argument('--transformer_dropout_rate', type=float, default=0.3) 

    # Architecture
    parser.add_argument('--num_heads', type=int, default=12,
                        help='number of heads for BERT network (default: 12)')
    parser.add_argument('--attn_mask', action='store_false',
                        help='use attention mask for Transformer (default: true)')
                        
    
    ## for finetune
    parser.add_argument('--pretrained_model_weights_path', default=None)
    parser.add_argument('--finetune', action='store_true')
    parser.add_argument('--finetune_test', action='store_true', help='test phase of finetuning task')
    
    
    ## spatiotemporal
    parser.add_argument('--spatiotemporal', action = 'store_true')
    parser.add_argument('--spat_diff_loss_type', type=str, default='minus_log', choices=['minus_log', 'reciprocal_log', 'exp_minus', 'log_loss', 'exp_whole'])
    parser.add_argument('--spatial_loss_factor', type=float, default=0.1)
    
    ## ablation
    parser.add_argument('--ablation', type=str, choices=['convolution', 'no_high_freq'])
    
    ## YC : Phase means step
    ## phase 1 vanilla BERT
    parser.add_argument('--task_phase1', type=str, default='vanilla_BERT')
    parser.add_argument('--batch_size_phase1', type=int, default=8, help='for DDP, each GPU processes batch_size_pahse1 samples')
    parser.add_argument('--validation_frequency_phase1', type=int, default=10000000)
    parser.add_argument('--nEpochs_phase1', type=int, default=2)  # initially, default=100
    parser.add_argument('--optim_phase1', default='AdamW')
    parser.add_argument('--weight_decay_phase1', type=float, default=1e-2)
    parser.add_argument('--lr_policy_phase1', default='SGDR', help='learning rate policy: step|SGDR')
    parser.add_argument('--lr_init_phase1', type=float, default=1e-3)
    parser.add_argument('--lr_gamma_phase1', type=float, default=0.97)
    parser.add_argument('--lr_step_phase1', type=int, default=3000)
    parser.add_argument('--lr_warmup_phase1', type=int, default=500)
    parser.add_argument('--sequence_length_phase1', type=int ,default=300) # ABCD 348 ABIDE 280 UKB 464
    parser.add_argument('--workers_phase1', type=int,default=4)
    parser.add_argument('--num_heads_2DBert', type=int, default=12)
    
    ## phase 2 MBBN
    parser.add_argument('--task_phase2', type=str, default='MBBN')
    parser.add_argument('--batch_size_phase2', type=int, default=8, help='for DDP, each GPU processes batch_size_pahse1 samples')
    parser.add_argument('--nEpochs_phase2', type=int, default=100)  # initially, default=100
    parser.add_argument('--optim_phase2', default='AdamW')
    parser.add_argument('--weight_decay_phase2', type=float, default=1e-2)
    parser.add_argument('--lr_policy_phase2', default='SGDR', help='learning rate policy: step|SGDR')
    parser.add_argument('--lr_init_phase2', type=float, default=1e-3)
    parser.add_argument('--lr_gamma_phase2', type=float, default=0.97)
    parser.add_argument('--lr_step_phase2', type=int, default=3000)
    parser.add_argument('--lr_warmup_phase2', type=int, default=500)
    parser.add_argument('--sequence_length_phase2', type=int ,default=300) # ABCD 348 ABIDE 280 UKB 464
    parser.add_argument('--workers_phase2', type=int, default=4)   # default=4
    
    ##phase 3 pretraining
    parser.add_argument('--task_phase3', type=str, default='MBBN_pretraining')
    parser.add_argument('--batch_size_phase3', type=int, default=8, help='for DDP, each GPU processes batch_size_pahse1 samples')
    parser.add_argument('--validation_frequency_phase3', type=int, default=10000000)
    parser.add_argument('--nEpochs_phase3', type=int, default=1000)
    parser.add_argument('--optim_phase3', default='AdamW')
    parser.add_argument('--weight_decay_phase3', type=float, default=1e-2)
    parser.add_argument('--lr_policy_phase3', default='SGDR', help='learning rate policy: step|SGDR')
    parser.add_argument('--lr_init_phase3', type=float, default=1e-3)
    parser.add_argument('--lr_gamma_phase3', type=float, default=0.97)
    parser.add_argument('--lr_step_phase3', type=int, default=3000)
    parser.add_argument('--lr_warmup_phase3', type=int, default=500)
    parser.add_argument('--sequence_length_phase3', type=int ,default=300)
    parser.add_argument('--workers_phase3', type=int,default=4)
    parser.add_argument('--use_recon_loss', action='store_true')
    parser.add_argument('--use_mask_loss', action='store_true') 
    parser.add_argument('--use_cont_loss', action='store_true')
    parser.add_argument('--masking_rate', type=float, default=0.1)
    parser.add_argument('--masking_method', type=str, default='spatiotemporal', choices=['temporal', 'spatial', 'spatiotemporal'])
    parser.add_argument('--temporal_masking_type', type=str, default='time_window', choices=['single_point','time_window'])
    parser.add_argument('--temporal_masking_window_size', type=int, default=20)
    parser.add_argument('--window_interval_rate', type=int, default=2)
    parser.add_argument('--spatial_masking_type', type=str, default='random_ROIs', choices=['hub_ROIs', 'random_ROIs'])
    parser.add_argument('--communicability_option', type=str, default='remove_high_comm_node', choices=['remove_high_comm_node', 'remove_low_comm_node'])
    parser.add_argument('--num_hub_ROIs', type=int, default=5)
    parser.add_argument('--num_random_ROIs', type=int, default=5)
    parser.add_argument('--spatiotemporal_masking_type', type=str, default='whole', choices=['whole', 'separate'])
    
    
    ## phase 4 (test)
    parser.add_argument('--task_phase4', type=str, default='test')
    parser.add_argument('--model_weights_path_phase4', default=None)
    parser.add_argument('--batch_size_phase4', type=int, default=4)
    parser.add_argument('--nEpochs_phase4', type=int, default=1)
    parser.add_argument('--optim_phase4', default='AdamW')
    parser.add_argument('--weight_decay_phase4', type=float, default=1e-2)
    parser.add_argument('--lr_policy_phase4', default='SGDR', help='learning rate policy: step|SGDR')
    parser.add_argument('--lr_init_phase4', type=float, default=1e-4)
    parser.add_argument('--lr_gamma_phase4', type=float, default=0.9)
    parser.add_argument('--lr_step_phase4', type=int, default=3000)
    parser.add_argument('--lr_warmup_phase4', type=int, default=100)
    parser.add_argument('--sequence_length_phase4', type=int,default=300) # ABCD 348 ABIDE 280 UKB 464
    parser.add_argument('--workers_phase4', type=int, default=4)
                        
    ## Uncertainty Quantification
    ## YC : CHANGED
    parser.add_argument('--UQ', action='store_true')
    parser.add_argument('--UQ_method', type=str, default='none', choices=['MC_dropout', 'ensemble'])
    parser.add_argument('--num_forward_passes', type=int, default=0) # for MC_dropout
    parser.add_argument('--num_ensemble_models', type=int, default=0) # for ensemble, should use same number when training and testing
    parser.add_argument('--ensemble_models_per_gpu', type=int, default=1)
    parser.add_argument('--UQ_model_weights_path', default=None)

    args = parser.parse_args()
        
    return args

def setup_folders(base_path): 
    os.makedirs(os.path.join(base_path,'experiments'),exist_ok=True) 
    os.makedirs(os.path.join(base_path,'runs'),exist_ok=True)
    os.makedirs(os.path.join(base_path, 'splits'), exist_ok=True)
    return None


def run_phase(args,loaded_model_weights_path,phase_num,phase_name, model_idx = None):
    experiment_folder = '{}_{}_{}_{}'.format(args.dataset_name,phase_name,args.target,args.exp_name)
    experiment_folder = Path(os.path.join(args.base_path,'experiments',experiment_folder))
    os.makedirs(experiment_folder, exist_ok=True)
    setattr(args,'loaded_model_weights_path_phase' + phase_num,loaded_model_weights_path)
    args.experiment_folder = experiment_folder
    args.experiment_title = experiment_folder.name
    
    print(f'saving the results at {args.experiment_folder}')
    
    # save hyperparameters
    args_logger(args)
    
    # make args to dict. + detach phase numbers from args
    kwargs = sort_args(phase_num, vars(args))
    if args.prepare_visualization:
        S = ['train','val']
    else:
        S = ['train','val','test']

    trainer = Trainer(sets=S,model_idx=model_idx,**kwargs)
    trainer.training()

    #S = ['train','val']

    if phase_num == '3' and not fine_tune_task == 'regression':
        critical_metric = 'accuracy'
    else:
        critical_metric = 'loss'
    model_weights_path = os.path.join(trainer.writer.experiment_folder,trainer.writer.experiment_title + '_BEST_val_{}.pth'.format(critical_metric)) 

    return model_weights_path


## YC : CHANGED
def test(args,phase_num,model_weights_path):
    UQ = args.UQ
    UQ_method = args.UQ_method
    print(f"UQ : {UQ} / UQ_method : {UQ_method}")
    
    experiment_folder = '{}_{}_{}'.format(args.dataset_name, 'test_{}'.format(args.fine_tune_task), args.exp_name) #, datestamp())
    experiment_folder = Path(os.path.join(args.base_path,'tests', experiment_folder))
    os.makedirs(experiment_folder,exist_ok=True)
    
    args.experiment_folder = experiment_folder
    args.experiment_title = experiment_folder.name

    if UQ:
        S = [UQ_method]
        if UQ_method == 'MC_dropout':
            # YC : Retrieve the last checkpoint from directory
            file_name_and_time_lst = []
            for f_name in os.listdir(model_weights_path):
                if f_name.endswith('.pth'):
                    written_time = os.path.getctime(os.path.join(model_weights_path,f_name))
                    file_name_and_time_lst.append((f_name, written_time))
            # Backward order of file creation time
            sorted_file_lst = sorted(file_name_and_time_lst, key=lambda x: x[1], reverse=True)

            if len(sorted_file_lst) == 0:
                raise Exception('No model weights found')
            loaded_model_weights_path = os.path.join(model_weights_path,sorted_file_lst[0][0])
            setattr(args,'loaded_model_weights_path_phase' + phase_num, loaded_model_weights_path)
            args_logger(args)
            args = sort_args(args.step, vars(args))
            trainer = UQTrainer(sets=S,**args)

    else:
        # YC : Retrieve the most recent checkpoint from directory
        file_name_and_time_lst = []
        for f_name in os.listdir(model_weights_path):
            if f_name.endswith('.pth'):
                written_time = os.path.getctime(os.path.join(model_weights_path,f_name))
                file_name_and_time_lst.append((f_name, written_time))
        # Backward order of file creation time
        sorted_file_lst = sorted(file_name_and_time_lst, key=lambda x: x[1], reverse=True)

        if len(sorted_file_lst) == 0:
            raise Exception('No model weights found')
        loaded_model_weights_path = os.path.join(model_weights_path,sorted_file_lst[0][0])
        setattr(args,'loaded_model_weights_path_phase' + phase_num, loaded_model_weights_path)
        S = ['test']
        args_logger(args)
        args = sort_args(args.step, vars(args))
        trainer = Trainer(sets=S,**args)
    
    trainer.testing()
    

## YC : CHANGED
# if __name__ == '__main__':
def main():
    base_path = os.getcwd() 
    setup_folders(base_path) 
    args = get_arguments(base_path)

    # UQ condition check
    if args.UQ:
        if args.UQ_method == 'none':
            raise Exception('UQ method is not specified')
        elif args.UQ_method == 'MC_dropout':
            if args.num_forward_passes == 0:
                raise Exception('num_forward_passes is not specified')
            elif args.num_ensemble_models != 0:
                raise Exception('num_ensemble_models should not be set for MC_dropout')
            if args.step != '4':
                raise Exception('MC_dropout is only available for testing')
        elif args.UQ_method == 'ensemble':
            if args.num_ensemble_models == 0:
                raise Exception('num_ensemble_models is not specified')
            elif args.num_forward_passes != 0:
                raise Exception('num_forward_passes should not be set for ensemble')
        
        print(f'UQ enabled - method : {args.UQ_method} | step : {args.step}')
        if args.UQ_method == 'ensemble':
            print(f'num_ensemble_models : {args.num_ensemble_models}')
            if args.step == '2':
                args.distributed = False
                print('distributed set False due to manual distributed setting in ensemble method')
        elif args.UQ_method == 'MC_dropout':
            print(f'num_forward_passes : {args.num_forward_passes}')

    # DDP initialization
    if not (args.step == '2' and args.UQ):
        init_distributed(args)

    # load weights that you specified at the Argument
    model_weights_path, step, task = weight_loader(args)

    if step == '4' :
        print(f'starting testing')
        phase_num = '4'
        if args.UQ:
            model_weights_path = args.UQ_model_weights_path
        test(args, phase_num, model_weights_path)
    else:
        print(f'starting phase{step}: {task}')
        if args.UQ and args.UQ_method == 'ensemble':
            if args.UQ_model_weights_path is not None:
                model_weights_path = args.UQ_model_weights_path
                print(f'UQ ensemble model weights loaded from {model_weights_path}')    
            mp.set_start_method("spawn", force=True)
            run_disributed_phase(args,model_weights_path,step,task)
        else:
            run_phase(args,model_weights_path,step,task)
        print(f'finishing phase{step}: {task}')


random seed for torch and np at utils.reproducibility()

for dataset, use         self.seed = kwargs.get('seed')  # random seed for reproducibility at dataloaders.py 
at train_test_split from sklearn. 

So I can use same seed for dataset split, and use different seed for model training!

In [5]:
import os
from pathlib import Path
import torch
import torch.multiprocessing as mp
"""
def run_phase(args,loaded_model_weights_path,phase_num,phase_name):
    experiment_folder = '{}_{}_{}_{}'.format(args.dataset_name,phase_name,args.target,args.exp_name)
    experiment_folder = Path(os.path.join(args.base_path,'experiments',experiment_folder))
    os.makedirs(experiment_folder, exist_ok=True)
    setattr(args,'loaded_model_weights_path_phase' + phase_num,loaded_model_weights_path)
    args.experiment_folder = experiment_folder
    args.experiment_title = experiment_folder.name
    
    print(f'saving the results at {args.experiment_folder}')
    
    # save hyperparameters
    args_logger(args)
    
    # make args to dict. + detach phase numbers from args
    kwargs = sort_args(phase_num, vars(args))
    if args.prepare_visualization:
        S = ['train','val']
    else:
        S = ['train','val','test']

    trainer = Trainer(sets=S,**kwargs)
    trainer.training()

    #S = ['train','val']

    if phase_num == '3' and not fine_tune_task == 'regression':
        critical_metric = 'accuracy'
    else:
        critical_metric = 'loss'
    model_weights_path = os.path.join(trainer.writer.experiment_folder,trainer.writer.experiment_title + '_BEST_val_{}.pth'.format(critical_metric)) 

    return model_weights_path
"""
def train_single_model(args, loaded_model_weights_path, phase_num, phase_name, model_idx, device_id):
    # Set the current GPU for this process
    torch.cuda.set_device(device_id)
    # Optionally update args with the device info so Trainer uses the correct device.
    args.device = f"cuda:{device_id}"
    print(f"Starting training for model {model_idx} on GPU {device_id}")
    # Call the original run_phase which trains the model.
    model_path = run_phase(args, loaded_model_weights_path, phase_num, phase_name, model_idx)
    print(f"Completed training for model {model_idx}, saved to {model_path}")

def run_disributed_phase(args,loaded_model_weights_path,phase_num,phase_name):
        # torchrun: sbatch script에서 WORLD_SIZE를 지정해준 경우 (노드 당 gpu * 노드의 수)
    if "WORLD_SIZE" in os.environ: # for torchrun
        args.world_size = int(os.environ["WORLD_SIZE"])
        #print('args.world_size:',args.world_size)
    elif 'SLURM_NTASKS' in os.environ: # for slurm scheduler
        args.world_size = int(os.environ['SLURM_NTASKS'])
    else:
        pass # torch.distributed.launch
        
    args.distributed = args.world_size > 1 # default: world_size = -1 
    
    num_gpus = torch.cuda.device_count()

    ### DEBUG STATEMENT ###
    print(f'world_size: {args.world_size}')
    print(f'distributed: {args.distributed}')
    print(f'num_gpus: {num_gpus}')
    #######################
    
    # Determine how many ensemble models to train concurrently per GPU.
    # For instance, if you want two models per GPU at a time, then:
    models_per_gpu = args.ensemble_models_per_gpu
    concurrent_models = num_gpus * models_per_gpu  # e.g. 4 GPUs * 2 = 8 models concurrently.

    # List all ensemble model indices (for example: [0, 1, 2, ..., args.num_ensemble_models-1])
    ensemble_indices = list(range(args.num_ensemble_models))

    ### DEBUG STATEMENT ###
    print(f'models_per_gpu: {models_per_gpu}')
    print(f'concurrent_models: {concurrent_models}')
    print(f'ensemble_indices: {ensemble_indices}')
    #######################

    # Iterate over ensemble indices in batches of 'concurrent_models'
    for batch_start in range(0, args.num_ensemble_models, concurrent_models):
        processes = []
        batch_indices = ensemble_indices[batch_start: batch_start + concurrent_models]
        print(f"#Training batch models: {batch_indices}")
        for slot, model_idx in enumerate(batch_indices):
            # For assignment, rotate across GPUs. Adjust if you want a different scheduling.
            device_id = slot % num_gpus
            print(f"##slot: {slot} / device_id: {device_id} / model_idx: {model_idx}")
            p = mp.Process(
                target=train_single_model,
                args=(args, loaded_model_weights_path, phase_num, phase_name, model_idx, device_id)
            )
            p.start()
            processes.append(p)
        # Wait for this batch of models to finish training.
        for p in processes:
            p.join()
        print(f"Finished training batch models: {batch_indices}")

    # Optionally, you can collect or post-process the model weight paths here.
    print("All ensemble models trained.")


sys.argv = ['main.py', '--dataset_name', 'ENIGMA_OCD', '--base_path', '/pscratch/sd/y/ycryu/ENIGMA_OCD_MBBN/MBBN-main', '--enigma_path', '/pscratch/sd/y/ycryu/MBBN_data_mini', '--step', '2', '--batch_size_phase2', '8', '--lr_init_phase2', '3e-5', '--lr_policy_phase2', 'step', '--workers_phase2', '8', '--fine_tune_task', 'binary_classification', '--target', 'OCD', '--fmri_type', 'divided_timeseries', '--transformer_hidden_layers', '8', '--divide_by_lorentzian', '--seq_part', 'head', '--use_raw_knee', '--fmri_dividing_type', 'three_channels', '--use_high_freq', '--spatiotemporal', '--spat_diff_loss_type', 'minus_log', '--spatial_loss_factor', '4.0', '--exp_name', 'from_scratch_seed101', '--seed', '101', '--sequence_length_phase2', '100', '--intermediate_vec', '316', '--nEpochs_phase2', '100', '--num_heads', '4', '--UQ', '--UQ_method', 'ensemble', '--num_ensemble_models', '16', '--UQ_model_weights_path', '/scratch/connectome/ycryu/ENIGMA_OCD_MBBN/MBBN-main/experiments/ENIGMA_OCD_mbbn_OCD_from_scratch_seed101']

filtered_args = [arg for arg in base_args if arg not in ['--UQ']]
['main.py', '--dataset_name', 'ENIGMA_OCD', '--base_path', '/pscratch/sd/y/ycryu/ENIGMA_OCD_MBBN/MBBN-main', '--enigma_path', '/pscratch/sd/y/ycryu/MBBN_data_mini', '--step', '2', '--batch_size_phase2', '8', '--lr_init_phase2', '3e-5', '--lr_policy_phase2', 'step', '--workers_phase2', '8', '--fine_tune_task', 'binary_classification', '--target', 'OCD', '--fmri_type', 'divided_timeseries', '--transformer_hidden_layers', '8', '--divide_by_lorentzian', '--seq_part', 'head', '--use_raw_knee', '--fmri_dividing_type', 'three_channels', '--use_high_freq', '--spatiotemporal', '--spat_diff_loss_type', 'minus_log', '--spatial_loss_factor', '4.0', '--exp_name', 'from_scratch_seed101', '--seed', '101', '--sequence_length_phase2', '100', '--intermediate_vec', '316', '--nEpochs_phase2', '100', '--num_heads', '4', '--UQ_method', 'ensemble', '--num_ensemble_models', '16', '--UQ_model_weights_path', '/scratch/connectome/ycryu/ENIGMA_OCD_MBBN/MBBN-main/experiments/ENIGMA_OCD_mbbn_OCD_from_scratch_seed101']

##### dealing with Args

    print(args)
    args.new_seed = 1010
    print(args)
    delattr(args,'new_seed')
    print(args)

In [6]:
sys.argv = [
    'main.py', '--dataset_name', 'ENIGMA_OCD', '--batch_size_phase2', '8', '--lr_init_phase2', '3e-5', 
    '--lr_policy_phase2', 'step', '--workers_phase2', '8', '--fine_tune_task', 'binary_classification', '--target', 'OCD', 
    '--fmri_type', 'divided_timeseries', '--transformer_hidden_layers', '8', '--divide_by_lorentzian', '--seq_part', 'head', 
    '--use_raw_knee', '--fmri_dividing_type', 'three_channels', '--use_high_freq', '--spatiotemporal', '--spat_diff_loss_type', 'minus_log', 
    '--spatial_loss_factor', '4.0', '--sequence_length_phase2', '100', 
    '--intermediate_vec', '316', '--nEpochs_phase2', '100', '--num_heads', '4', 
    '--base_path', '/pscratch/sd/y/ycryu/ENIGMA_OCD_MBBN/MBBN-main', 
    '--enigma_path', '/pscratch/sd/y/ycryu/MBBN_data_mini', 
    '--exp_name', 'ensemble_training_from_scratch_seed101',
    '--seed', '101', 
    '--step', '2', 
    '--UQ', 
    '--UQ_method', 'ensemble', 
    '--num_ensemble_models', '4', 
    '--ensemble_models_per_gpu', '2',
]



# sys.argv = [
#     'main.py', '--dataset_name', 'ENIGMA_OCD', '--fine_tune_task', 'binary_classification', '--target', 'OCD',
#     '--fmri_type', 'divided_timeseries', '--transformer_hidden_layers', '8', '--divide_by_lorentzian', '--seq_part', 'head',
#     '--use_raw_knee', '--fmri_dividing_type', 'three_channels', '--use_high_freq', '--spatiotemporal', '--spat_diff_loss_type', 'minus_log',
#     '--spatial_loss_factor', '4.0', '--intermediate_vec', '316', '--num_heads', '4', '--sequence_length_phase4', '100', '--lr_warmup_phase4', '1', '--workers_phase4', '1',
#     '--wandb_mode', 'disabled',
#     '--base_path', '/pscratch/sd/y/ycryu/ENIGMA_OCD_MBBN/MBBN-main', 
#     '--enigma_path', '/pscratch/sd/y/ycryu/MBBN_data_mini',
#     '--exp_name', 'test_evaluation_seed101', 
#     '--seed', '101', 
#     '--step', '4',
#     '--UQ', 
#     '--UQ_method', 'ensemble', 
#     '--num_ensemble_models', '16', 
#     '--UQ_model_weights_path', '/pscratch/sd/y/ycryu/ENIGMA_OCD_MBBN/MBBN-main/experiments/ENIGMA_OCD_mbbn_OCD_from_scratch_seed101_1gpu_perlmutter',
# ]



# sys.argv = [
#     'main.py', '--dataset_name', 'ENIGMA_OCD', '--fine_tune_task', 'binary_classification', '--target', 'OCD',
#     '--fmri_type', 'divided_timeseries', '--transformer_hidden_layers', '8', '--divide_by_lorentzian', '--seq_part', 'head',
#     '--use_raw_knee', '--fmri_dividing_type', 'three_channels', '--use_high_freq', '--spatiotemporal', '--spat_diff_loss_type', 'minus_log',
#     '--spatial_loss_factor', '4.0', '--intermediate_vec', '316', '--num_heads', '4', '--sequence_length_phase4', '100', '--lr_warmup_phase4', '1', '--workers_phase4', '1',
#     '--wandb_mode', 'disabled',
#     '--base_path', '/pscratch/sd/y/ycryu/ENIGMA_OCD_MBBN/MBBN-main', 
#     '--enigma_path', '/pscratch/sd/y/ycryu/MBBN_data_mini',
#     '--exp_name', 'test_evaluation_seed101', 
#     '--seed', '101', 
#     '--step', '4',
#     '--UQ', 
#     '--UQ_method', 'MC_dropout', 
#     '--num_forward_pass', '16', 
#     '--UQ_model_weights_path', '/pscratch/sd/y/ycryu/ENIGMA_OCD_MBBN/MBBN-main/experiments/ENIGMA_OCD_mbbn_OCD_from_scratch_seed101',
# ]


In [7]:
main()
## Should set distributed = False when training ensemble

UQ enabled - method : ensemble | step : 2
num_ensemble_models : 4
distributed set False due to manual distributed setting in ensemble method
starting phase2: mbbn
world_size: -1
distributed: False
num_gpus: 1
models_per_gpu: 2
concurrent_models: 2
ensemble_indices: [0, 1, 2, 3]
#Training batch models: [0, 1]
##slot: 0 / device_id: 0 / model_idx: 0
##slot: 1 / device_id: 0 / model_idx: 1


Finished training batch models: [0, 1]
#Training batch models: [2, 3]
##slot: 0 / device_id: 0 / model_idx: 2


Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/global/homes/y/ycryu/.conda/envs/mbbn-env/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/global/homes/y/ycryu/.conda/envs/mbbn-env/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/global/homes/y/ycryu/.conda/envs/mbbn-env/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    exitcode = _main(fd, parent_sentinel)
  File "/global/homes/y/ycryu/.conda/envs/mbbn-env/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'train_single_model' on <module '__main__' (built-in)>
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'train_single_model' on <module '__main__' (built-in)>


##slot: 1 / device_id: 0 / model_idx: 3
Finished training batch models: [2, 3]
All ensemble models trained.
finishing phase2: mbbn


Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/global/homes/y/ycryu/.conda/envs/mbbn-env/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/global/homes/y/ycryu/.conda/envs/mbbn-env/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'train_single_model' on <module '__main__' (built-in)>
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/global/homes/y/ycryu/.conda/envs/mbbn-env/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/global/homes/y/ycryu/.conda/envs/mbbn-env/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'train_single_model' on <module '__main__' (built-in)>


In [14]:
# import os
# import sys
# import subprocess

# def launch_ensemble_processes(gpu_ids, base_seed=42):
#     # Remove any argument you don't want the children to see, if needed.
#     # For example, if you want the child processes to know they're not responsible for launching ensembles,
#     # you might remove the `--method ensemble` flag.
#     # One simple way is to filter sys.argv (or use parse_known_args) if needed.
#     base_args = sys.argv[1:]  # All original command-line arguments except the script name.
    
#     # Optionally, filter out ensemble-specific arguments if you don't want children to spawn further ensembles.
#     # For example:
#     filtered_args = [arg for arg in base_args if arg not in ['--UQ']]
    
#     processes = []
#     for idx, gpu in enumerate(gpu_ids):
#         seed = base_seed + idx
#         # Build the new command: start with 'python main.py' then the filtered arguments
#         # then add GPU and seed.
#         cmd = ["python", "main.py"] + filtered_args + ["--gpu", str(gpu), "--seed", str(seed)]
#         print("Launching subprocess with command:", " ".join(cmd))
#         proc = subprocess.Popen(cmd)
#         processes.append(proc)
    
#     # Optionally wait for all processes to finish:
#     for proc in processes:
#         proc.wait()

# if __name__ == '__main__':
#     # Setup folders, etc.
#     base_path = os.getcwd()
#     setup_folders(base_path)
    
#     # Parse your arguments
#     args = get_arguments(base_path)
    
#     # Check if ensemble mode is requested:
#     if args.method == 'ensemble':
#         # Specify the list of GPU ids you want to use
#         gpu_ids = [0, 1, 2, 3]  # for example, adjust based on your node
#         launch_ensemble_processes(gpu_ids)
#         # Optionally, exit the parent process if it is only responsible for spawning children.
#         exit()
    
#     # Continue with the rest of your training/inference code if not in ensemble mode
#     ...

(mbbn-env) ycryu@login33:/pscratch/sd/y/ycryu/ENIGMA_OCD_MBBN/MBBN-main/yc/ensemble> python main.py --dataset_name ENIGMA_OCD --base_path /pscratch/sd/y/ycryu/ENIGMA_OCD_MBBN/MBBN-main --enigma_path /pscratch/sd/y/ycryu/MBBN_data_mini --step 2 --batch_size_phase2 8 --lr_init_phase2 3e-5 --lr_policy_phase2 step --workers_phase2 8 --fine_tune_task binary_classification --target OCD --fmri_type divided_timeseries --transformer_hidden_layers 8 --divide_by_lorentzian --seq_part head --use_raw_knee --fmri_dividing_type three_channels --use_high_freq --spatiotemporal --spat_diff_loss_type minus_log --spatial_loss_factor 4.0 --exp_name ensemble_training_from_scratch_seed101 --seed 101 --sequence_length_phase2 100 --intermediate_vec 316 --nEpochs_phase2 10 --num_heads 4 --UQ --UQ_method ensemble --num_ensemble_models 4 --ensemble_models_per_gpu 2  2> /pscratch/sd/y/ycryu/ENIGMA_OCD_MBBN/MBBN-main/failed_experiments/enigma_ocd_error_from_scratch_seed101.log
UQ enabled - method : ensemble | step : 2
num_ensemble_models : 4
distributed set False due to manual distributed setting in ensemble method
DEBUG : args.distributed : False / args.rank : 0 / args.local_rank : -1 / args.world_size : -1 / args.gpu : 0
starting phase2: mbbn
world_size: -1
distributed: False
num_gpus: 1
models_per_gpu: 2
concurrent_models: 2
ensemble_indices: [0, 1, 2, 3]
#Training batch models: [0, 1]
##slot: 0 / device_id: 0 / model_idx: 0
##slot: 1 / device_id: 0 / model_idx: 1
Starting training for model 0 on GPU 0
saving the results at /pscratch/sd/y/ycryu/ENIGMA_OCD_MBBN/MBBN-main/experiments/ENIGMA_OCD_mbbn_OCD_ensemble_training_from_scratch_seed101/model_0
DEBUG : seed for ensemble model 0 is 101
Starting training for model 1 on GPU 0
saving the results at /pscratch/sd/y/ycryu/ENIGMA_OCD_MBBN/MBBN-main/experiments/ENIGMA_OCD_mbbn_OCD_ensemble_training_from_scratch_seed101/model_1
DEBUG : seed for ensemble model 1 is 102
generating splits...
generating step 1
generating splits...
Number of subjects used for training: 112
generating step 2
generating step 1
Training set class distribution: {1: 39, 2: 39}
Validation set class distribution: {1: 9, 2: 8}
Test set class distribution: {1: 9, 2: 8}
distribution seed: 101
generating step 3.. saving splits...
Number of subjects used for training: 112
generating step 2
Training set class distribution: {1: 39, 2: 39}
Validation set class distribution: {1: 9, 2: 8}
Test set class distribution: {1: 9, 2: 8}
distribution seed: 101
generating step 3.. saving splits...
Finished training batch models: [0, 1]
#Training batch models: [2, 3]
##slot: 0 / device_id: 0 / model_idx: 2
##slot: 1 / device_id: 0 / model_idx: 3
Starting training for model 2 on GPU 0
saving the results at /pscratch/sd/y/ycryu/ENIGMA_OCD_MBBN/MBBN-main/experiments/ENIGMA_OCD_mbbn_OCD_ensemble_training_from_scratch_seed101/model_2
DEBUG : seed for ensemble model 2 is 103
Starting training for model 3 on GPU 0
saving the results at /pscratch/sd/y/ycryu/ENIGMA_OCD_MBBN/MBBN-main/experiments/ENIGMA_OCD_mbbn_OCD_ensemble_training_from_scratch_seed101/model_3
DEBUG : seed for ensemble model 3 is 104
generating splits...
generating step 1
generating splits...
generating step 1
Number of subjects used for training: 112
generating step 2
Training set class distribution: {1: 39, 2: 39}
Validation set class distribution: {1: 9, 2: 8}
Test set class distribution: {1: 9, 2: 8}
distribution seed: 101
generating step 3.. saving splits...
Number of subjects used for training: 112
generating step 2
Training set class distribution: {1: 39, 2: 39}
Validation set class distribution: {1: 9, 2: 8}
Test set class distribution: {1: 9, 2: 8}
distribution seed: 101
generating step 3.. saving splits...
Finished training batch models: [2, 3]
All ensemble models trained.
finishing phase2: mbbn