In [1]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

### 참고로 쌤이 run 할떄 쓰던 script
```
#!/bin/bash

#SBATCH --job-name ADNI_eval  #job name을 다르게 하기 위해서
#SBATCH --nodes=1
#SBATCH --nodelist=node3 #used node4
#SBATCH -t 10:00:00 # Time for running job #길게 10일넘게 잡음
#SBATCH -o ./output/ADCN_yaware_epoch99_noCV.txt
#SBATCH -e ./error/error_%J.error
#SBATCH --ntasks=1
#SBATCH --mem-per-cpu=20000MB
#SBATCH --gpus=2
#SBATCH --cpus-per-task=8

######################################################################################
source /home/connectome/mieuxmin/.bashrc

echo "Used (y-Aware_Contrastive_MRI_epoch_99) for pretrained weights."
#echo "Used self.model='DenseNet', self.nb_epochs=100, self.tf='cutout', self.batch_size = 8"
echo "self.input_size=(1,80,80,80)<- 얘는 그렇고, 바꾸기 전 버전임., self.lr=1e-4, self.weight_decay=5e-5, self.patience=20"
echo ""
echo "--train_num 100 & --layer_control tune_all"
python3 main.py --pretrained_path ./weights/y-Aware_Contrastive_MRI_epoch_99.pth --mode finetuning --train_num 100 --task_name AD/MCI --layer_control tune_all --stratify balan --random_seed 0
echo "--train_num 100 & --layer_control freeze"
python3 main.py --pretrained_path ./weights/y-Aware_Contrastive_MRI_epoch_99.pth --mode finetuning --train_num 100 --task_name AD/MCI --layer_control freeze --stratify balan --random_seed 0
echo "--train_num 300 & --layer_control tune_all"
python3 main.py --pretrained_path ./weights/y-Aware_Contrastive_MRI_epoch_99.pth --mode finetuning --train_num 300 --task_name AD/MCI --layer_control tune_all --stratify balan --random_seed 0
echo "--train_num 300 & --layer_control freeze"
python3 main.py --pretrained_path ./weights/y-Aware_Contrastive_MRI_epoch_99.pth --mode finetuning --train_num 300 --task_name AD/MCI --layer_control freeze --stratify balan --random_seed 0
```

> our version : 


```bash

weight_pth=xXXX
layer_control=tune_all #freeze
task=XXX

python3 finetune.py --pretrained_path $weight_pth --mode finetuning --train_num,  ###적기 

```


### 밑 : the dataset.py that junbeom's code uses... let's try to modify ourse to fit this!

In [2]:
from cmath import nan

PRETRAINING = 0
FINE_TUNING = 1

class Config:

    def __init__(self, mode):
        assert mode in {PRETRAINING, FINE_TUNING}, "Unknown mode: %i"%mode

        self.mode = mode

        if self.mode == PRETRAINING:
            self.batch_size = 8 # ADNI
            self.nb_epochs_per_saving = 1
            self.pin_mem = True
            self.num_cpu_workers = 8
            self.nb_epochs = 100 # ADNI #####
            self.cuda = True
            # Optimizer
            self.lr = 1e-4
            self.weight_decay = 5e-5
            # Hyperparameters for our y-Aware InfoNCE Loss
            self.temperature = 0.1
            self.tf = 'cutout' # ADNI
            self.model = 'DenseNet' # 'UNet'
            ### ADNI
            self.data = './adni_t1s_baseline' # ADNI
            self.label = './csv/fsdat_baseline_CN.csv' # ADNI
            self.valid_ratio = 0.25 # ADNI (valid set ratio compared to training set)
            self.input_size = (1, 80, 80, 80) # ADNI #####
            
            self.label_name = ['PTAGE', 'PTGENDER'] # ADNI
            self.label_type = ['cont', 'cat'] # ADNI
            self.cat_similarity = [nan, 0] # similarity for mismatched categorical meta-data. set nan for continuous meta-data
            self.alpha_list = [0.5, 0.5] # ADNI # sum = 1
            self.sigma = [5, 5] # ADNI # depends on the meta-data at hand
            
            self.checkpoint_dir = './ckpts' # ADNI
            self.patience = 20 # ADNI

        elif self.mode == FINE_TUNING:
            ## We assume a classification task here
            self.batch_size = 8
            self.nb_epochs_per_saving = 10
            self.pin_mem = True
            self.num_cpu_workers = 1
            self.nb_epochs = 100 # ADNI #####
            self.cuda = True
            # Optimizer
            self.lr = 1e-4
            self.weight_decay = 5e-5
            self.tf = 'cutout' # ADNI
            self.model = 'DenseNet' # 'UNet'
            ### ADNI
            self.data = '/scratch/connectome/study_group/VAE_ADHD/data' #'./adni_t1s_baseline' # ADNI
            self.label = '/scratch/connectome/dyhan316/VAE_ADHD/junbeom_finetuning/csv/fsdat_baseline.csv' # ADNI
            self.valid_ratio = 0.25 # ADNI (valid set ratio compared to training set)
            self.input_size = (1, 80, 80, 80) # ADNI

            self.task_type = 'cls' # ADNI # 'cls' or 'reg' #####
            self.label_name = 'Dx.new' # ADNI # `Dx.new` #####
            self.num_classes = 2 # ADNI - AD vs CN or MCI vs CN or AD vs MCI or reg #####

            #self.pretrained_path = './weights/BHByAa64c.pth' # ADNI #####
            #self.layer_control = 'tune_all' # ADNI # 'freeze' or 'tune_diff' (whether to freeze pretrained layers or not) #####
            self.patience = 20 # ADNI



In [3]:
#some things to use 
from torch.utils.data import Dataset
import numpy as np
import pandas as pd
import torch
from augmentations import Transformer, Crop, Cutout, Noise, Normalize, Blur, Flip
### ADNI
import os
import nibabel as nib
from skimage.transform import resize

class ADNI_Dataset(Dataset):

    def __init__(self, config, data_csv, data_type, *args, **kwargs): # ADNI
        super().__init__(*args, **kwargs)
        
        ### ADNI
        self.config = config
        self.data_type = data_type
        self.transforms = Transformer()
        self.transforms.register(Normalize(), probability=1.0)

        if self.data_type == 'train' or self.config.mode == 0:
            if self.config.tf == "all_tf":
                self.transforms.register(Flip(), probability=0.5)
                self.transforms.register(Blur(sigma=(0.1, 1)), probability=0.5)
                self.transforms.register(Noise(sigma=(0.1, 1)), probability=0.5)
                self.transforms.register(Cutout(patch_size=np.ceil(np.array(self.config.input_size)/4)), probability=0.5)
                self.transforms.register(Crop(np.ceil(0.75*np.array(self.config.input_size)), "random", resize=True),
                                        probability=0.5)

            elif self.config.tf == "cutout":
                self.transforms.register(Cutout(patch_size=np.ceil(np.array(self.config.input_size)/4)), probability=1)

            elif self.config.tf == "crop":
                self.transforms.register(Crop(np.ceil(0.75*np.array(self.config.input_size)), "random", resize=True),
                                        probability=1)
        
        self.data_dir = self.config.data ###CHANGED### #'./adni_t1s_baseline'
        self.data_csv = data_csv
        self.files = [x for x in os.listdir(self.data_dir) if x[4:12] in list(self.data_csv['SubjectID'])]
        ###
        
    def collate_fn(self, list_samples):
        list_x = torch.stack([torch.as_tensor(x, dtype=torch.float) for (x, y) in list_samples], dim=0)
        list_y = torch.stack([torch.as_tensor(y, dtype=torch.float) for (x, y) in list_samples], dim=0)

        return (list_x, list_y)

    def __getitem__(self, idx):
        # For a single input x, samples (t, t') ~ T to generate (t(x), t'(x))
        ### ADNI
        if self.config.mode == 0: # Pre-training # consider multiple labels (list)
            labels = []
            for label_nm in self.config.label_name: # ["PTAGE", "PTGENDER"]
                labels.append(float(self.data_csv[label_nm].values[idx]))
            labels = tuple(labels)
        else: # Fine-tuning
            labels = self.data_csv[self.config.label_name].values[idx]
        SubjectID = self.data_csv['SubjectID'].values[idx]
        file_match = [file for file in self.files if SubjectID in file]
        path = os.path.join(self.data_dir, file_match[0])
        img = nib.load(os.path.join(path, 'brain_to_MNI_nonlin.nii.gz'))
        img = np.swapaxes(img.get_data(),1,2)
        img = np.flip(img,1)
        img = np.flip(img,2)
        img = resize(img, (self.config.input_size[1], self.config.input_size[2], self.config.input_size[3]), mode='constant')
        img = torch.from_numpy(img).float().view(self.config.input_size[0], self.config.input_size[1], self.config.input_size[2], self.config.input_size[3])
        img = img.numpy()
        
        np.random.seed()
        if self.config.mode == 0: # Pre-training
            x1 = self.transforms(img)
            x2 = self.transforms(img)
            x = np.stack((x1, x2), axis=0)
        else: # Fine-tuning
            x = self.transforms(img)
        ###

        return (x, labels)

    def __len__(self):
        return len(self.data_csv)

In [4]:
config = Config(mode = FINE_TUNING)
config.data

'/scratch/connectome/study_group/VAE_ADHD/data'

In [5]:
class Args(): 
    def __init__(self):
        self.mode = "finetuning"
        self.train_num = 100
        self.task_name = "MCI/CN"
        self.layer_control = "tune_all" #freeze
        self.stratify = "balan"
        self.rancom_seed = 0

In [6]:
##THE THREE KEY THINGS THAT ARE USED!!! (the pandas directory, the name of the task, and the task itself)
#done so that even if there's like three task possibles for a given task (ex : Dx.new), we can still choose two out of those three to do classification
config.label, config.label_name, args.task_name

NameError: name 'args' is not defined

In [None]:
print(label_name)
labels.keys()

In [None]:
labels['Dx.new'] #config.label_name

In [None]:
labels['Dx.new']== "CN"

In [10]:
label = "./csv/BT_ABCD_ADHD_sex_edited.csv"
label_name = "sex"
task_name = "M/W"

In [11]:
labels = pd.read_csv(label)
task_include = task_name.split('/')

labels[labels[label_name] == task_include[0]]

Unnamed: 0.1,Unnamed: 0,subjectkey,sex,HCvsADHD
0,0,NDARINV0A4P0LWM,M,SecHC
5,5,NDARINV0CBFTKR7,M,SecHC
7,7,NDARINV0E4CT74P,M,HC
10,10,NDARINV0H2AWWPU,M,SecHC
12,12,NDARINV0J1M2ETU,M,SecHC
...,...,...,...,...
11281,11610,NDARINVZR16R6Y3,M,HC
11282,11611,NDARINVZRR4D9LW,M,SecHC
11285,11614,NDARINVZW8G4W5A,M,HC
11287,11617,NDARINVZWP0XZ9A,M,SecHC


In [22]:
set(labels["HCvsADHD"])

{'ADHD', 'HC', 'SecHC'}

In [12]:
labels

Unnamed: 0.1,Unnamed: 0,subjectkey,sex,HCvsADHD
0,0,NDARINV0A4P0LWM,M,SecHC
1,1,NDARINV0A6WVRZY,W,SecHC
2,2,NDARINV0A86UD86,W,SecHC
3,3,NDARINV0AU5R8NA,W,SecHC
4,4,NDARINV0BEPJHU1,W,SecHC
...,...,...,...,...
11287,11617,NDARINVZWP0XZ9A,M,SecHC
11288,11618,NDARINVZWWDT1TG,W,SecHC
11289,11619,NDARINVZY3TE53A,W,SecHC
11290,11620,NDARINVZYRTFYRP,M,ADHD


In [54]:
print(labels[label_name] == 1.0)

0         True
1        False
2        False
3        False
4        False
         ...  
11287     True
11288    False
11289    False
11290     True
11291    False
Name: sex, Length: 11292, dtype: bool


In [59]:
#FROM THE  main.py! 
args = Args() #brought in parser to here
random_seed = 0 #set for things
label_name = config.label_name # 'Dx.new'

labels = pd.read_csv(config.label)

print('Task: Fine-tuning for {0}'.format(args.task_name))
print('Task type: {0}'.format(config.task_type))
print('N = {0}'.format(args.train_num))

if config.task_type == 'cls':
    print('Policy: {0}'.format(args.stratify))
    task_include = args.task_name.split('/')
    #print(task_include)
    assert len(task_include) == 2, 'Set two labels.'
    assert config.num_classes == 2, 'Set config.num_classes == 2'
    data_1 = labels[labels[label_name] == task_include[0]]
    data_2 = labels[labels[label_name] == task_include[1]]
    if args.stratify == 'strat':
        ratio = len(data_1) / (len(data_1) + len(data_2))
        len_1_train = round(args.train_num*ratio)
        len_2_train = args.train_num - len_1_train
        len_1_valid = round(int(args.train_num*config.valid_ratio)*ratio)
        len_2_valid = int(args.train_num*config.valid_ratio) - len_1_valid
        assert args.train_num*(1+config.valid_ratio) < (len(data_1) + len(data_2)), 'Not enough valid data. Set smaller --train_num or smaller config.valid_ratio in config.py.'
        train1, valid1, test1 = np.split(data_1.sample(frac=1, random_state=random_seed), 
                                         [len_1_train, len_1_train + len_1_valid])
        train2, valid2, test2 = np.split(data_2.sample(frac=1, random_state=random_seed), 
                                         [len_2_train, len_2_train + len_2_valid])
        label_train = pd.concat([train1, train2]).sample(frac=1, random_state=random_seed)
        label_valid = pd.concat([valid1, valid2]).sample(frac=1, random_state=random_seed)
        label_test = pd.concat([test1, test2]).sample(frac=1, random_state=random_seed)
        assert len(label_test) >= 100, 'Not enough test data. (Total: {0})'.format(len(label_test))
    else: # args.stratify == 'balan'
        if len(data_1) <= len(data_2):
            limit = len(data_1)
        else:
            limit = len(data_2)
        data_1 = data_1.sample(frac=1, random_state=random_seed)[:limit]
        data_2 = data_2.sample(frac=1, random_state=random_seed)[:limit]
        len_1_train = round(args.train_num*0.5)
        len_2_train = args.train_num - len_1_train
        len_1_valid = round(int(args.train_num*config.valid_ratio)*0.5)
        len_2_valid = int(args.train_num*config.valid_ratio) - len_1_valid
        assert args.train_num*(1+config.valid_ratio) < limit*2, 'Not enough data to make balanced set.'
        train1, valid1, test1 = np.split(data_1.sample(frac=1, random_state=random_seed), 
                                         [len_1_train, len_1_train + len_1_valid])
        train2, valid2, test2 = np.split(data_2.sample(frac=1, random_state=random_seed), 
                                         [len_2_train, len_2_train + len_2_valid])
        label_train = pd.concat([train1, train2]).sample(frac=1, random_state=random_seed)
        label_valid = pd.concat([valid1, valid2]).sample(frac=1, random_state=random_seed)
        label_test = pd.concat([test1, test2]).sample(frac=1, random_state=random_seed)
        assert len(label_test) >= 100, 'Not enough test data. (Total: {0})'.format(len(label_test))
    print('\nTrain data info:\n{0}\nTotal: {1}\n'.format(label_train[label_name].value_counts().sort_index(), len(label_train)))
    print('Valid data info:\n{0}\nTotal: {1}\n'.format(label_valid[label_name].value_counts().sort_index(), len(label_valid)))
    print('Test data info:\n{0}\nTotal: {1}\n'.format(label_test[label_name].value_counts().sort_index(), len(label_test)))
    label_train[label_name].replace({task_include[0]: 0, task_include[1]: 1}, inplace=True)
    label_valid[label_name].replace({task_include[0]: 0, task_include[1]: 1}, inplace=True)
    label_test[label_name].replace({task_include[0]: 0, task_include[1]: 1}, inplace=True)
else: # config.task_type = 'reg'
    task_include = args.task_name.split('/')
    assert len(task_include) == 1, 'Set only one label.'
    assert config.num_classes == 1, 'Set config.num_classes == 1'
    labels = pd.read_csv(config.label)
    labels = labels[(np.abs(stats.zscore(labels[label_name])) < 3)] # remove outliers w.r.t. z-score > 3
    assert args.train_num*(1+config.valid_ratio) <= len(labels), 'Not enough valid data. Set smaller --train_num or smaller config.valid_ratio in config.py.'
    label_train, label_valid, label_test = np.split(labels.sample(frac=1, random_state=random_seed), 
                                                    [args.train_num, int(args.train_num*(1+config.valid_ratio))])
    
    print('\nTrain data info:\nTotal: {0}\n'.format(len(label_train)))
    print('Valid data info:\nTotal: {0}\n'.format(len(label_valid)))
    print('Test data info:\nTotal: {0}\n'.format(len(label_test)))

Task: Fine-tuning for MCI/CN
Task type: cls
N = 100
Policy: balan

Train data info:
CN     50
MCI    50
Name: Dx.new, dtype: int64
Total: 100

Valid data info:
CN     13
MCI    12
Name: Dx.new, dtype: int64
Total: 25

Test data info:
CN     675
MCI    676
Name: Dx.new, dtype: int64
Total: 1351



In [77]:
labels[['SubjectID', 'Dx.new']]

Unnamed: 0,SubjectID,Dx.new
0,002S0295,CN
1,002S0413,CN
2,002S0559,CN
3,002S0685,CN
4,002S0729,MCI
...,...,...
1802,941S6570,CN
1803,941S6574,CN
1804,941S6575,CN
1805,941S6580,CN


RangeIndex(start=0, stop=1807, step=1)

In [3]:
### ADNI
import os
#os.environ["CUDA_VISIBLE_DEVICES"] = '1, 2, 3'
import time
import datetime
###
import numpy as np
from dataset import ADNI_Dataset
from torch.utils.data import DataLoader, Dataset, RandomSampler
from yAwareContrastiveLearning import yAwareCLModel
from losses import GeneralizedSupervisedNTXenLoss
from torch.nn import CrossEntropyLoss, MSELoss # ADNI
from models.densenet import densenet121

import argparse
from config import Config, PRETRAINING, FINE_TUNING
### ADNI
import torch
import random
import matplotlib.pyplot as plt
import sklearn.metrics as metrics
from sklearn.metrics import roc_auc_score, mean_absolute_error, mean_squared_error, r2_score
import pandas as pd
from scipy import stats
###

now = datetime.datetime.now()
nowDatetime = now.strftime('%Y-%m-%d %H:%M:%S') # ADNI
print("[main.py started at {0}]".format(nowDatetime))
start_time = time.time() # ADNI
parser = argparse.ArgumentParser()
parser.add_argument("--pretrained_path", type=str, required=True, # ADNI
                    help="Set the pretrained model path.")   
parser.add_argument("--mode", type=str, choices=["pretraining", "finetuning"], required=True,
                    help="Set the training mode. Do not forget to configure config.py accordingly !")
parser.add_argument("--train_num", type=int, required=True, # ADNI
                    help="Set the number of training samples.")                        
parser.add_argument("--task_name", type=str, required=False, # ADNI
                    help="Set the name of the fine-tuning task. (e.g. AD/MCI)")
parser.add_argument("--layer_control", type=str, choices=['tune_all', 'freeze', 'tune_diff'], required=False, # ADNI
                    help="Set pretrained weight layer control option.")
parser.add_argument("--stratify", type=str, choices=["strat", "balan"], required=False, # ADNI
                    help="Set training samples are stratified or balanced for fine-tuning task.")
parser.add_argument("--random_seed", type=int, required=False, default=0, # ADNI
                    help="Random seed for reproduction.")
args = parser.parse_args()
mode = PRETRAINING if args.mode == "pretraining" else FINE_TUNING
config = Config(mode)
pretrained_path = args.pretrained_path
print('Pretrained path:', pretrained_path)
### ADNI
label_name = config.label_name # 'Dx.new'
# Control randomness for reproduction
if args.random_seed != None:
    random_seed = args.random_seed
    os.environ["PYTHONHASHSEED"] = str(random_seed)
    random.seed(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed) # if use multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
#if config.mode == PRETRAINING:
#    data = pd.read_csv(config.label)
#    for i in range(len(label_name)): # ["PTAGE", "PTGENDER"]
#        if config.label_type[i] != 'cont': # convert str object to numbers
#            data[label_name[i]] = pd.Categorical(data[label_name[i]])
#            data[label_name[i]] = data[label_name[i]].cat.codes

#    assert args.train_num*(1+config.valid_ratio) <= len(data), 'Not enough valid data. Set smaller --train_num or smaller config.valid_ratio in config.py.'
#    label_train, label_valid, label_test = np.split(data.sample(frac=1, random_state=random_seed), 
#                                                    [args.train_num, int(args.train_num*(1+config.valid_ratio))])
#    
#    print('Task: Pretraining')
#    print('N = {0}'.format(args.train_num))
#    print('meta-data: {0}\n'.format(label_name))
#    assert len(label_name) == len(config.alpha_list), 'len(label_name) and len(alpha_list) should match.'
#    assert len(label_name) == len(config.label_type), 'len(label_name) and len(label_type) should match.'
#    assert len(label_name) == len(config.sigma), 'len(alpha_list) and len(sigma) should match.'
#    assert sum(config.alpha_list) == 1.0, 'Sum of alpha list should be 1.'

else: # config.mode == FINE_TUNING:
    labels = pd.read_csv(config.label)
    print('Task: Fine-tuning for {0}'.format(args.task_name))
    print('Task type: {0}'.format(config.task_type))
    print('N = {0}'.format(args.train_num))
    
    if config.task_type == 'cls':
        print('Policy: {0}'.format(args.stratify))
        task_include = args.task_name.split('/')
        assert len(task_include) == 2, 'Set two labels.'
        assert config.num_classes == 2, 'Set config.num_classes == 2'
        data_1 = labels[labels[label_name] == task_include[0]]
        data_2 = labels[labels[label_name] == task_include[1]]
        if args.stratify == 'strat':
            ratio = len(data_1) / (len(data_1) + len(data_2))
            len_1_train = round(args.train_num*ratio)
            len_2_train = args.train_num - len_1_train
            len_1_valid = round(int(args.train_num*config.valid_ratio)*ratio)
            len_2_valid = int(args.train_num*config.valid_ratio) - len_1_valid
            assert args.train_num*(1+config.valid_ratio) < (len(data_1) + len(data_2)), 'Not enough valid data. Set smaller --train_num or smaller config.valid_ratio in config.py.'
            train1, valid1, test1 = np.split(data_1.sample(frac=1, random_state=random_seed), 
                                             [len_1_train, len_1_train + len_1_valid])
            train2, valid2, test2 = np.split(data_2.sample(frac=1, random_state=random_seed), 
                                             [len_2_train, len_2_train + len_2_valid])
            label_train = pd.concat([train1, train2]).sample(frac=1, random_state=random_seed)
            label_valid = pd.concat([valid1, valid2]).sample(frac=1, random_state=random_seed)
            label_test = pd.concat([test1, test2]).sample(frac=1, random_state=random_seed)
            assert len(label_test) >= 100, 'Not enough test data. (Total: {0})'.format(len(label_test))
        else: # args.stratify == 'balan'
            if len(data_1) <= len(data_2):
                limit = len(data_1)
            else:
                limit = len(data_2)
            data_1 = data_1.sample(frac=1, random_state=random_seed)[:limit]
            data_2 = data_2.sample(frac=1, random_state=random_seed)[:limit]
            len_1_train = round(args.train_num*0.5)
            len_2_train = args.train_num - len_1_train
            len_1_valid = round(int(args.train_num*config.valid_ratio)*0.5)
            len_2_valid = int(args.train_num*config.valid_ratio) - len_1_valid
            assert args.train_num*(1+config.valid_ratio) < limit*2, 'Not enough data to make balanced set.'
            train1, valid1, test1 = np.split(data_1.sample(frac=1, random_state=random_seed), 
                                             [len_1_train, len_1_train + len_1_valid])
            train2, valid2, test2 = np.split(data_2.sample(frac=1, random_state=random_seed), 
                                             [len_2_train, len_2_train + len_2_valid])
            label_train = pd.concat([train1, train2]).sample(frac=1, random_state=random_seed)
            label_valid = pd.concat([valid1, valid2]).sample(frac=1, random_state=random_seed)
            label_test = pd.concat([test1, test2]).sample(frac=1, random_state=random_seed)
            assert len(label_test) >= 100, 'Not enough test data. (Total: {0})'.format(len(label_test))
        print('\nTrain data info:\n{0}\nTotal: {1}\n'.format(label_train[label_name].value_counts().sort_index(), len(label_train)))
        print('Valid data info:\n{0}\nTotal: {1}\n'.format(label_valid[label_name].value_counts().sort_index(), len(label_valid)))
        print('Test data info:\n{0}\nTotal: {1}\n'.format(label_test[label_name].value_counts().sort_index(), len(label_test)))
        label_train[label_name].replace({task_include[0]: 0, task_include[1]: 1}, inplace=True)
        label_valid[label_name].replace({task_include[0]: 0, task_include[1]: 1}, inplace=True)
        label_test[label_name].replace({task_include[0]: 0, task_include[1]: 1}, inplace=True)
    else: # config.task_type = 'reg'
        task_include = args.task_name.split('/')
        assert len(task_include) == 1, 'Set only one label.'
        assert config.num_classes == 1, 'Set config.num_classes == 1'
        labels = pd.read_csv(config.label)
        labels = labels[(np.abs(stats.zscore(labels[label_name])) < 3)] # remove outliers w.r.t. z-score > 3
        assert args.train_num*(1+config.valid_ratio) <= len(labels), 'Not enough valid data. Set smaller --train_num or smaller config.valid_ratio in config.py.'
        label_train, label_valid, label_test = np.split(labels.sample(frac=1, random_state=random_seed), 
                                                        [args.train_num, int(args.train_num*(1+config.valid_ratio))])
        
        print('\nTrain data info:\nTotal: {0}\n'.format(len(label_train)))
        print('Valid data info:\nTotal: {0}\n'.format(len(label_valid)))
        print('Test data info:\nTotal: {0}\n'.format(len(label_test)))
###

### ADNI
#if config.mode == PRETRAINING:
#    dataset_train = ADNI_Dataset(config, label_train, data_type='train')
#    dataset_val = ADNI_Dataset(config, label_valid, data_type='valid')
#    dataset_test = ADNI_Dataset(config, label_test, data_type='test')
elif config.mode == FINE_TUNING:
    dataset_train = ADNI_Dataset(config, label_train, data_type='train')
    dataset_val = ADNI_Dataset(config, label_valid, data_type='valid')
    dataset_test = ADNI_Dataset(config, label_test, data_type='test')
###
loader_train = DataLoader(dataset_train,
                          batch_size=config.batch_size,
                          sampler=RandomSampler(dataset_train),
                          collate_fn=dataset_train.collate_fn,
                          pin_memory=config.pin_mem,
                          num_workers=config.num_cpu_workers
                          )
loader_val = DataLoader(dataset_val,
                        batch_size=config.batch_size,
                        sampler=RandomSampler(dataset_val),
                        collate_fn=dataset_val.collate_fn,
                        pin_memory=config.pin_mem,
                        num_workers=config.num_cpu_workers
                        )
### ADNI
loader_test = DataLoader(dataset_test,
                         batch_size=1,
                         sampler=RandomSampler(dataset_test),
                         collate_fn=dataset_test.collate_fn,
                         pin_memory=config.pin_mem,
                         num_workers=config.num_cpu_workers
                         )
###
#if config.mode == PRETRAINING:
#    if config.model == "DenseNet":
#        net = densenet121(mode="encoder", drop_rate=0.0)
#    elif config.model == "UNet":
#        net = UNet(config.num_classes, mode="simCLR") # ?? why num_classes?
#    else:
#        raise ValueError("Unkown model: %s"%config.model)
else: # config.mode == FINETUNING:
    if config.model == "DenseNet":
        net = densenet121(mode="classifier", drop_rate=0.0, num_classes=config.num_classes)
    elif config.model == "UNet":
        net = UNet(config.num_classes, mode="classif")
    else:
        raise ValueError("Unkown model: %s"%config.model)
#if config.mode == PRETRAINING:
#    loss = GeneralizedSupervisedNTXenLoss(config=config, # ADNI
#                                          temperature=config.temperature,
#                                          sigma=config.sigma,
#                                          return_logits=True)
elif config.mode == FINE_TUNING:
    if config.task_type == 'cls':
        loss = CrossEntropyLoss()
    else: # config.task_type == 'reg': # ADNI
        loss = MSELoss()
#if config.mode == PRETRAINING:
#    model = yAwareCLModel(net, loss, loader_train, loader_val, loader_test, config, "no", 0, "no", None, pretrained_path) # ADNI
else:
    model = yAwareCLModel(net, loss, loader_train, loader_val, loader_test, config, args.task_name, args.train_num, args.layer_control, None, pretrained_path) # ADNI
#if config.mode == PRETRAINING:
#    model.pretraining()
else:
    outGT, outPRED = model.fine_tuning() # ADNI
    #print('outGT:', outGT)
    #print('outPRED:', outPRED)

### ADNI
if config.mode == FINE_TUNING:
    if config.task_type == 'cls':
        outGTnp = outGT.cpu().numpy()
        outPREDnp = outPRED.cpu().numpy()
        print('\n<<< Test Results: AUROC >>>')
        outAUROC = []
        for i in range(config.num_classes):
            outAUROC.append(roc_auc_score(outGTnp[:, i], outPREDnp[:, i]))
        aurocMean = np.array(outAUROC).mean()
        print('MEAN', ': {:.4f}'.format(aurocMean))
        fig, ax = plt.subplots(nrows = 1, ncols = config.num_classes)
        ax = ax.flatten()
        fig.set_size_inches((config.num_classes * 10, 10))
        for i in range(config.num_classes):
            print(task_include[i], ': {:.4f}'.format(outAUROC[i]))
            fpr, tpr, threshold = metrics.roc_curve(outGT.cpu()[:, i], outPRED.cpu()[:, i])
            roc_auc = metrics.auc(fpr, tpr)
            ax[i].plot(fpr, tpr, label = 'AUC = %0.2f' % (roc_auc))
            ax[i].set_title('ROC for {0}'.format(task_include[i]))
            ax[i].legend(loc = 'lower right')
            ax[i].plot([0, 1], [0, 1], 'r--')
            ax[i].set_xlim([0, 1])
            ax[i].set_ylim([0, 1])
            ax[i].set_ylabel('True Positive Rate')
            ax[i].set_xlabel('False Positive Rate')
        
        if args.layer_control == 'tune_all':
            control = 'a'
        elif args.layer_control == 'freeze':
            control = 'f'
        else:
            control = 'd'
        plt.savefig('./figs/{0}_ADNI_{1}{2}{3}_{4}_ROC.png'.format(str(datetime.datetime.now())[2:10].replace('-', ''),
                                                                   args.task_name.replace('/', ''), 
                                                                   str(args.train_num)[0], 
                                                                   args.stratify[0], 
                                                                   control), dpi = 100)
        plt.close()
        
        ############################# rename stats.txt ###########################################
        os.rename('./stats.txt', "./"+args.task_name.replace('/', '')+"_stats.txt")
        #stats_file = open("./"+args.task_name.replace('/', '')+"_stats.txt", 'a', buffering=1)
        #stats_file.write("until here_"+args.task_name.replace('/', '')+"_"+str(args.train_num)[0]+"_"+args.stratify[0]+"_"+control)
        #stats_file.write("#########################################################################")
        #stats_file.close()
        
        
    else: # config.task_type == 'reg':
        outGTnp = outGT.cpu().numpy()
        outPREDnp = outPRED.cpu().numpy()
        print('\n<<< Test Results >>>')
        print('MSE: {:.2f}'.format(mean_squared_error(outGTnp, outPREDnp)))
        print('MAE: {:.2f}'.format(mean_absolute_error(outGTnp, outPREDnp)))
        print('RMSE: {:.2f}'.format(np.sqrt(mean_squared_error(outGTnp, outPREDnp))))
        print('R2-score: {:.4f}'.format(r2_score(outGTnp, outPREDnp)))
end_time = time.time()
print('\nTotal', round((end_time - start_time) / 60), 'minutes elapsed.')
now = datetime.datetime.now()
nowDatetime = now.strftime('%Y-%m-%d %H:%M:%S') # ADNI
print("[main.py finished at {0}]".format(nowDatetime))
###

SyntaxError: invalid syntax (3075718483.py, line 157)