# Imports

In [None]:
import torch 
import numpy as np
import albumentations as A
from datasets.inria import InriaDataset
from datasets.poland import PolandDataset
from loss_functions.tversky_focal_loss import TverskyFocalLoss
from segmentation_models_pytorch import Unet
from torchmetrics import Accuracy, JaccardIndex, F1Score
import mlflow
import logging
import tqdm
import json
import os
import train_utils

# Training

## Train loop

In [None]:
def train_epoch(model: torch.nn.Module, train_loader: torch.utils.data.DataLoader, epoch: int, epochs: int,
                optimizer: torch.optim.Optimizer,
                criterion: torch.nn.Module, scheduler: torch.optim.lr_scheduler.LRScheduler, 
                warmup_epochs:int, accumulation_steps: int,
                device: str, scaler: torch.amp.grad_scaler.GradScaler):
    """Executes a single epoch. Uses iteration-wise warmup scheduler, gradient accumulation and
    mixed precision training.

    Args:
        model (torch.nn.Module): Model.
        train_loader (torch.utils.data.DataLoader): Train dataloader.
        epoch (int): Current epoch.
        epochs (int): Total num of epochs.
        optimizer (torch.optim.Optimizer): Loss optimizer.
        criterion (torch.nn.Module): Loss function.
        scheduler (torch.optim.lr_scheduler.LRScheduler): LR warmup scheduler (warmup).
        warmup_epochs (int): Num of epochs to use warmup scheduler.
        accumulation_steps (int): Amount of forward passes taken before single backward.
        device (str): CPU / GPU.
        scaler (torch.amp.grad_scaler.GradScaler): Gradient scaler.

    Returns:
        dict: Epoch metrics.
    """    
 
    processed_data = 0
    running_loss = 0.0
    running_accuracy = 0.0
    running_iou = 0.0
    running_dice = 0.0
    accuracy = Accuracy(task='binary', threshold=0.5, average='weighted', ignore_index=255).to(device)
    iou = JaccardIndex(task='binary', threshold=0.5, average='weighted', ignore_index=255).to(device)
    dice = F1Score(task='binary', threshold=0.5, average='weighted', ignore_index=255).to(device)
    
    loop_train = tqdm.tqdm(enumerate(train_loader), total=len(train_loader), leave=False,
                            desc=f'Epoch[{epoch+1}/{epochs}] train')
    model.train()
    optimizer.zero_grad()
    for iteration, (inputs, targets) in loop_train:
        inputs = inputs.to(device, non_blocking=True) 
        targets = targets.to(device, non_blocking=True)
        
        # use mixed precision
        with torch.amp.autocast(device_type=device):
            outputs = torch.sigmoid(model(inputs))
            # scale loss according to acc. steps
            loss = criterion(outputs, targets) / accumulation_steps

        # accumulate loss
        scaler.scale(loss).backward()
        
        # propagate loss each n-th iteration or last iteration of epoch
        if (iteration + 1) % accumulation_steps == 0 or iteration + 1 == len(train_loader):
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            
            # make warmup scheduler step
            if epoch < warmup_epochs and scheduler is not None:
                scheduler.step()
        
        # compute metrics
        inputs = inputs.detach()
        outputs = outputs.detach()
        targets = targets.detach()
        processed_data += inputs.size(0)

        batch_loss = loss.item() * accumulation_steps * inputs.size(0)
        batch_accuracy = accuracy(outputs.squeeze(), targets).item()
        batch_iou = iou(outputs.squeeze(), targets).item()
        batch_dice = dice(outputs.squeeze(), targets).item()
        
        running_loss += batch_loss
        running_accuracy += batch_accuracy * inputs.size(0)
        running_iou += batch_iou * inputs.size(0)
        running_dice += batch_dice * inputs.size(0)

        # log metrics each 100-th iteration
        if (iteration + 1) % 100 == 0:
            mlflow.log_metrics({'train_batch_loss': batch_loss,
                                'train_loss_propagated': loss.item(),
                                'train_batch_accuracy': batch_accuracy,
                                'train_batch_iou': batch_iou,
                                'train_batch_dice': batch_dice,}, 
                                 step=iteration + 1)  
        
        loop_train.set_postfix(batch=iteration+1, loss=loss.item(), iou=batch_iou)   
        
    return {'train_avg_loss': running_loss / processed_data,
            'train_avg_accuracy': running_accuracy / processed_data,
            'train_avg_iou_score': running_iou / processed_data,
            'train_avg_dice_score': running_dice / processed_data}

## Val loop

In [None]:
def val_epoch(model: torch.nn.Module, val_loader: torch.utils.data.DataLoader, 
              epoch: int, epochs: int, criterion: torch.nn.Module, device: str):
    
    """Executes a single validation epoch.

    Args:
        model (torch.nn.Module): Model.
        val_loader (torch.utils.data.DataLoader): Validation dataloader.
        epoch (int): Current epoch.
        epochs (int): Total num of epochs.
        criterion (torch.nn.Module): Loss function.
        device (str): CPU / GPU.

    Returns:
        dict: Epoch metrics.
    """     
    
    processed_data = 0
    running_loss = 0.0
    running_accuracy = 0.0
    running_iou = 0.0
    running_dice = 0.0
    accuracy = Accuracy(task='binary', threshold=0.5, average='weighted', ignore_index=255).to(device)
    iou = JaccardIndex(task='binary', threshold=0.5, average='weighted', ignore_index=255).to(device)
    dice = F1Score(task='binary', threshold=0.5, average='weighted', ignore_index=255).to(device)
    
    
    model.eval()
    with torch.no_grad():
        loop_val = tqdm.tqdm(enumerate(val_loader), total=len(val_loader), leave=False,
                                desc=f'Epoch[{epoch+1}/{epochs}] val') 
        
        for iteration, (inputs, targets) in loop_val:
            inputs = inputs.to(device, non_blocking=True) 
            targets = targets.to(device, non_blocking=True)
            outputs = torch.sigmoid(model(inputs))
            loss = criterion(outputs, targets)
            
            inputs = inputs.detach()
            outputs = outputs.detach()
            targets = targets.detach()
            processed_data += inputs.size(0)

            batch_loss = loss.item() * inputs.size(0)
            batch_accuracy = accuracy(outputs.squeeze(), targets).item()
            batch_iou = iou(outputs.squeeze(), targets).item()
            batch_dice = dice(outputs.squeeze(), targets).item()
            
            running_loss += batch_loss
            running_accuracy += batch_accuracy * inputs.size(0)
            running_iou += batch_iou * inputs.size(0)
            running_dice += batch_dice * inputs.size(0)
            
            # log metrics and examples of inference
            if iteration % 100 == 0:
                mlflow.log_metrics({'val_batch_loss': batch_loss,
                                    'val_loss_propagated': loss.item(),
                                    'val_batch_accuracy': batch_accuracy,
                                    'val_batch_iou': batch_iou,
                                    'val_batch_dice': batch_dice,}, 
                                    step=iteration + 1)  
                try:
                    mlflow.active_run().info.run_name
                    example = train_utils.make_example(inputs, targets, outputs)
                    mlflow.log_image(example, artifact_file=f'example_{iteration+1}.png')
                except Exception as e:
                    with open('artifacts\\log.txt', 'a') as f:
                        f.write(str(f'{e}\n'))
            
            loop_val.set_postfix(batch=iteration+1, loss=loss.item(), iou=batch_iou)
           
    return {'val_avg_loss': running_loss / processed_data,
            'val_avg_accuracy': running_accuracy / processed_data,
            'val_avg_iou_score': running_iou / processed_data,
            'val_avg_dice_score': running_dice / processed_data} 
    
            

## Fit function

In [None]:
def fit(run_id: str, model: torch.nn.Module, criterion: torch.nn.Module,
        epochs: int, optimizer: torch.optim.Optimizer, patience: int, train_loader: torch.utils.data.DataLoader, 
        val_loader: torch.utils.data.DataLoader, scheduler: torch.optim.lr_scheduler.LRScheduler,
        warmup_scheduler: torch.optim.lr_scheduler.LRScheduler, 
        warmup_epochs: int, acc_steps: int, device: str):
    """Performs training of a model. Uses task-specific segmentation metrics.

    Args:
        run_id (str): MLflow run id.
        model (torch.nn.Module): Model.
        criterion (torch.nn.Module): Loss function.
        epochs (int): Total num of epochs.
        optimizer (torch.optim.Optimizer): Loss optimizer.
        patience (int): Num of epochs with no improvement to halt the training.
        train_loader (torch.utils.data.DataLoader): Training dataloader.
        val_loader (torch.utils.data.DataLoader): Validation dataloader.
        scheduler (torch.optim.lr_scheduler.LRScheduler): LR scheduler.
        warmup_scheduler (torch.optim.lr_scheduler.LRScheduler): Warmup LR scheduler.
        warmup_epochs (int): Num of epochs to use warmup scheduler.
        acc_steps (int): Amount of forward passes taken before single backward.
        device (str): CPU / GPU.

    Returns:
        dict: Training and validation metrics.
    """      
    parent_run_name = mlflow.get_run(run_id).info.run_name
    
    mlflow.set_tag('status', 'training')
    best_dice = float('-inf')
    best_iou = float('-inf')
    counter = 0

    # initialize metric lists
    train_losses = []
    train_accuracy_scores = []
    train_iou_scores = []
    train_dice_scores = []
    
    val_losses = []
    val_accuracy_scores = []
    val_iou_scores = []
    val_dice_scores = []

    
    scaler = torch.amp.GradScaler()
    # training
    for epoch in range(epochs):
        # train loop and lr logging
        mlflow.log_metrics({f'group_{i}_lr': lr for (i, lr) in enumerate(scheduler.get_last_lr())}, step=epoch+1, run_id=run_id)
        mlflow.start_run(run_name=f'Epoch_{epoch+1}', nested=True, parent_run_id=run_id)
        train_results = train_epoch(model, train_loader, epoch, epochs, optimizer, criterion, 
                                    warmup_scheduler, warmup_epochs, acc_steps, device, scaler)
        # save checkpoint 
        checkpoint_path = f'artifacts\\{parent_run_name}\\checkpoints\\epoch{epoch+1}'
        train_utils.save_checkpoint(epoch, checkpoint_path, model, optimizer, scheduler, scaler, run_id=run_id)
        
        # refresh lrs 
        if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
            scheduler.step(train_results['train_avg_loss'])
        elif scheduler is not None:
            scheduler.step()

        # log and update metrics
        mlflow.log_metrics(train_results, step=epoch+1, run_id=run_id)  
        train_losses.append(train_results['train_avg_loss'])
        train_accuracy_scores.append(train_results['train_avg_accuracy'])
        train_iou_scores.append(train_results['train_avg_iou_score'])
        train_dice_scores.append(train_results['train_avg_dice_score'])
        
        # val loop
        val_results = val_epoch(model, val_loader, epoch, epochs, criterion, device)
  
        mlflow.log_metrics(val_results, step=epoch+1, run_id=run_id)
        val_losses.append(val_results['val_avg_loss'])
        val_accuracy_scores.append(val_results['val_avg_accuracy'])
        val_iou_scores.append(val_results['val_avg_iou_score'])
        val_dice_scores.append(val_results['val_avg_dice_score'])
        

        # early stopping
        indexing_range = patience if patience <= (epoch + 1) else (epoch + 1)
        
        save_condition = val_dice_scores[-1] > best_dice or val_iou_scores[-1] > best_iou
        
        if epoch > 0:
            val_dice_array = np.array(val_dice_scores)
            val_iou_array = np.array(val_iou_scores)
            plateau_condition = (np.abs(np.diff(val_dice_array[-indexing_range:])).mean() < 0.01 and \
            np.abs(np.diff(val_iou_array[-indexing_range:])).mean() < 0.01)
        else:
            plateau_condition = False

        if save_condition:
            counter = 0
            best_dice = max(best_dice, val_dice_scores[-1])
            best_iou = max(best_iou, val_iou_scores[-1])     
            mlflow.log_metric('best_epoch', epoch+1, step=epoch+1, run_id=run_id)
        elif plateau_condition:
            counter += 1
            if counter > patience:
                mlflow.set_tag('status', 'early stopping because of plateau')
                break
        else:
            counter += 1
            if counter > patience:
                mlflow.set_tag('status', 'early stopping')
                break   
        mlflow.end_run()
        
    mlflow.pytorch.log_model(model)
    res = {'train_losses': train_losses,
           'train_accuracy_scores': train_accuracy_scores,
           'train_iou_scores': train_iou_scores,
           'train_dice_scores': train_dice_scores,
           'val_losses': val_losses,
           'val_accuracy_scores': val_accuracy_scores,
           'val_iou_scores': val_iou_scores,
           'val_dice_scores': val_dice_scores}
           
    return res           

# Training setup

In [None]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
mlflow.set_tracking_uri('http://127.0.0.1:5000')
os.environ['MLFLOW_SUPPRESS_PRINTING_URL_TO_STDOUT'] = '1'
experiment = mlflow.set_experiment('test')
parent_run = mlflow.start_run(run_name='test')
parent_id = parent_run.info.run_id
parent_name = parent_run.info.run_name
os.makedirs(f'artifacts\\{parent_name}\\checkpoints')
mlflow.set_tag('status', 'setting up')
mlflow_logger = logging.getLogger('mlflow')
mlflow_logger.setLevel(logging.ERROR)

In [None]:
batch_size = 4
accumulation_steps = 2
warmup_epochs = 3
num_epochs = 20
patience = 6

In [None]:
train_augs = A.Compose([A.OneOf([
                            A.Resize(256, 256, p=0.2),
                            A.Compose([
                                A.OneOf([
                                    A.Resize(384, 384, p=0.625),
                                    A.Resize(512, 512, p=0.375)
                                ], p=1.0),
                                A.ShiftScaleRotate(
                                    shift_limit=0, scale_limit=(-0.2, 0.2), rotate_limit=0,
                                    fill=255, fill_mask=255, p=0.6),
                                A.ShiftScaleRotate(
                                    shift_limit=0, scale_limit=0, rotate_limit=(-30, 30),
                                    fill=255, fill_mask=255, p=0.6)
                            ], p=0.8)], p=1.0),

                        A.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0, p=0.4),
                        A.RandomGamma((85, 115), p=0.4),
                        A.GaussianBlur((3, 4), (0.3, 0.8), p=0.4),
                        A.GaussNoise((0, 0.1), (0, 0.05), p=0.3),

                        A.PadIfNeeded(min_height=512,
                                    min_width=512,
                                    fill=255,
                                    fill_mask=255),
                        # A.Normalize(mean=(0.485, 0.456, 0.406),   IMAGENET STATS
                        #         std=(0.229, 0.224, 0.225)),
                
                        A.Normalize(mean=(0.4014, 0.4235, 0.3888), #INRIA STATS
                                std=(0.1708, 0.1555, 0.1457)),
                        A.ToTensorV2()], additional_targets={'mask': 'mask'})

In [None]:
val_augs = A.Compose([A.Resize(384, 384),
                        A.PadIfNeeded(min_height=512,
                                min_width=512,
                                fill=255,
                                fill_mask=255),
                        A.Normalize(mean=(0.4014, 0.4235, 0.3888), #INRIA STATS
                                std=(0.1708, 0.1555, 0.1457)),
                        A.ToTensorV2()], additional_targets={'mask': 'mask'})                

In [None]:
# create and log datasets

idx_a = np.random.permutation(180)[:120] 
np.random.shuffle(idx_a)
split = int(len(idx_a) * 0.8)
train_inria_idx, val_inria_idx = idx_a[:split], idx_a[split:]
np.save(f'artifacts\\{parent_name}\\train_inria_idx.npy', train_inria_idx)
np.save(f'artifacts\\{parent_name}\\val_inria_idx.npy', val_inria_idx)
mlflow.log_artifact(f'artifacts\\{parent_name}\\train_inria_idx.npy')
mlflow.log_artifact(f'artifacts\\{parent_name}\\val_inria_idx.npy')
inria_train_dataset = InriaDataset(mode='train', idx=train_inria_idx, res=512,
                             overlap=0.33, transforms = train_augs)
inria_val_dataset = InriaDataset(mode='val', idx=val_inria_idx, res=512, 
                                 overlap=0.0, transforms =val_augs)

idx_b = np.random.permutation(10674)[:3250]
np.random.shuffle(idx_b)
split = int(len(idx_b) * 0.8)
train_poland_idx, val_poland_idx = idx_b[:split], idx_b[split:]
np.save(f'artifacts\\{parent_name}\\train_poland_idx.npy', train_poland_idx)
np.save(f'artifacts\\{parent_name}\\val_poland_idx.npy', val_poland_idx)
mlflow.log_artifact(f'artifacts\\{parent_name}\\train_poland_idx.npy')
mlflow.log_artifact(f'artifacts\\{parent_name}\\val_poland_idx.npy')
poland_train_dataset = PolandDataset(idx=train_poland_idx, transforms=train_augs)
poland_val_dataset = PolandDataset(idx=val_poland_idx, transforms=val_augs)

train_dataset = torch.utils.data.ConcatDataset([inria_train_dataset, poland_train_dataset])
val_dataset = torch.utils.data.ConcatDataset([inria_val_dataset, poland_val_dataset])

In [None]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                           shuffle=True, pin_memory=True,
                                           drop_last=True)

val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size,
                                         shuffle=True, pin_memory=True, drop_last=True)

In [None]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Unet(encoder_name='resnet50', encoder_weights='imagenet').to(DEVICE)

In [None]:
criterion = TverskyFocalLoss('binary', alpha_tversky=0.7, beta_tversky=0.3,
                             alpha_focal=0.7, gamma_focal=1.5, ignore_index=255,
                             smoothing=0.8, weight_focal=0.4, weight_tversky=0.6).to(DEVICE)

In [None]:
modules = [model.encoder.layer1, model.encoder.layer2, model.encoder.layer3, model.encoder.layer4, 
           model.decoder, model.segmentation_head]

separated_params = {name: train_utils.separate_norms(module) for(name, module) in 
                    zip(('layer1', 'layer2', 'layer3', 'layer4', 'decoder', 'head'), modules)}

opt_params = [{"params": separated_params['layer1'][0], "weight_decay": 1e-4, 'lr': 1e-5},
              {"params": separated_params['layer1'][1], "weight_decay": 0.0, 'lr': 1e-5},
              {"params": separated_params['layer2'][0], "weight_decay": 1e-4, 'lr': 3e-5},
              {"params": separated_params['layer2'][1], "weight_decay": 0.0, 'lr': 3e-5},
              {"params": separated_params['layer3'][0], "weight_decay": 1e-4, 'lr': 1e-4},
              {"params": separated_params['layer3'][1], "weight_decay": 0.0, 'lr': 1e-4},
              {"params": separated_params['layer4'][0], "weight_decay": 1e-4, 'lr': 3e-4},
              {"params": separated_params['layer4'][1], "weight_decay": 0.0, 'lr': 3e-4},
              {"params": separated_params['decoder'][0], "weight_decay": 1e-4, 'lr': 3e-4},
              {"params": separated_params['decoder'][1], "weight_decay": 0.0, 'lr': 3e-4},
              {"params": separated_params['head'][0], "weight_decay": 1e-4, 'lr': 3e-4},
              {"params": separated_params['head'][1], "weight_decay": 0.0, 'lr': 3e-4}]

optimizer = torch.optim.AdamW(params=opt_params, lr=3e-4, foreach=True)

In [None]:
warmup_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer=optimizer,
                                                     start_factor=0.1,
                                                     end_factor=1.0,
                                                     total_iters=train_utils
                                                     .count_warmup_iters(batch_size, accumulation_steps, warmup_epochs, num_epochs))  

main_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, 
                                                            T_max=num_epochs-warmup_epochs,
                                                            eta_min=1e-6)

In [None]:
mlflow.log_params({
    "batch_size": batch_size,
    "accumulation_epochs": accumulation_steps,
    "warmup_epochs": warmup_epochs,
    "num_epochs": num_epochs,
    "patience": patience,
    "encoder": "resnet50",
    "encoder_weights": "imagenet",
    "optimizer": "AdamW",
    "main_scheduler": "CosineAnnealingLR",
    "base_lr": 3e-4,
    "criterion": "TverskyFocalLoss",
    "criterion_params": {
        "alpha_tversky": 0.7,
        "beta_tversky": 0.3,
        "alpha_focal": 0.7,
        "gamma_focal": 1.5,
        "ignore_index": 255
    },
    "inria_params":{
        'res': 512,
        'overlap': 0.125
    }
})

mlflow.log_text(json.dumps(train_augs.to_dict(), indent=2), 'train_augmentations.json')
mlflow.log_text(json.dumps(val_augs.to_dict(), indent=2), 'val_augmentations.json')

In [None]:
results = fit(parent_id, model, criterion, num_epochs,
              optimizer, patience, train_loader, val_loader,
              main_scheduler, warmup_scheduler, warmup_epochs,
              accumulation_steps, DEVICE)

In [None]:
mlflow.set_tag('status', 'success')
mlflow.end_run()


In [None]:
mlflow.end_run()

# ONNX Conversion

In [None]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Unet(encoder_name='resnet50', encoder_weights='imagenet').to(DEVICE)
checkpoint = torch.load('artifacts\\Unet_ResNet50_try_5\\checkpoints\\best_epoch_21.pth')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

In [None]:
dummy_input = torch.randn(1, 3, 384, 384, device=DEVICE)
onnx_file_path = '..\\models\\test\\Unet_1.0_fp16.onnx'

with torch.no_grad():
    torch.onnx.export(
        model,  
        dummy_input,  
        onnx_file_path, 
        export_params=True,
        opset_version=20,
        do_constant_folding=True,
        input_names=["input"],
        output_names=["output"],
        dynamic_axes={
            "input": {0: "batch_size"},
            "output": {0: "batch_size"},
        },
    )