In [1]:
# Install required packages

!pip install kornia
!pip install torchmetrics

# Set some parameters

class Argument():
    def __init__(self):
        self.base_dir = '/kaggle/input/segmentation-dataset/'
        self.save_dir = '/kaggle/working/'
        self.IMG_WIDTH = 128
        self.IMG_HEIGHT = 128
        self.IMG_CHANNELS = 3
        self.width_out = 128
        self.height_out = 128
        self.batch_size = 32
        self.learning_rate = 0.01
        self.epochs = 100
        self.epoch_lapse = 10
        self.threshold = 0.33
        self.sample_size = None
        self.ishybrid = False
        self.isAttention = False
        self.attn_type = 'cosine' # 'cosine', regular_pointwise', 'regular_full', 'regular_full_dim_add' 'channel_attention'
        self.use_split_data = True
        self.power = 2
        self.is_swap_coeffs = False
        self.is_train_data_aug = False
        self.is_apply_color_jitter = False
        self.save_model = True

args = Argument()

In [4]:
import os
os.chdir(args.base_dir)
import sys
import random
import warnings

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from tqdm import tqdm
from itertools import chain
from sklearn.model_selection import train_test_split
from skimage.io import imread, imshow, imread_collection, concatenate_images
from skimage.transform import resize
from skimage.morphology import label
import torch

from loss_utils import *
from model_utils import *
from dataset_train_val import Dataset_train_val
from training_validation_utils import *
from plotutils import *


warnings.filterwarnings('ignore', category=UserWarning, module='skimage')
seed = 42
random.seed = seed
np.random.seed = seed

use_gpu = torch.cuda.is_available()
dtv = Dataset_train_val(args.batch_size, use_gpu)

def create_model_name(args):
    if args.ishybrid == True:
        if args.is_swap_coeffs == True:
            model_name = 'unet_hybrid_power'+str(args.power)+'_swap_coeffs'
        else:
            model_name = 'unet_hybrid_power'+str(args.power)
    else:
        if args.isAttention == True:
            model_name = 'unet_'+args.attn_type+'_attn'
        else:
            model_name = 'unet_no_attn'
    if args.is_train_data_aug == True:
        model_name = model_name + '_augmentation'
    return model_name + '.pt'

if args.use_split_data == True:
    x_train = torch.load('x_train_split.pt')
    x_val = torch.load('x_val_split.pt')
    y_train = torch.load('y_train_split.pt')
    y_val = torch.load('y_val_split.pt')
    x_test = torch.load('x_test.pt')

else:
    x_train, x_val, y_train, y_val, x_test = dtv.load_train_val_test('x_train.pt', 'y_train.pt', 'x_test.pt', split_ratio=0.2)
# print(x_train.shape,x_val.shape,x_test.shape)

# Train Model

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from tqdm import trange
from time import sleep

if args.ishybrid == True:
    print('Training Hybrid Net')
    flatten_arg = False # for hybrid net
    unet = Hybrid_Net(img_ch=3,output_ch=1) 
    criterion = DiceBCELossModified(full_flatten=flatten_arg) # nn.CrossEntropyLoss() # TverskyBCELoss()
else:
    if args.isAttention == True:
        print('Training Attention U-Net')
        unet = AttU_Net(img_ch=3,output_ch=1, attn_type = args.attn_type)
    else:
        print('Training Vanilla U-Net')
        unet = UNet(in_channel=3,out_channel=1)
    criterion = DiceBCELoss() # nn.CrossEntropyLoss() # TverskyBCELoss()

if use_gpu:
    unet = unet.cuda()
optimizer = optim.SGD(unet.parameters(), lr = args.learning_rate, momentum=0.99)
train_losses, val_losses = Training_Validation(ishybrid=args.ishybrid,power=args.power,swap_coeffs=args.is_swap_coeffs,is_train_data_aug=args.is_train_data_aug,is_apply_color_jitter=args.is_apply_color_jitter).train_valid(unet, x_train , y_train, x_val, y_val, optimizer, criterion, args.batch_size, dtv, use_gpu, epochs = args.epochs, epoch_lapse = args.epoch_lapse)

if args.save_model == True:
    torch.save(unet,args.save_dir+create_model_name(args))

**Train Model**

**Plot Results**

In [5]:
allp = Allplots()

if args.ishybrid == True:
    allp.plot_losses(train_losses, [])
else:
    allp.plot_losses(train_losses, val_losses)

allp.plot_examples(unet, x_train, y_train, 12, ishybrid=args.ishybrid)

allp.plot_examples(unet, x_val, y_val, 12, ishybrid=args.ishybrid)

**Display Metrics**

In [6]:
from segment_metrics import IOU_eval

iou_ev = IOU_eval(ishybrid=args.ishybrid)

print('Training')
iou_t, iou_t_indices = iou_ev.iou_evaluate(unet, x_train, y_train)
allp.plot_best(unet, x_train, datay=y_train, indx=np.argsort(iou_t)[-5:], index_ranks=iou_t_indices, ishybrid=args.ishybrid)

print('Validation')
iou_v, iou_v_indices = iou_ev.iou_evaluate(unet, x_val, y_val)
allp.plot_best(unet, x_val, datay=y_val, indx=np.argsort(iou_v)[-5:], index_ranks=iou_v_indices, ishybrid=args.ishybrid)

print('Testing')
allp.plot_best(unet, x_test, datay=None, indx=np.random.permutation(65)[:5], index_ranks=np.zeros(5), ishybrid=args.ishybrid)

print('Train Mean IOU: '+str(np.mean(iou_t))+', Valid Mean IOU: '+str(np.mean(iou_v)))