

### Image segementaion model (UNet) Fine-tune and Inference Code




## 1. Import packages

In [None]:
!pip install torchvision --upgrade
!pip install grad-cam
!pip install timm
!pip install imagecodecs
!pip install pytorchtools
!pip install git+https://github.com/qubvel/segmentation_models.pytorch

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
from torchvision.datasets import VisionDataset
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, Subset

import os
from os import path

import numpy as np
import pandas as pd
from scipy.io import loadmat

from tqdm import tqdm
from PIL import Image

# read tiff
import zipfile
from tifffile import imread
from torchvision.transforms import ToTensor
import random
import csv
import matplotlib.pyplot as plt
import cv2 as cv

In [None]:
# dont forget to upload all .py files in "utils" and "unet_model" folders

In [None]:
from unet import UNet
from keyholeDataset import Keyhole
from loss import DiceBCEWithActivationLoss
from augmentation import get_training_augmentation, preprocess
from utils import plot_2_sidebyside, plot_3_sidebyside, save_model, save_loss_record
from iou import iou_numpy
from train import train
from validation import validation
import segmentation_models_pytorch as smp


## 2. Initiate a model

In [None]:
model = UNet(n_channels=3, n_classes=1, bilinear=1)
model_name = "UNet"
torch.cuda.empty_cache()
model.cuda()

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# load pretrained weights
# you need to upload the "public_trained_models" folders to your own google drive
# this folder is available for download at https://drive.google.com/drive/folders/1PjvG199PSNGER255jMh35cCw4MV0Lp3G?usp=share_link
# you can choose any one out the 5 available
path = "/content/drive/MyDrive/public_trained_models/Unet_segmentation/Unet_Split1_epoch_153"
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model_state_dict'])
for key, value in checkpoint.items():
    print(key)

## 3. load data + specify batch_size and epochs

In [None]:
# create a folder called Keyhole, and import your annotated data to this folder
# your data will be a zip folder uploaded to google drive
# in case of fine-tuning, this will be your fine-tune data
# one folder "images", one folder "masks", and one csv file containing how you split the data

!mkdir Keyhole

with zipfile.ZipFile('/content/drive/MyDrive/keyhole_segmentation_data/keyhole_segmentation_data.zip', 'r') as zip:
  zip.extractall(path='/content/Keyhole')

# you need to create this csv file for your own fine-tune data
# left colum is image name as in images folder, right colomn is 1train(80%), 0val(105),2test(10%)
# change this to your csv file name
csv_split_name = "/image_and_split_1.csv"

cuda = torch.cuda.is_available()
device = torch.device("cuda" if cuda else "cpu")
num_workers = 2 if cuda else 0
print("Cuda = " + str(cuda)+" with num_workers = "+str(num_workers))


In [None]:
batch_size = 2

train_dataset = Keyhole('/content/Keyhole/keyhole_segmentation_data',
                        transform=get_training_augmentation(),
                        preprocess=None,
                        mode="train",
                        csv_name=csv_split_name)
val_dataset = Keyhole('/content/Keyhole/keyhole_segmentation_data',
                      transform=None,
                      preprocess=None,
                      mode="val",
                      csv_name=csv_split_name)
test_dataset = Keyhole('/content/Keyhole/keyhole_segmentation_data',
                       transform=None,
                       preprocess=None,
                       mode="test",
                       csv_name=csv_split_name)

print(f"Train size: {len(train_dataset)}")
print(f"Valid size: {len(val_dataset)}")
print(f"Test size: {len(test_dataset)}")

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=1)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=1)


## 4. Model training

In [None]:
# #del model
# torch.cuda.empty_cache()
# model.cuda()

In [None]:
from torchsummary import summary
summary(model, (3, 576, 576))

In [None]:
 # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
 # https://github.com/milesial/Pytorch-UNet/blob/master/train.py

 # you can experiment with lower lr bc finetune
optimizer =  optim.RMSprop(model.parameters(), lr=1e-5, weight_decay=1e-8, momentum=0.99)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=15)  # goal: maximize Dice score
grad_scaler = torch.cuda.amp.GradScaler(enabled=True)
criterion = DiceBCEWithActivationLoss() #nn.BCEWithLogitsLoss()

In [None]:
epochs = 50 # you can reduce epochs for fine-tuning
amp = True
train_loss_record= []
val_loss_record= []
lr_record = []
# record the # of times lr changes
prev_lr = 100; # 100 to simulate int.max_value
lr_count = -1

In [None]:
for epoch in range(0, epochs+1):
  # lr - early stop
  curr_lr = optimizer.param_groups[0]['lr'] # this value is 1e-5, you may need to adjust it depending on the test result
  lr_record.append(curr_lr)
  print('New peoch lr: ', curr_lr)
  if curr_lr < prev_lr:
    prev_lr = curr_lr
    lr_count += 1
  # if lr was reduced for the third time stop training
  if (lr_count == 3):
    print("Early Stop")
    save_model(model, epoch, model_name, optimizer, scheduler, grad_scaler, batch_size,
               path="/content/drive/MyDrive/") # you can change the path
    save_loss_record(train_loss_record, val_loss_record, lr_record, model_name+".csv")
    break
  # train
  train_loss = train(model, device, train_loader, optimizer, criterion, scheduler, grad_scaler, epoch, epochs, amp=True)
  train_loss_record.append(train_loss)
  # validation
  val_loss = validation(model, device, val_loader, optimizer, criterion, scheduler, epoch, epochs, amp=True)
  val_loss_record.append(val_loss)


In [None]:
save_model(model, epoch, model_name, optimizer, scheduler, grad_scaler, batch_size,
               path="/content/drive/MyDrive/") # you can change the path

## 5. Save model and and loss data

In [None]:
save_model(model, epoch, model_name, optimizer, scheduler, grad_scaler, batch_size,
           path = "/content/") # you can change this path to google drive path, it will be saved to your GD automatically

In [None]:
plt.figure(figsize=(10,10))
plt.plot(train_loss_record)
plt.plot(val_loss_record)

In [None]:
# you can save the above loss record as csv file
save_loss_record(train_loss_record, val_loss_record, lr_record, model_name+".csv")

## 6. check test loss and IOU score

#### you need to make a decison when to stop fine-tuning based on the train/val loss, you may also need to add more annotated data if the loss is not good enough

In [None]:
test_loss = validation(model, device, test_loader, optimizer, criterion, scheduler, 0, epochs, amp=True)

In [None]:
test_loss

In [None]:
val_loss = validation(model, device, val_loader, optimizer, criterion, scheduler, 0, epochs, amp=True)

In [None]:
val_loss

In [None]:
train_loss = validation(model, device, train_loader, optimizer, criterion, scheduler, 0, epochs, amp=True)
train_loss

In [None]:
pred_masks = []
iou_record = []
for i, batch in enumerate(test_loader):
      print("i = ", i)
      x = batch['image'].float().to(device)
      y = batch['mask'].float().to(device)
      assert(len(x) == len(y))
      print(x.shape)
      yp = model(x)

      for i in range(len(x)):
        x_ = x[i].unsqueeze(0)
        y_ = y[i].unsqueeze(0)
        yp_ = yp[i]
        print(y_.shape)

        # plot_2_sidebyside(
        #               y_.detach().cpu().numpy()[0][0].astype(int),
        #               (yp_.detach().cpu().numpy()[0]>0.5).astype(int))
        iou_score = iou_numpy((yp_.detach().cpu().numpy()[0]>0.5).astype(int), y_.detach().cpu().numpy()[0][0].astype(int))
        print("iou: ", iou_score)
        iou_record.append(iou_score)

      # print("yp shape", yp.shape)#torch.Size([1, 1, 572, 572])
      # plot_2_sidebyside(x.detach().cpu().numpy()[0][0],
      #                 y.detach().cpu().numpy()[0][0])

      # plot_2_sidebyside(
      #                 y.detach().cpu().numpy()[0][0],
      #                 (yp.detach().cpu().numpy()[0][0]>0.5).astype(int))


      # plot_3_sidebyside(x.detach().cpu().numpy()[0][0],
      #                 y.detach().cpu().numpy()[0][0],
      #                 (yp.detach().cpu().numpy()[0][0]>0.5).astype(int))


In [None]:
np.mean(iou_record)
np.std(iou_record)
# 0.8 or 80% IOU is good

In [None]:
save_model

## 7. Inference on data with no masks

Assuming you have loaded your model

In [None]:
# load the data, dont forget to change the path
with zipfile.ZipFile('/content/drive/MyDrive/keyhole_segmentation_data/keyhole_segmentation_data_no_mask.zip', 'r') as zip:
  zip.extractall(path='/content/Keyhole')

from keyholeDataset import KeyholeNoMask
# we use KeyholeNoMask class to load this data
# change the path
infer_dataset = KeyholeNoMask('/content/Keyhole/keyhole_segmentation_infer_data')

print(f"infer_dataset size: {len(infer_dataset)}")

batch_size=2
infer_loader = DataLoader(infer_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

In [None]:
pred_masks = []
for i, batch in enumerate(infer_loader):
      print("i = ", i)
      x = batch['image'].float().to(device)
      print(x.shape)
      yp = model(x)

      for i in range(len(x)):
        x_ = x[i].unsqueeze(0)
        yp_ = yp[i]
        pred_masks.append((yp_.detach().cpu().numpy()[0]>0.5).astype(int))
        print(yp_.shape)

        plot_2_sidebyside(
                      x_.detach().cpu().numpy()[0][0].astype(int),
                      (yp_.detach().cpu().numpy()[0]>0.5).astype(int))



# Now save your pred_masks list, save each one as "tif" images
# use the keyhole_feature_extraction code to get your features
# https://github.com/rubyjiang18/Deep-learning-approaches-for-time-resolved-laser-absorptance-prediction/tree/main/keyhole_feature_extraction