In [1]:
import Transform.Schedule

from Configuration import Editor

with Editor('Config') as Config:

    # settings for dataset
    Config.Dataset.ImagesRootPath = r'D:\Dataset_Collection\Cardiac_Catheterization\train\images'
    Config.Dataset.MasksRootPath = r'D:\Dataset_Collection\Cardiac_Catheterization\train\masks'
    # control input and output image format
    Config.Dataset.IO.InputRGBImage = False
    Config.Dataset.IO.NumWorkers = 0
    Config.Dataset.IO.PinMemory = False
    Config.Dataset.IO.PrefetchFactor = 2
    Config.Dataset.IO.OutputDtype = 'float'
    # uniformed preprocess
    Config.Dataset.Preprocess.Version = 'v1'
    #train dataset
    Config.Dataset.Train.BatchSize = 1
    Config.Dataset.Train.Transform.Combination.Version = 'v1'
    Config.Dataset.Train.Transform.Schedule = 0.8
    Config.Dataset.Train.Transform.Combination.Components = 'default'
    Config.Dataset.Train.Transform.Combination.Params = 'default'
    Config.Dataset.Train.Transform.Combination.Schedules = 'default'
    #validation dataset
    Config.Dataset.Validation.BatchSize = 2
    Config.Dataset.Validation.Ratio = 0.05
    Config.Dataset.Validation.Transform.Combination.Version = 'v1'
    Config.Dataset.Validation.Transform.Schedule = 0.8
    Config.Dataset.Validation.Transform.Combination.Components = 'default'
    Config.Dataset.Validation.Transform.Combination.Params = 'default'
    Config.Dataset.Validation.Transform.Combination.Schedules = 'default'
    
    #choose training structure, including the model, loss, metrics, optimizer, schedular
    Config.Training.Structure.Type = 'SimpleSeg'
    # model
    Config.Training.Structure.Model.Backbone.Name = 'Backbone1'
    Config.Training.Structure.Model.Backbone.Param = dict(ues_instance_norm=True)
    Config.Training.Structure.Model.Head.Name = 'V1'
    Config.Training.Structure.Model.Head.Param = dict(logit_output=True,in_channels=4)
    # loss
    Config.Training.Structure.Loss.Name = 'DiceBCELoss'
    Config.Training.Structure.Loss.Param = dict(use_logit=True,w_bce=0.2)
    # optimizer
    Config.Training.Structure.Optimizer.Name = 'Adam'
    Config.Training.Structure.Optimizer.Param = dict(lr=0.001)
    # scheduler
    Config.Training.Structure.Scheduler.Name = 'CustomSchedule1'
    Config.Training.Structure.Scheduler.Param = dict(warmup_epochs=1,reduce_gamma=-2)
    # metrics
    Config.Training.Structure.Metrics.Name = ['DiceBCELoss']
    Config.Training.Structure.Metrics.Param = [dict(use_logit=True,w_bce=0.2)]

    # setup training
    Config.Training.Checkpoint.Path = None
    Config.Training.Checkpoint.FileName = ''
    Config.Training.Checkpoint.Resume.Process = False
    Config.Training.Checkpoint.Resume.Optimizer = False
    Config.Training.Checkpoint.Resume.Scheduler = False
    
    # whether to freeze backbone
    Config.Training.Settings.Model.FreezeBackbone = True
    
    # setup training process
    Config.Training.Settings.Epochs = 20
    Config.Training.Settings.GradientAccumulation = 4
    Config.Training.Settings.AmpScaleTrain = True
    # set random property
    Config.Training.Settings.Random.cuDNN.Deterministic = True
    Config.Training.Settings.Random.cuDNN.Benchmark = True
    Config.Training.Settings.Random.Seed.Dataset.Split = 4
    Config.Training.Settings.Random.Seed.Dataset.Transform = 10
    Config.Training.Settings.Random.Seed.Dataset.Shuffle = 6
    Config.Training.Settings.Random.Seed.Model = 99
    
    Config.Logging.StepsPerLog = 8
    Config.Logging.Image.Columns = 3
    Config.Logging.Image.Rows = 10
    Config.Logging.Image.Figsize = (300,300)
    Config.Logging.Image.Fontsize = 200
    Config.Logging.Image.DPI = 10
    Config.Logging.Image.MaskAlpha = 0.6
    
    Config.Logging.RootPath = 'logging'
    Config.Logging.Model.Reference = 'Resnet'
    Config.Logging.Model.Derivative = 'WithRegionalAttention'
    Config.Logging.Model.Branch = 'instance_norm02'
    Config.Logging.Comment = 'DiceBCELoss'
    Config.Logging.Purpose = 'None'
    Config.Logging.Note = 'None'

In [2]:
import io
import pathlib
import random
import math
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch.cuda.amp import autocast, GradScaler

import Configuration
from Dataset import TrainDataLoader
import Transform.Preprocess
import Transform.Combinations
import Transform.Schedule
import Structure
import utils.train
import utils.schedulers


Config = Configuration.Config

# create gpu device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# cudnn reproducibility
torch.backends.cudnn.deterministic = Config.Training.Settings.Random.cuDNN.Deterministic
torch.backends.cudnn.benchmark = Config.Training.Settings.Random.cuDNN.Benchmark

# create dataset: train_dataset, validation_dataset, validation_dataset_wo_arg
# train_transform for dataloader
train_transform_creater = getattr(Transform.Combinations, Config.Dataset.Train.Transform.Combination.Version)
train_transform_creater = train_transform_creater(
    Config.Dataset.Train.Transform.Schedule,
    Config.Dataset.Train.Transform.Combination.Components,
    Config.Dataset.Train.Transform.Combination.Params,
    Config.Dataset.Train.Transform.Combination.Schedules)
# validation_transform for dataloader
validation_transform_creater = getattr(Transform.Combinations, Config.Dataset.Validation.Transform.Combination.Version)
validation_transform_creater = validation_transform_creater(
    Config.Dataset.Validation.Transform.Schedule,
    Config.Dataset.Validation.Transform.Combination.Components,
    Config.Dataset.Validation.Transform.Combination.Params,
    Config.Dataset.Validation.Transform.Combination.Schedules)
# standard preprocess operation for dataloader
preprocess = getattr(Transform.Preprocess,Config.Dataset.Preprocess.Version)
# create dataloader
dataloader = TrainDataLoader(
    images_root = Config.Dataset.ImagesRootPath,
    masks_root = Config.Dataset.MasksRootPath,
    train_transform = train_transform_creater,
    train_batch_size = Config.Dataset.Train.BatchSize,
    validation_ratio = Config.Dataset.Validation.Ratio,
    validation_transform = validation_transform_creater,
    validation_batch_size = Config.Dataset.Validation.BatchSize,
    image_rgb = Config.Dataset.IO.InputRGBImage,
    preprocess = preprocess,
    num_workers = Config.Dataset.IO.NumWorkers,
    pin_memory = Config.Dataset.IO.PinMemory,
    prefetch_factor = Config.Dataset.IO.PrefetchFactor,
    dtype=Config.Dataset.IO.OutputDtype)
# get datasets from dataloader
# seed for dataset
dataset_transform_seed = Config.Training.Settings.Random.Seed.Dataset.Transform
random.seed(dataset_transform_seed)
dataset_split_seed = Config.Training.Settings.Random.Seed.Dataset.Split
np.random.seed(dataset_split_seed)
train_dataset, validation_dataset, validation_dataset_wo_arg = dataloader.get_dataset()

# calcualte how many steps in an epoch and scheduler need this
train_data_count = len(train_dataset)
gradient_accumulation = Config.Training.Settings.GradientAccumulation
train_batch_size = Config.Dataset.Train.BatchSize
step_size =  gradient_accumulation * train_batch_size
steps_per_epoch = int(math.ceil(train_data_count/step_size))

# Select the training structure and corresponding to model, loss_fn, metrics
structure = getattr(Structure,Config.Training.Structure.Type)
# seed for model creation
model_seed = Config.Training.Settings.Random.Seed.Model
torch.manual_seed(model_seed)
# start building the model
# build the model backbone
model_backbone_class = getattr(structure.Model.Backbones,Config.Training.Structure.Model.Backbone.Name)
model_backbone_param = Config.Training.Structure.Model.Backbone.Param
model_backbone = model_backbone_class(**model_backbone_param)
# build the model head
model_head_class = getattr(structure.Model.Heads,Config.Training.Structure.Model.Head.Name)
model_head_param = Config.Training.Structure.Model.Head.Param
model_head = model_head_class(**model_head_param)
# create the model
model = utils.train.ModelBuilder(model_backbone,model_head)
model = model.to(device)
# setup loss_fn
loss_fn_class = getattr(structure.Losses,Config.Training.Structure.Loss.Name)
loss_fn = loss_fn_class(**Config.Training.Structure.Loss.Param)
# setup optimizer
optimizer_class = getattr(torch.optim,Config.Training.Structure.Optimizer.Name)
optimizer = optimizer_class(model.parameters(),**Config.Training.Structure.Optimizer.Param)
# scheduler
scheduler_class = getattr(utils.schedulers,Config.Training.Structure.Scheduler.Name)
scheduler = scheduler_class(steps_per_epoch=steps_per_epoch,**Config.Training.Structure.Scheduler.Param)
scheduler = scheduler(optimizer)
# setup metrics and setup recoder
metrics_class = [getattr(structure.Metrics,m) for m in Config.Training.Structure.Metrics.Name]
metrics_params = Config.Training.Structure.Metrics.Param

# initialize process state or resume checkpoint
start_epoch = 0
global_step = 0
images_count = 0
checkpoint_path = Config.Training.Checkpoint.Path
checkpoint_file = Config.Training.Checkpoint.FileName
steps_per_log = Config.Logging.StepsPerLog
if (checkpoint_path != None) or (checkpoint_path == ''):
    checkpoint_file_path = pathlib.Path(checkpoint_path,checkpoint_file).as_posix()
    checkpoint = torch.load(checkpoint_file_path)
    
    model_state_dict = checkpoint['model']
    model.load_state_dict(model_state_dict)
    
    images_count = checkpoint['images_count']
    global_step = checkpoint['global_step']
    
    if Config.Training.Checkpoint.Resume.Process:
        start_epoch = checkpoint['epoch']
    
    if Config.Training.Checkpoint.Resume.Optimizer:
        optimizer.load_state_dict(checkpoint['optimizer'])
        
    if Config.Training.Checkpoint.Resume.Scheduler:
        scheduler.load_state_dict(checkpoint['scheduler'])
    
# if the model's backbone is not going to be trained, but still keep the input and the output layers trainable. 
if Config.Training.Settings.Model.FreezeBackbone:
    model.freeze_backbone()

# setup for logging validation image segmentation examples
ncols = Config.Logging.Image.Columns * 4
nrows = Config.Logging.Image.Rows
figsize = Config.Logging.Image.Figsize
fontsize = Config.Logging.Image.Fontsize
dpi = Config.Logging.Image.DPI
mask_alpha = Config.Logging.Image.MaskAlpha

amp_scale_train = Config.Training.Settings.AmpScaleTrain
# if amp_scale_train:
#     scaler = GradScaler()
scaler = GradScaler(enabled=amp_scale_train)

# seed for dataset shuffle
dataset_shuffle_seed = Config.Training.Settings.Random.Seed.Dataset.Shuffle
torch.manual_seed(dataset_shuffle_seed)

# create recorder
recorder = utils.train.Recorder(
    root = Config.Logging.RootPath,
    reference = Config.Logging.Model.Reference,
    derivative = Config.Logging.Model.Derivative,
    branch = Config.Logging.Model.Branch,
    comment = Config.Logging.Comment,
    purpose = Config.Logging.Purpose,
    dataset_split_seed = dataset_split_seed,
    dataset_transform_seed = dataset_transform_seed,
    dataset_shuffle_seed = dataset_shuffle_seed,
    model_seed = model_seed)
recorder.create_metrics_and_writers(metrics_class,metrics_params)
recorder.log_config(Config)

# start training process
end_epoch = start_epoch + Config.Training.Settings.Epochs
for epoch in range(start_epoch, end_epoch):
    acc_count = 0
    acc_loss = 0
    # set model to train mode
    model.train()
    for batch_train_data in train_dataset(epoch):
        images,masks = batch_train_data
        images = images.to(device)
        masks = masks.to(device)
        with autocast(enabled=amp_scale_train,dtype=torch.float32):
            predicts = model(images)
            # loss = loss_fn(predicts,masks)
            loss = loss_fn(predicts,masks)['loss']
            recorder.update_metrics_state(purpose='train', predict=predicts, label=masks)

        # update process state
        images_count += images.shape[0]
        acc_loss += loss
        acc_count += 1
    
        # if it's time to do back propagation
        if acc_count == gradient_accumulation:
            acc_loss = acc_loss / gradient_accumulation
            scaler.scale(acc_loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            optimizer.zero_grad()
            scheduler.step()
            acc_loss = 0
            acc_count = 0
            global_step += 1
        
            if (global_step % steps_per_log == 0) and global_step != 0:
                recorder.log_metrics_state(purpose='train',images_count=images_count)
                recorder.reset_metrics_state(purpose='train')
                recorder.log_lr(lr=optimizer.param_groups[0]['lr'],images_count=images_count)
    
    
    if global_step % steps_per_log != 0:
        recorder.log_metrics_state(purpose='train',images_count=images_count)
        recorder.reset_metrics_state(purpose='train')
        recorder.log_lr(lr=optimizer.param_groups[0]['lr'],images_count=images_count)
    
    if acc_count != 0:
        acc_loss = acc_loss / acc_count
        scaler.scale(acc_loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        optimizer.zero_grad()
        acc_loss = 0
        acc_count = 0
        global_step += 1
    
    model.eval()
    for purpose, val_dataset in zip(('validation','validation_wo_arg'),(validation_dataset, validation_dataset_wo_arg)):
        
        acc_output_images = 1
        fig = plt.figure(figsize=figsize,dpi=dpi)
        plt.axis(False)
        
        for batch_val_data in val_dataset(epoch):
            images, masks = batch_val_data
            images = images.to(device)
            masks = masks.to(device)
            with autocast(enabled=amp_scale_train,dtype=torch.float32):
                with torch.no_grad():
                    intermediate_predicts = model(images)
                recorder.update_metrics_state(purpose=purpose, predict=intermediate_predicts, label=masks)
            del intermediate_predicts
            
            with autocast(enabled=amp_scale_train,dtype=torch.float32):
                final_predicts = model.predict(images)
            if acc_output_images <= nrows*ncols:
                for image,mask,predict in zip(images.cpu().numpy().squeeze(axis=1),masks.cpu().numpy(),final_predicts.cpu().detach().numpy()):
                    if acc_output_images < nrows*ncols:
                        fig.add_subplot(nrows,ncols,acc_output_images)
                        plt.imshow(image,cmap='gray')
                        plt.title('Input Image',fontsize=fontsize)
                        plt.axis(False)
                        acc_output_images += 1
                        
                        fig.add_subplot(nrows,ncols,acc_output_images)
                        plt.imshow(image,cmap='gray')
                        plt.imshow(mask,cmap='gray',alpha=mask_alpha)
                        plt.title('Mask',fontsize=fontsize)
                        plt.axis(False)
                        acc_output_images += 1
                        
                        fig.add_subplot(nrows,ncols,acc_output_images)
                        plt.imshow(image,cmap='gray')
                        plt.imshow(predict,cmap='gray',alpha=mask_alpha)
                        plt.title('Predict',fontsize=fontsize)
                        plt.axis(False)
                        acc_output_images += 1
                        
                        fig.add_subplot(nrows,ncols,acc_output_images)
                        plt.imshow(predict,cmap='gray',alpha=mask_alpha)
                        plt.title('Predict',fontsize=fontsize)
                        plt.axis(False)
                        acc_output_images += 1
            del final_predicts
            del images
            del masks
        
        scheduler.epoch(recorder)
        
        recorder.log_metrics_state(purpose=purpose,images_count=images_count)
        checkpoint = dict(
            model = model.state_dict(),
            optimizer = optimizer.state_dict(),
            scheduler = scheduler.state_dict(),
            epoch = epoch,
            global_step = global_step,
            images_count = images_count)
        recorder.save_checkpoint(purpose=purpose,checkpoint=checkpoint)
        recorder.reset_metrics_state(purpose='train')
        
        #record image
        plt.tight_layout()
        buf = io.BytesIO()
        plt.savefig(buf, format='raw', dpi=dpi)
        plt.close()
        buf.seek(0)
        img_arr = np.reshape(np.frombuffer(buf.getvalue(), dtype=np.uint8),
                            newshape=(int(fig.bbox.bounds[3]), int(fig.bbox.bounds[2]), -1))
        buf.close()
        recorder.log_image(purpose=purpose,image=img_arr,epoch=epoch)

EPOCH: 0 Training: 100%|██████████| 1631/1631 [05:53<00:00,  4.61it/s]
EPOCH: 0 Validation: 100%|██████████| 43/43 [00:14<00:00,  3.03it/s]
EPOCH: 0 Validation Without Transform: 100%|██████████| 43/43 [00:06<00:00,  6.42it/s]
EPOCH: 1 Training: 100%|██████████| 1631/1631 [05:40<00:00,  4.80it/s]
EPOCH: 1 Validation: 100%|██████████| 43/43 [00:13<00:00,  3.09it/s]
EPOCH: 1 Validation Without Transform: 100%|██████████| 43/43 [00:06<00:00,  6.96it/s]
EPOCH: 2 Training: 100%|██████████| 1631/1631 [05:15<00:00,  5.17it/s]
EPOCH: 2 Validation: 100%|██████████| 43/43 [00:14<00:00,  3.06it/s]
EPOCH: 2 Validation Without Transform: 100%|██████████| 43/43 [00:06<00:00,  6.59it/s]
EPOCH: 3 Training: 100%|██████████| 1631/1631 [05:14<00:00,  5.19it/s]
EPOCH: 3 Validation: 100%|██████████| 43/43 [00:12<00:00,  3.41it/s]
EPOCH: 3 Validation Without Transform: 100%|██████████| 43/43 [00:06<00:00,  6.32it/s]
EPOCH: 4 Training: 100%|██████████| 1631/1631 [05:19<00:00,  5.11it/s]
EPOCH: 4 Validation: 