In [1]:
import os
import sys
import glob
import math
import time
import timeit
import random
import shutil
import importlib
import pandas as pd
import numpy as np 
from datetime import datetime
from typing import Any, Callable, Optional, Sequence, Union
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=Warning)

import torch
import torch.nn as nn 
from torch.utils.data import DataLoader
from torch.nn.utils import weight_norm
import torch.utils.data
import torchmetrics

from monai.config import DtypeLike
from monai.data.image_reader import ImageReader
from monai.data import ImageDataset
from monai.transforms import Compose, RandAffine, AddChannel, ScaleIntensity, RandFlip, Resize, NormalizeIntensity, ToTensor, SpatialPad

iterations = [1, 5] # 1
epochs = 100 # 1
batch_size = 8 # 2
num_workers = 16 # 0

module_names = ['GOTDNet']
model_names = ['GOTDNet']

model_dir = 'models'
for  module_name in module_names:
    exec(f'from {model_dir}.{module_name} import *')

dataset_dir = '../Total_Datasets/Datasets-TAO/231002_Treatment_Proposed_Dataset_v5'
label_file = 'Treatment_Label_v0.5_230913.csv'

devices = [0,1]
Metrics = ['Experient Time', 'Train Time','Iteration', 'Model Name', 'Loss', 'AUROC', 'AUPRC', 'Accuracy', 'F1 Score', 'Sensitivity', 'Specificity', 'Precision', 'Threshold', 'Params', 'Training Time (s)', 'Test Time (s)', 'Best_Epoch','DIR']
in_channels = 1
num_classes = 2

optimizer = 'adamw'
lr = 1e-3
momentum = 0.9
weight_decay = 1e-4
optim_args = {'optimizer': optimizer, 'lr': lr, 'momentum': momentum, 'weight_decay': weight_decay}

lr_scheduler = 'CosineAnnealingLR'
T_max = epochs
T_0 = 50
eta_min = 1e-6
lr_scheduler_args = {'lr_scheduler': lr_scheduler, 'T_max': T_max, 'T_0': T_0, 'eta_min': eta_min}

loss_function = 'FocalLoss'
reduction = 'mean'
gamma = 2.0
weight = None
loss_function_args = {'loss_function': loss_function, 'reduction': reduction, 'gamma': gamma, 'weight': weight}

save_log = False

In [2]:
def train_epoch(model, optimizer, criterion, train_loader, device):
    model.train()
    train_losses = AverageMeter()
    for i, batch in enumerate(train_loader):
        # print(f'Step {i+1}/{len(train_loader)}', end=' ')
        model.zero_grad(set_to_none=True)
        input, target, _ = batch
        input = input.to(device)
        target = target.to(device)#.float()
        output = model(input) 
        output = nn.Softmax(dim=1)(output)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_losses.update(loss.detach().cpu().numpy(),input.shape[0])
    train_losses = round(float(train_losses.avg),6)
    return train_losses  

def infer(model, criterion, valid_loader, device):
    model.eval()
    with torch.no_grad():     
        for i, batch in enumerate(valid_loader):
            # print(f'Step {i+1}/{len(valid_loader)}', end=' ')
            input, target, sample_index = batch
            target = target.to(device)#.float()
            output = model(input)
            output = nn.Softmax(dim=1)(output)
            if i==0:
                targets = target.to('cpu')
                outputs = output.to('cpu')
                sample_indexes = sample_index
            else:
                targets = torch.cat((targets, target.to('cpu')))
                outputs = torch.cat((outputs, output.to('cpu')), axis=0)
                sample_indexes = torch.cat((sample_indexes, sample_index))
    return outputs, targets, sample_indexes

def count_parameters(model):
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return "{:4.2f} M".format( num_params/1000000 )

def copy_sourcefile(output_dir, src_dir = 'src' ):    
    import os 
    import shutil
    import glob 
    source_dir = os.path.join(output_dir, src_dir)

    os.makedirs(source_dir, exist_ok=True)
    org_files1 = os.path.join('./', '*.py' )
    org_files2 = os.path.join('./', '*.sh' )
    org_files3 = os.path.join('./', '*.ipynb' )
    org_files4 = os.path.join('./', '*.txt' )
    org_files5 = os.path.join('./', '*.json' )    
    files =[]
    files = glob.glob(org_files1 )
    files += glob.glob(org_files2  )
    files += glob.glob(org_files3  )
    files += glob.glob(org_files4  ) 
    files += glob.glob(org_files5  )     

    # print("COPY source to output/source dir ", files)
    tgt_files = os.path.join( source_dir, '.' )
    for i, file in enumerate(files):
        shutil.copy(file, tgt_files)

def init_weights(m, mean=0.0, std=0.01):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        m.weight.data.normal_(mean, std)

def save_checkpoint(args, model, optimizer, optim_args, epoch, loss, output_dir, file_path):
    state = {
        'args'           : args,
        'model_state'    : model.state_dict(),
        'optimizer_state': optimizer.state_dict(),
        'optim_args'     : optim_args,
        'epoch'          : epoch,
        'val_loss'       : loss,
        }
    os.makedirs(output_dir, exist_ok=True)
    torch.save(state, file_path)   

def load_checkpoint(model, optimizer, file_path):
    import os
    import torch
    dst = f'cuda:{torch.cuda.current_device()}'
    checkpoint = torch.load(file_path, map_location=dst)

    model.load_state_dict(checkpoint['model_state'])
    optimizer.load_state_dict(checkpoint['optimizer_state'])

def configure_optimizer(model, optim_args) :
    import torch.optim as optim
    if optim_args['optimizer'].lower() == 'sgd':
        optimizer = optim.SGD(model.parameters(),     lr=optim_args['lr'], momentum=optim_args['momentum'])
    elif optim_args['optimizer'].lower() == 'adagrad':
        optimizer = optim.Adagrad(model.parameters(), lr=optim_args['lr'])
    elif optim_args['optimizer'].lower() == 'adam':
        optimizer = optim.Adam(model.parameters(),    lr=optim_args['lr'])
    elif optim_args['optimizer'].lower() == 'lamb':
        optimizer = lamb.Lamb(model.parameters(),     lr=optim_args['lr'],  betas=(0.9, 0.98), eps=1e-9, weight_decay=optim_args['weight_decay'])
    elif optim_args['optimizer'].lower() == 'jitlamb':
        optimizer = lamb.JITLamb(model.parameters(),  lr=optim_args['lr'],  betas=(0.9, 0.98), eps=1e-9, weight_decay=optim_args['weight_decay'])
    elif optim_args['optimizer'].lower() == 'adamw':
        optimizer = optim.AdamW(model.parameters(),   lr=optim_args['lr'], weight_decay=optim_args['weight_decay'])
    return optimizer     

class ImageDatasetWithIndex(ImageDataset):
    def __init__(
        self,
        image_files: Sequence[str],
        seg_files: Optional[Sequence[str]] = None,
        labels: Optional[Sequence[float]] = None,
        transform: Optional[Callable] = None,
        seg_transform: Optional[Callable] = None,
        label_transform: Optional[Callable] = None,
        image_only: bool = True,
        transform_with_metadata: bool = False,
        dtype: DtypeLike = np.float32,
        reader: Optional[Union[ImageReader, str]] = None,
        sample_indexes: Optional[Sequence[int]] = None,
        *args,
        **kwargs,
    ) -> None:
        super().__init__(
            image_files=image_files,
            seg_files=seg_files,
            labels=labels,
            transform=transform,
            seg_transform=seg_transform,
            label_transform=label_transform,
            image_only=image_only,
            transform_with_metadata=transform_with_metadata,
            dtype=dtype,
            reader=reader,
            *args,
            **kwargs,
        )
        self.sample_indexes = sample_indexes

    def __getitem__(self, index):
        image, label = super().__getitem__(index)[:2]
        # label = F.one_hot(torch.tensor(label), num_classes = num_classes)
        img_path = self.image_files[index]
        
        if self.sample_indexes is not None:
            sample_index = self.sample_indexes[index]
            return image, label, sample_index
        else:
            return image, label

def create_dataset(images, targets, sample_indexes, apply_augmentation=True):
    
    if apply_augmentation:
        transform = Compose([ToTensor(),
                             AddChannel(), 
                             SpatialPad(spatial_size=(160, 220, 350)),
                            ])

        dataset = ImageDatasetWithIndex(image_files=images, labels=targets, sample_indexes=sample_indexes, transform=transform)
    else:
        transform = Compose([ToTensor(),
                             AddChannel(), 
                             SpatialPad(spatial_size=(160, 220, 350)),
                            ])
        dataset = ImageDatasetWithIndex(image_files=images, labels=targets, sample_indexes=sample_indexes, transform=transform)
    return dataset

def calculate_performance(outputs, targets, threshold):
#     outputs = torch.tensor([0.1, 0.2, 0.2, 0.2, 0.2])
#     targets = torch.tensor([1, 0, 0, 0, 0])
    targets = targets.int()
    acc = torchmetrics.Accuracy(task='binary', threshold = threshold)(outputs, targets)
    f1 = torchmetrics.F1Score(task='binary', threshold = threshold)(outputs, targets)
    sensitivity = torchmetrics.Recall(task='binary', threshold = threshold)(outputs, targets)
    specificity = torchmetrics.Specificity(task='binary', threshold = threshold)(outputs, targets)
    precision = torchmetrics.Precision(task='binary', threshold = threshold)(outputs, targets)
    auprc = torchmetrics.AveragePrecision(task='binary')(outputs, targets)
    auroc = torchmetrics.AUROC(task='binary')(outputs, targets)
    
    return round(float(auroc),3), round(float(auprc),3), round(float(acc),3), round(float(f1),3), round(float(sensitivity),3), round(float(specificity),3), round(float(precision),3)

class AverageMeter (object):
    def __init__(self):
        self.reset ()
    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        
class LossSaver(object):
    def __init__(self):
        self.train_losses = []
        self.val_losses = []
    def reset(self):
        self.train_losses = []
        self.val_losses = []
    def update(self, train_loss, val_loss):
        self.train_losses.append(train_loss)
        self.val_losses.append(val_loss)
    def return_list(self):
        return self.train_losses, self.val_losses
    def save_as_csv(self, csv_file):
        df = pd.DataFrame({'Train Losses': self.train_losses, 'Validation Losses': self.val_losses})
        df.index = [f"{i+1} Epoch" for i in df.index]
        df.to_csv(csv_file, index=True)
        

def control_random_seed(seed, pytorch=True):
    random.seed(seed)
    np.random.seed(seed)
    try:
        torch.manual_seed(seed)
        if torch.cuda.is_available()==True:
            torch.cuda.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False
    except:
        pass
        torch.backends.cudnn.benchmark = False
    
class DualOutput:
    def __init__(self, file, stdout):
        self.file = file
        self.stdout = stdout

    def write(self, text):
        self.file.write(text)
        self.stdout.write(text)

    def flush(self):
        self.file.flush()
        self.stdout.flush()
def str_to_class(classname):
    return getattr(sys.modules[__name__], classname)
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2, device=False):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        alpha = self.alpha.to(inputs.device)
        pt = torch.exp(-ce_loss)
        loss = (alpha[targets] * (1 - pt) ** gamma * ce_loss).mean()
        return loss

In [None]:
def main(args):
    # print("args : ", args)
    Experiments_Time, iteration, module_name, model_name, \
    model_dir, output_dir, batch_size, epochs, num_workers, \
    dataset_dir, label_file, devices, \
    in_channels, num_classes, optim_args, lr_scheduler_args, loss_function_args = args
    
    seed = iteration
    control_random_seed(seed)                 
    
    device = torch.device("cuda:"+str(devices[0]))
    
    module = importlib.import_module(f'models.{module_name}')
    model = str_to_class(model_name)(in_channels, num_classes)
    try:
        init_weights(model)    
    except:
        pass
    if len(devices)>1:
        model = torch.nn.DataParallel(model, device_ids = devices ).to(device)
    else:
        model = model.to(device)
        
    optimizer = configure_optimizer(model, optim_args)
    if lr_scheduler_args['lr_scheduler'] == 'CosineAnnealingWarmRestarts':
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0 = lr_scheduler_args['T_0'], eta_min = lr_scheduler_args['eta_min'])
    elif lr_scheduler_args['lr_scheduler'] == 'CosineAnnealingLR':
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = lr_scheduler_args['T_max'], eta_min = lr_scheduler_args['eta_min'])
    

    df_label = pd.read_csv(label_file, header=None, index_col=0) #[:100]

    images = [os.path.join(dataset_dir, f'{i}') for i in df_label.index.to_list()]
    labels = df_label.to_numpy(dtype=np.int64).flatten()
   
    from sklearn.model_selection import train_test_split
    train_test_split_ratio = 0.8
    train_val_split_ratio = 0.75
    train_images, test_images, train_labels, test_labels = train_test_split(
        images, labels, train_size=train_test_split_ratio, random_state=seed
    )
    train_images, val_images, train_labels, val_labels = train_test_split(
        train_images, train_labels, train_size=train_val_split_ratio, random_state=seed,
        stratify=train_labels
    )
    from collections import Counter
      
    train_class_ratios = {k: round(v / len(train_labels), 3) for k, v in Counter(train_labels).items()}
    val_class_ratios = {k: round(v / len(val_labels), 3) for k, v in Counter(val_labels).items()}
    test_class_ratios = {k: round(v / len(test_labels), 3) for k, v in Counter(test_labels).items()}
    
    print(f"Class ratios: {train_class_ratios} / {val_class_ratios} / {test_class_ratios}")
    
    train_indexes = [(train_images + val_images + test_images).index(x) for x in train_images]; val_indexes = [(train_images + val_images + test_images).index(x) for x in val_images]; test_indexes = [(train_images + val_images + test_images).index(x) for x in test_images]

    df_data_split = pd.DataFrame({'Images': [os.path.basename(image) for image in train_images + val_images + test_images], 'Labels': list(train_labels) + list(val_labels) + list(test_labels)})
    df_data_split = df_data_split.reset_index(drop=True).reset_index().rename(columns={'index': 'Sample Index'})
    data_split = (
        ['Train'] * len(train_images) +
        ['Validation'] * len(val_images) +
        ['Test'] * len(test_images)
    )
    df_data_split['Data Split'] = data_split
        
    trainset = create_dataset(train_images, train_labels, train_indexes)
    validset = create_dataset(val_images, val_labels, val_indexes, apply_augmentation=False)
    testset = create_dataset(test_images, test_labels, test_indexes, apply_augmentation=False)
    print(f'Data Split (Train/Val/Test): {len(train_images)}/{len(val_images)}/{len(test_images)}')
    
    train_sampler = None
    valid_sampler = None
    test_sampler  = None
    suffle = False
    
    train_loader = DataLoader(trainset, num_workers=num_workers, shuffle=False,
                          sampler=train_sampler, batch_size=batch_size ,
                          pin_memory=False, drop_last=False,
                          collate_fn=None)
    
    valid_loader = DataLoader(validset, num_workers=num_workers, shuffle=False,
                          sampler=valid_sampler, batch_size=batch_size ,
                          pin_memory=False, drop_last=False,
                          collate_fn=None)
    
    test_loader = DataLoader(testset, num_workers=num_workers, shuffle=False,
                          sampler=test_sampler, batch_size=batch_size,
                          pin_memory=False, drop_last=False,
                          collate_fn=None)

    if loss_function_args['loss_function'] == 'BCEWithLogitsLoss':
        criterion = nn.BCEWithLogitsLoss()
    elif loss_function_args['loss_function'] == 'FocalLoss':
        class_counts = np.bincount(train_labels)
        num_classes = len(class_counts)
        total_samples = len(train_labels)
        class_weights = []
        for count in class_counts:
            weight = 1 / (count / total_samples)
            class_weights.append(weight)
        class_weights = torch.FloatTensor(class_weights)
        criterion = FocalLoss(alpha=class_weights, gamma=loss_function_args['gamma'])
    elif loss_function_args['loss_function'] == 'CrossEntropyLoss':
        criterion = nn.CrossEntropyLoss()
    
    loss_saver = LossSaver()
    Best_Loss = 999999999999999
    Best_Epoch = 1
    now = datetime.now()
    Train_Time = now.strftime("%y%m%d_%H%M%S")
    start = timeit.default_timer()
    print(f'Train Start ({Train_Time})')
    file_path = os.path.join(output_dir, f'{Train_Time}_{model_name}_iter_{iteration}.pt')
    save_checkpoint(args, model, optimizer, optim_args, 0, Best_Loss, output_dir, file_path)

    Early_Stop = 0
    for epoch in range(1, epochs+1):
        train_loss =  train_epoch(model, optimizer, criterion, train_loader, device)
        lr_scheduler.step()
        outputs, targets, sample_indexes = infer(model, criterion, valid_loader, device)
        val_loss = round(float(criterion(outputs, targets).cpu().numpy()),6)
        now = datetime.now()
        infer_date = now.strftime("%y%m%d_%H%M%S")        
        print(f'{epoch} EP({infer_date}): Loss: train-{train_loss}/val-{val_loss}',end=' ')
        loss_saver.update(train_loss, val_loss)
        loss_saver.save_as_csv(f'output/output_{Experiments_Time}/{model_name}_iter_{seed}/Losses_{Experiments_Time}.csv')
            
        if Best_Loss >= val_loss:
            save_checkpoint(args, model, optimizer, optim_args, epoch, val_loss, output_dir, file_path)
            Best_Loss = val_loss
            Best_Epoch = epoch
            Early_Stop = 0
            print(f"Best Epoch: {epoch}, Loss: {val_loss}", end=' ')
            threshold_range = [0, 1]
            graduation = (threshold_range[1]-threshold_range[0])/1000
            threshold_list=list(np.round(np.linspace(threshold_range[0]+graduation,threshold_range[1]-graduation,999),3))
            min_diff = 9999
            THRESHOLD = None
            outputs = outputs[:,1]
            for threshold in threshold_list:
                ss = torchmetrics.Recall(task='binary', threshold = threshold)(outputs, targets)
                sp = torchmetrics.Specificity(task='binary', threshold = threshold)(outputs, targets)
                if min_diff >= np.abs(ss-sp):
                    min_diff = np.abs(ss-sp)
                    THRESHOLD = threshold
            print(f'(Cutoff: {THRESHOLD}) Loss: {val_loss}')
            df_data_split['Model Validation Output'] = 'No Validation'; df_data_split['Model Validation Target'] = 'No Validation';
            for i in range(outputs.shape[0]):
                df_data_split.loc[df_data_split['Sample Index'] == int(sample_indexes[i]), 'Model Validation Output'] = np.round(float(outputs[i]),6)
                df_data_split.loc[df_data_split['Sample Index'] == int(sample_indexes[i]), 'Model Validation Target'] = 1 if float(outputs[i])>= THRESHOLD else 0
            df_data_split.to_csv(f'output/output_{Experiments_Time}/{model_name}_iter_{seed}/Data_split_and_Outputs_Targets_{Experiments_Time}.csv', index=False)
        else:
            print('')
            Early_Stop+=1
        if Early_Stop>=20:
            break
    stop = timeit.default_timer()
    training_time = round((stop - start),2)
    print(f"Train End ({datetime.now().strftime('%y%m%d_%H%M%S')})")
    
    # validation for threshold
    print(f'Validation Start ({datetime.now().strftime("%y%m%d_%H%M%S")})')
    load_checkpoint(model, optimizer, file_path)
    outputs, targets, sample_indexes = infer(model, criterion, valid_loader, device)
    val_loss = round(float(criterion(outputs, targets).cpu().numpy()),6)
    threshold_range = [0, 1]
    graduation = (threshold_range[1]-threshold_range[0])/100000
    threshold_list=list(np.round(np.linspace(threshold_range[0]+graduation,threshold_range[1]-graduation,99999),5))
    min_diff = 9999
    THRESHOLD = None
    outputs = outputs[:,1]
    for threshold in threshold_list:
        ss = torchmetrics.Recall(task='binary', threshold = threshold)(outputs, targets)
        sp = torchmetrics.Specificity(task='binary', threshold = threshold)(outputs, targets)
        if min_diff >= np.abs(ss-sp):
            min_diff = np.abs(ss-sp)
            THRESHOLD = threshold
    print(f'(Cutoff: {THRESHOLD}) Loss: {val_loss}')
    df_data_split['Model Validation Output'] = 'No Validation'; df_data_split['Model Validation Target'] = 'No Validation';
    for i in range(outputs.shape[0]):
        df_data_split.loc[df_data_split['Sample Index'] == int(sample_indexes[i]), 'Model Validation Output'] = np.round(float(outputs[i]),6)
        df_data_split.loc[df_data_split['Sample Index'] == int(sample_indexes[i]), 'Model Validation Target'] = 1 if float(outputs[i])>= THRESHOLD else 0
    df_data_split.to_csv(f'output/output_{Experiments_Time}/{model_name}_iter_{seed}/Data_split_and_Outputs_Targets_{Experiments_Time}.csv', index=False)
    print(f'Validation End ({datetime.now().strftime("%y%m%d_%H%M%S")})')
    
    # test
    now = datetime.now()
    Test_Time = now.strftime("%y%m%d_%H%M%S")
    print(f"Test Start ({Test_Time})")
    load_checkpoint(model, optimizer, file_path)
    start = timeit.default_timer()
          
    outputs, targets, sample_indexes = infer(model, criterion, test_loader, device)
    
    loss = round(float(criterion(outputs, targets).cpu().numpy()),6)
    outputs = outputs[:,1]
    auroc, auprc, acc, f1, ss, sp, pr = calculate_performance(outputs, targets, THRESHOLD)
    
    print(f'Test({datetime.now().strftime("%y%m%d_%H%M%S")}): Loss: {loss}, AUROC: {auroc}, AUPRC: {auprc}, ACC: {acc}, F1: {f1}, SS: {ss}, SP: {sp}, PR:{pr}')
    params = count_parameters(model)
    stop = timeit.default_timer()
    test_time = round((stop - start),2)

    Performances = [Experiments_Time, Train_Time, seed, model_name, loss, auroc, auprc, acc, f1, ss, sp, pr, THRESHOLD, params, training_time, test_time, Best_Epoch, os.getcwd()]
    df = pd.read_csv(f'output/output_{Experiments_Time}/Performance_{Experiments_Time}.csv')
    df = pd.concat([df, pd.DataFrame([Performances], columns = df.columns)], ignore_index=True)
    df.to_csv(f'output/output_{Experiments_Time}/Performance_{Experiments_Time}.csv', index=False, header=True)
    df_data_split['Model Test Output'] = 'No Test'; df_data_split['Model Test Target'] = 'No Test';
    for i in range(outputs.shape[0]):
        df_data_split.loc[df_data_split['Sample Index'] == int(sample_indexes[i]), 'Model Test Output'] = np.round(float(outputs[i]),6)
        df_data_split.loc[df_data_split['Sample Index'] == int(sample_indexes[i]), 'Model Test Target'] = 1 if float(outputs[i])>= THRESHOLD else 0
    df_data_split.to_csv(f'output/output_{Experiments_Time}/{model_name}_iter_{seed}/Data_split_and_Outputs_Targets_{Experiments_Time}.csv', index=False)

    print(f"Test End ({datetime.now().strftime('%y%m%d_%H%M%S')})")
    
Experiments_Time = datetime.now().strftime('%y%m%d_%H%M%S')
output_root = f'output/output_{Experiments_Time}'
os.makedirs(output_root, exist_ok = True)
df = pd.DataFrame(index=None, columns=Metrics)
df.to_csv(f'output/output_{Experiments_Time}/Performance_{Experiments_Time}.csv', index=False, header=True)
for iteration in range(iterations[0], iterations[1]+1):
    for i, (module_name, model_name) in enumerate(zip(module_names, model_names)):
        print(f'{model_name} (iter: {iteration})')
        if (model_name=='Xuefei_Song_ResNet18_Proposed_v8_18') and iteration<4:
            continue
        output_dir = output_root + f'/{model_name}_iter_{iteration}'
        copy_sourcefile(output_dir, src_dir='src')
        if save_log == True:
            original_stdout = sys.stdout
            log_file = open(f'{output_dir}/Log.txt', 'w')
            sys.stdout = DualOutput(log_file, original_stdout)
        args = [
            Experiments_Time, iteration, module_name, model_name, 
            model_dir, output_dir, batch_size, epochs, num_workers,            
            dataset_dir, label_file, devices,
            in_channels, num_classes, optim_args, lr_scheduler_args, loss_function_args
            ]
        main(args)
        if save_log == True:
            sys.stdout = original_stdout
            log_file.close()
        copy_sourcefile(output_dir, src_dir='src')
import os
print('End')
os._exit(00)