

### Image segementaion model training code




## 1. Import packages

In [None]:
!pip install torchvision --upgrade
!pip install grad-cam
!pip install timm
!pip install imagecodecs
!pip install pytorchtools
!pip install git+https://github.com/qubvel/segmentation_models.pytorch

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
from torchvision.datasets import VisionDataset
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, Subset

import os 
from os import path

import numpy as np
import pandas as pd
from scipy.io import loadmat

from tqdm import tqdm
from PIL import Image

# read tiff
import zipfile
from tifffile import imread
from torchvision.transforms import ToTensor
import random
import csv
import matplotlib.pyplot as plt
import cv2 as cv

In [None]:
from unet import UNet
from keyholeDataset import Keyhole
from loss import DiceBCEWithActivationLoss 
from augmentation import get_training_augmentation, preprocess
from utils import plot_2_sidebyside, plot_3_sidebyside, save_model, save_loss_record
from iou import iou_numpy
from train import train
from validation import validation
import segmentation_models_pytorch as smp


## 2. Initiate a model

In [None]:
del model

In [None]:
#model = torch.hub.load('milesial/Pytorch-UNet', 'unet_carvana', pretrained=True, scale=0.5)

#model = UNet(n_channels=3, n_classes=1, bilinear=1)

# resnet50, mobilenet_v2,
model = smp.DeepLabV3( #.DeepLabV3 # .Unet
    encoder_name="resnet50",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights=None,     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,                      # model output channels (number of classes in your dataset)
)

model_name = "DLV3+Res50_Split5"
csv_split_name = "/image_and_split_5.csv"

torch.cuda.empty_cache()
model.cuda()

In [None]:
# all models other than pure unet
model.segmentation_head = nn.Sequential(*list(model.segmentation_head.children())[:-1])

In [None]:
# #load model
# path = "/content/drive/MyDrive/DL_segmentation_models/UnetRes50_Split4_epoch_107"
# checkpoint = torch.load(path)
# model.load_state_dict(checkpoint['model_state_dict'])
# for key, value in checkpoint.items():
#     print(key)

## 3. load data + specify batch_size and epochs

In [None]:
!mkdir Keyhole

from google.colab import drive
drive.mount('/content/drive')

with zipfile.ZipFile('/content/drive/MyDrive/DL_segmentation_data/keyhole_segmentation_data.zip', 'r') as zip:
  zip.extractall(path='/content/Keyhole')


cuda = torch.cuda.is_available()
device = torch.device("cuda" if cuda else "cpu")
num_workers = 4 if cuda else 0
print("Cuda = " + str(cuda)+" with num_workers = "+str(num_workers))


In [None]:
# need to write config file to make this part elegent
batch_size = 2
epochs = 300


train_dataset = Keyhole('/content/Keyhole/keyhole_segmentation_data', 
                        transform=get_training_augmentation(),
                        preprocess=None,
                        mode="train", 
                        csv_name=csv_split_name)
val_dataset = Keyhole('/content/Keyhole/keyhole_segmentation_data', 
                      transform=None, 
                      preprocess=None, 
                      mode="val", 
                      csv_name=csv_split_name)
test_dataset = Keyhole('/content/Keyhole/keyhole_segmentation_data', 
                       transform=None, 
                       preprocess=None, 
                       mode="test", 
                       csv_name=csv_split_name)

print(f"Train size: {len(train_dataset)}")
print(f"Valid size: {len(val_dataset)}")
print(f"Test size: {len(test_dataset)}")

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=1)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=1)

In [None]:
pred_masks = []
iou_record = []
for i, batch in enumerate(test_loader):
      print("i = ", i)
      x = batch['image'].float().to(device)
      y = batch['mask'].float().to(device) 
      assert(len(x) == len(y))
      print(x.shape)
      yp = model(x)

      for i in range(len(x)):
        x_ = x[i].unsqueeze(0)
        y_ = y[i].unsqueeze(0)
        yp_ = yp[i]
        print(y_.shape)

        # plot_2_sidebyside( 
        #               y_.detach().cpu().numpy()[0][0].astype(int),
        #               (yp_.detach().cpu().numpy()[0]>0.5).astype(int))
        iou_score = iou_numpy((yp_.detach().cpu().numpy()[0]>0.5).astype(int), y_.detach().cpu().numpy()[0][0].astype(int))
        print("iou: ", iou_score)
        iou_record.append(iou_score)

      # print("yp shape", yp.shape)#torch.Size([1, 1, 572, 572])
      # plot_2_sidebyside(x.detach().cpu().numpy()[0][0], 
      #                 y.detach().cpu().numpy()[0][0])
      
      # plot_2_sidebyside( 
      #                 y.detach().cpu().numpy()[0][0],
      #                 (yp.detach().cpu().numpy()[0][0]>0.5).astype(int))
                      
      
      # plot_3_sidebyside(x.detach().cpu().numpy()[0][0], 
      #                 y.detach().cpu().numpy()[0][0], 
      #                 (yp.detach().cpu().numpy()[0][0]>0.5).astype(int))


In [None]:
np.mean(iou_record)

In [None]:
np.std(iou_record)


## 4. Model training

In [None]:
# #del model
# torch.cuda.empty_cache()
# model.cuda()

In [None]:
from torchsummary import summary
summary(model, (3, 576, 576))

In [None]:
 # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
 # https://github.com/milesial/Pytorch-UNet/blob/master/train.py
optimizer =  optim.RMSprop(model.parameters(), lr=1e-5, weight_decay=1e-8, momentum=0.99) # 0.99
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=15)  # goal: maximize Dice score
grad_scaler = torch.cuda.amp.GradScaler(enabled=True)
criterion = DiceBCEWithActivationLoss() #nn.BCEWithLogitsLoss()

In [None]:
# No early stop
epochs = 300
amp = True
train_loss_record= []
val_loss_record= []
lr_record = []
# record the # of times lr changes
prev_lr = 100; # 100 to simulate int.max_value
lr_count = -1


In [None]:
#save_loss_record(train_loss_record, val_loss_record, lr_record, model_name+".csv")

In [None]:
for epoch in range(0, epochs+1):
  # lr - early stop
  curr_lr = optimizer.param_groups[0]['lr']
  lr_record.append(curr_lr)
  print('New peoch lr: ', curr_lr)
  if curr_lr < prev_lr:
    prev_lr = curr_lr
    lr_count += 1
  if (lr_count == 3):
    print("Early Stop")
    save_model(model, epoch, model_name, optimizer, scheduler, grad_scaler, batch_size)
    save_loss_record(train_loss_record, val_loss_record, lr_record, model_name+".csv")
    break
  # train
  train_loss = train(model, device, train_loader, optimizer, criterion, scheduler, grad_scaler, epoch, epochs, amp=True)
  train_loss_record.append(train_loss)
  # validation
  val_loss = validation(model, device, val_loader, optimizer, criterion, scheduler, epoch, epochs, amp=True)
  val_loss_record.append(val_loss)


## 5. Save model and and loss data

In [None]:
save_model(model, epoch, model_name, optimizer, scheduler, grad_scaler, batch_size)

In [None]:
plt.figure(figsize=(10,10))
plt.plot(train_loss_record)
plt.plot(val_loss_record)

In [None]:
# save_model(model, 50, "unet_test", optimizer, scheduler, 1)

from google.colab import files
files.download('loss.py') 

In [None]:
# checkpoint = torch.load("Unet_MobV3_Nopretrain_epoch_56")
# model.load_state_dict(checkpoint['model_state_dict'])

In [None]:
save_loss_record(train_loss_record, val_loss_record, lr_record, model_name+".csv")

In [None]:
# !cp DeepLabV3_Nopretrain_epoch_190 /content/drive/MyDrive/DL_segmentation_models

In [None]:
save_model(model, 0, "test", optimizer, scheduler, grad_scaler, batch_size)

## 6. check test loss

In [None]:
test_loss = validation(model, device, test_loader, optimizer, criterion, scheduler, 0, epochs, amp=True)

In [None]:
test_loss

In [None]:
val_loss = validation(model, device, val_loader, optimizer, criterion, scheduler, 0, epochs, amp=True)

In [None]:
val_loss

In [None]:
train_loss = validation(model, device, train_loader, optimizer, criterion, scheduler, 0, epochs, amp=True)
train_loss