**Semi-supervised training: Phase 1**

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import sys

sys.path.append('/content/drive/MyDrive/Wound_tissue_segmentation')
sys.path.append('/content/drive/MyDrive/Wound_tissue_segmentation/utils')
sys.path.append('/content/drive/MyDrive/Wound_tissue_segmentation/wound_lib')

In [None]:
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset
import albumentations as albu
import cv2
import numpy as np
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.utils import metrics, losses, base
import random
import matplotlib.pyplot as plt
import os
from copy import deepcopy
from datetime import datetime
import torch.nn.functional as F
import time

%matplotlib inline

## Dataloader

In [None]:
class Dataset(BaseDataset):
    """CamVid Dataset. Read images, apply augmentation and preprocessing transformations.

    Args:
        images_dir (str): path to images folder
        masks_dir (str): path to segmentation masks folder
        augmentation (albumentations.Compose): data transfromation pipeline
            (e.g. flip, scale, etc.)
        preprocessing (albumentations.Compose): data preprocessing
            (e.g. noralization, shape manipulation, etc.)

    """

    def __init__(
            self,
            list_IDs,
            images_dir,
            masks_dir,
            augmentation=None,
            preprocessing=None,
            to_categorical:bool=False,
            resize=(False, (256, 256)), # To resize, the first value has to be True
            n_classes:int=6,
    ):
        self.ids = list_IDs
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
        self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]

        self.augmentation = augmentation
        self.preprocessing = preprocessing
        self.to_categorical = to_categorical
        self.resize = resize
        self.n_classes = n_classes

    def __getitem__(self, i):

        # read data
        image = cv2.imread(self.images_fps[i])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.masks_fps[i], 0)     # ----------------- pay attention ------------------ #

        if self.resize[0]:
            image = cv2.resize(image, self.resize[1], interpolation=cv2.INTER_NEAREST)
            mask = cv2.resize(mask, self.resize[1], interpolation=cv2.INTER_NEAREST)

        # mask = mask/255.0   # converting mask to (0 and 1) # ----------------- pay attention ------------------ #
        mask = np.expand_dims(mask, axis=-1)  # adding channel axis # ----------------- pay attention ------------------ #

        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']

        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']

        if self.to_categorical:
            mask = torch.from_numpy(mask)
            mask = F.one_hot(mask.long(), num_classes=self.n_classes)
            mask = mask.type(torch.float32)
            mask = mask.numpy()
            mask = np.squeeze(mask)

            mask = np.moveaxis(mask, -1, 0) # e.g. 6 x 512 x 512. Only for smp

        # print('-----------------------------------------------')
        # print(image.dtype)
        # print(mask.shape)

        return image, mask

    def __len__(self):
        return len(self.ids)

## Augmentation

In [None]:
def get_training_augmentation():
    train_transform = [

        albu.OneOf(
            [
                albu.HorizontalFlip(p=0.5),
                albu.VerticalFlip(p=0.5),
            ],
            p=0.8,
        ),

        albu.OneOf(
            [
                albu.ShiftScaleRotate(scale_limit=0.5, rotate_limit=0, shift_limit=0, p=0.1, border_mode=0), # scale only
                albu.ShiftScaleRotate(scale_limit=0, rotate_limit=30, shift_limit=0, p=0.1, border_mode=0), # rotate only
                albu.ShiftScaleRotate(scale_limit=0, rotate_limit=0, shift_limit=0.1, p=0.6, border_mode=0), # shift only
                albu.ShiftScaleRotate(scale_limit=0.5, rotate_limit=30, shift_limit=0.1, p=0.2, border_mode=0), # affine transform
            ],
            p=0.9,
        ),

        albu.OneOf(
            [
                albu.Perspective(p=0.2),
                albu.GaussNoise(p=0.2),
                albu.Sharpen(p=0.2),
                albu.Blur(blur_limit=3, p=0.2),
                albu.MotionBlur(blur_limit=3, p=0.2),
            ],
            p=0.5,
        ),

        albu.OneOf(
            [
                albu.CLAHE(p=0.25),
                albu.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.25),
                albu.RandomGamma(p=0.25),
                albu.HueSaturationValue(p=0.25),
            ],
            p=0.3,
        ),

    ]

    return albu.Compose(train_transform, p=0.9) # 90% augmentation probability


def get_validation_augmentation():
    """Add paddings to make image shape divisible by 32"""
    test_transform = [
        # albu.PadIfNeeded(512, 512)
    ]
    return albu.Compose(test_transform)


def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')


def get_preprocessing(preprocessing_fn):
    """Construct preprocessing transform

    Args:
        preprocessing_fn (callbale): data normalization function
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose

    """

    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)

## Parameters

In [None]:
# Parameters
BASE_MODEL = 'MiT+pscse'
ENCODER = 'mit_b3'
ENCODER_WEIGHTS = 'imagenet'
BATCH_SIZE = 16
n_classes = 4
ACTIVATION = 'sigmoid' # could be None for logits or 'softmax2d' for multiclass segmentation
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LR = 0.0001 # learning rate
EPOCHS = 500
WEIGHT_DECAY = 1e-5
SAVE_WEIGHTS_ONLY = True
RESIZE = (False, (256,256)) # if resize needed
TO_CATEGORICAL = True
SAVE_BEST_MODEL = True
SAVE_LAST_MODEL = False

PERIOD = 10 # periodically save checkpoints
RAW_PREDICTION = False # if true, then stores raw predictions (i.e. before applying threshold)
RETRAIN = False

# For early stopping
EARLY_STOP = True # True to activate early stopping
PATIENCE = 50 # for early stopping

import ssl
ssl._create_default_https_context = ssl._create_unverified_context

## Helper function: save a model

In [None]:
def save(model_path, epoch, model_state_dict, optimizer_state_dict):

    state = {
        'epoch': epoch + 1,
        'state_dict': deepcopy(model_state_dict),
        'optimizer': deepcopy(optimizer_state_dict),
        }

    torch.save(state, model_path)

## Loss, optimizer, metrics, and callbacks

In [None]:
# Loss function
dice_loss = losses.DiceLoss()
focal_loss = losses.FocalLoss()
dce_loss = losses.DynamicCEAndSCELoss() # dynamic CE

# Metrics
metrics = [
    metrics.IoU(threshold=0.5),
    metrics.Fscore(threshold=0.5),
]


## Model Run

In [None]:
# Create a function to read names from a text file, and add extensions
def read_names(txt_file, ext=".png"):
  with open(txt_file, "r") as f: names = f.readlines()

  names = [name.strip("\n") for name in names] # remove newline

  # Names are without extensions. So, add extensions
  names = [name + ext for name in names]

  return names

In [None]:
dir_txt_save = '/content/drive/MyDrive/Wound_tissue_segmentation/texts/'
os.makedirs(dir_txt_save, exist_ok=True)

# Read unsupervised names
dir_txt_load = '/content/drive/MyDrive/Wound_tissue_segmentation/Dataset/FUWound_mmseg_all_in_one_cropped_padded_selfSupervised'
unsup_names = read_names(os.path.join(dir_txt_load, 'Unsupervised_name.txt'), ext='.png')

# Read supervised train, test, and val names
dir_txt = '/content/drive/MyDrive/Wound_tissue_segmentation/Dataset/dataset_MiT_v3+aug-added'
sup_IDs_train = read_names(os.path.join(dir_txt, 'labeled_train_names.txt'), ext='.png')
list_IDs_val = read_names(os.path.join(dir_txt, 'labeled_val_names.txt'), ext='.png')
list_IDs_test = read_names(os.path.join(dir_txt, 'test_names.txt'), ext='.png')

In [None]:
n_runs = 5 # No. of runs

seeds = [random.randint(0, 5000) for _ in range(n_runs)] # generate 10 random seeds

save_dir_pred_root = '/content/drive/MyDrive/Wound_tissue_segmentation/predictions'
os.makedirs(save_dir_pred_root, exist_ok = True)

weight_factor = [1.0, 1.0, 1.0]

for run, seed in enumerate(seeds):

    print('===================================================================')
    print('===================================================================')
    print(f'=========================== run {run} ============================')
    print('===================================================================')
    print('===================================================================')

    total_loss = base.HybridLoss(dice_loss, focal_loss, dce_loss, weight_factor)

    start = time.time() # start of training

    # Create a unique model name
    model_name = BASE_MODEL + '_padded_' + ENCODER + '_' + datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_selfSupervised'
    print(model_name)

    aux_params=dict(
        classes=n_classes,
        activation=ACTIVATION,
        dropout=0.1, # dropout ratio, default is None
    )

    # create segmentation model with pretrained encoder
    model = smp.Unet(
        encoder_name=ENCODER,
        encoder_weights=ENCODER_WEIGHTS,
        # aux_params=aux_params,
        classes=n_classes,
        activation=ACTIVATION,
        decoder_attention_type='pscse',
    )

    preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

    model.to(DEVICE)

    # Optimizer
    optimizer = torch.optim.Adam([
        dict(params=model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY),
    ])

    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                  factor=0.1,
                                  mode='min',
                                  patience=10,
                                  min_lr=0.00001,
                                  verbose=True,
                                  )

    print(f'seed: {seed}')

    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)

    x_train_dir = x_valid_dir = x_test_dir = '/content/drive/MyDrive/Wound_tissue_segmentation/Dataset/FUWound_mmseg_all_in_one_cropped_padded_selfSupervised/images'
    y_train_dir = y_valid_dir = y_test_dir = '/content/drive/MyDrive/Wound_tissue_segmentation/Dataset/FUWound_mmseg_all_in_one_cropped_padded_selfSupervised/annotations'

    random.seed(seed) # seed for random number generator

    unsup_names_IDs = unsup_names.copy() # make a copy of unsupervised names
    random.shuffle(unsup_names_IDs) # shuffle unsupervised names
    unsup_names_IDs = unsup_names_IDs[:50] # take 50 unsupervised images

    list_IDs_train = sup_IDs_train + unsup_names_IDs # supervised + unsupervised

    print('No. of training images: ', len(list_IDs_train))
    print('No. of validation images: ', len(list_IDs_val))
    print('No. of test images: ', len(list_IDs_test))

    # Save the randomly picked 50 unsupervised names in text files
    with open(os.path.join(dir_txt_save, model_name + '_unsup_train.txt'), "w") as f:
      for name in unsup_names_IDs: print(name, file=f)

    # Checkpoint directory
    checkpoint_loc = '/content/drive/MyDrive/Wound_tissue_segmentation/checkpoints/' + model_name

    # Create checkpoint directory if does not exist
    if not os.path.exists(checkpoint_loc): os.makedirs(checkpoint_loc)

    # Dataloader ===================================================================
    train_dataset = Dataset(
        list_IDs_train,
        x_train_dir,
        y_train_dir,
        augmentation=get_training_augmentation(),
        preprocessing=get_preprocessing(preprocessing_fn),
        to_categorical=TO_CATEGORICAL,
        resize=(RESIZE),
        n_classes=n_classes,
    )

    valid_dataset = Dataset(
        list_IDs_val,
        x_valid_dir,
        y_valid_dir,
        augmentation=get_validation_augmentation(),
        preprocessing=get_preprocessing(preprocessing_fn),
        resize=(RESIZE),
        to_categorical=TO_CATEGORICAL,
        n_classes=n_classes,
    )

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=6)
    valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=6)

    # create epoch runners =========================================================
    # it is a simple loop of iterating over dataloader`s samples
    train_epoch = smp.utils.train.TrainEpoch(
        model,
        loss=total_loss,
        metrics=metrics,
        optimizer=optimizer,
        device=DEVICE,
        verbose=True,
    )

    valid_epoch = smp.utils.train.ValidEpoch(
        model,
        loss=total_loss,
        metrics=metrics,
        device=DEVICE,
        verbose=True,
    )

    # Train ========================================================================
    # train model for N epochs
    best_viou = 0.0
    best_vloss = 1_000_000.
    save_model = False # Initially start with False
    cnt_patience = 0

    store_train_loss, store_val_loss = [], []
    store_train_iou, store_val_iou = [], []
    store_train_dice, store_val_dice = [], []

    for epoch in range(EPOCHS):

        print('\nEpoch: {}'.format(epoch))
        train_logs = train_epoch.run(train_loader)
        valid_logs = valid_epoch.run(valid_loader)

        # Store losses and metrics
        train_loss_key = list(train_logs.keys())[0] # first key is for loss
        val_loss_key = list(valid_logs.keys())[0] # first key is for loss

        store_train_loss.append(train_logs[train_loss_key])
        store_val_loss.append(valid_logs[val_loss_key])
        store_train_iou.append(train_logs["iou_score"])
        store_val_iou.append(valid_logs["iou_score"])
        store_train_dice.append(train_logs["fscore"])
        store_val_dice.append(valid_logs["fscore"])

        # Track best performance, and save the model's state
        if  best_vloss > valid_logs[val_loss_key]:
            best_vloss = valid_logs[val_loss_key]
            print(f'Validation loss reduced. Saving the model at epoch: {epoch:04d}')
            cnt_patience = 0 # reset patience
            best_model_epoch = epoch
            save_model = True

        # Compare iou score
        elif best_viou < valid_logs['iou_score']:
            best_viou = valid_logs['iou_score']
            print(f'Validation IoU increased. Saving the model at epoch: {epoch:04d}.')
            cnt_patience = 0 # reset patience
            best_model_epoch = epoch
            save_model = True

        else: cnt_patience += 1

        # Learning rate scheduler
        scheduler.step(valid_logs[sorted(valid_logs.keys())[0]]) # monitor validation loss

        # Save the model
        if save_model:
            save(os.path.join(checkpoint_loc, 'best_model' + '.pth'),
                epoch+1, model.state_dict(), optimizer.state_dict())
            save_model = False

        # Early stopping
        if EARLY_STOP and cnt_patience >= PATIENCE:
          print(f"Early stopping at epoch: {epoch:04d}")
          break

        # Periodic checkpoint save
        if not SAVE_BEST_MODEL:
          if (epoch+1) % PERIOD == 0:
            save(os.path.join(checkpoint_loc, f"cp-{epoch+1:04d}.pth"),
                epoch+1, model.state_dict(), optimizer.state_dict())
            print(f'Checkpoint saved for epoch {epoch:04d}')

    if not EARLY_STOP and SAVE_LAST_MODEL:
        print('Saving last model')
        save(os.path.join(checkpoint_loc, 'last_model' + '.pth'),
            epoch+1, model.state_dict(), optimizer.state_dict())

    print(best_model_epoch)
    print('Min validation loss:', np.min(store_val_loss))

    end = time.time() # End of training

    print(f'Training time: {end - start:.2f} seconds')

    # Plot loss curves =============================================================
    fig, ax = plt.subplots(1,3, figsize=(12, 3))

    ax[0].plot(store_train_loss, 'r')
    ax[0].plot(store_val_loss, 'b')
    ax[0].set_title('Loss curve')
    ax[0].legend(['training', 'validation'])

    ax[1].plot(store_train_iou, 'r')
    ax[1].plot(store_val_iou, 'b')
    ax[1].set_title('IoU curve')
    ax[1].legend(['training', 'validation'])

    ax[2].plot(store_train_iou, 'r')
    ax[2].plot(store_val_iou, 'b')
    ax[2].set_title('Dice curve')
    ax[2].legend(['training', 'validation'])

    fig.tight_layout()

    save_fig_dir = "/content/drive/MyDrive/Wound_tissue_segmentation/plots/"
    if not os.path.exists(save_fig_dir): os.makedirs(save_fig_dir)

    fig.savefig(os.path.join(save_fig_dir, model_name + '.png'))
