In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
#Data Loader

import os
import numpy as np
import torch
import glob
from scipy.io import loadmat
from torch.utils.data import Dataset
from skimage.exposure import rescale_intensity


class OrganSegmentationDataset(Dataset):

  def __init__(
        self,
        images_dir = '/content/drive/MyDrive/Dataset/segmentation/Covert_mat_into_numpy/processed_data',
        # images_dir="/content/drive/MyDrive/Research/Image_segmentation/processed_data",
        # images_dir="/content/drive/MyDrive/dataset/processed_data",
        subset = "train",
        traning_patient = 39,
        test_length = 1,
    ):
    self.images_dir = images_dir
    self.subset = subset
    self.traning_patient = traning_patient
    self.data_paths = []
    self.patient_ids = []
    self.required_test = False
    assert subset in ["all", "train", "validation"]

    print("reading {} images...".format(subset))
    filesPath = glob.glob(images_dir+"/*.mat")
    if(subset == "train"):
      for filePath in filesPath:
        patient_id = int(filePath.split("/")[-1].split("_")[1])
        if patient_id <= traning_patient:
          if patient_id not in self.patient_ids:
            self.patient_ids.append(patient_id)
          self.data_paths.append(filePath)
          self.data_paths = sorted(self.data_paths, key=lambda x: int(x.split("/")[-1].split("_")[1].zfill(2) + x.split("/")[-1].split("_")[-1].split(".")[0].zfill(3)), reverse=False)
    
    elif (subset == "validation"):
      filesPath = (filesPath)
      for filePath in sorted(filesPath):
        patient_id = int(filePath.split("/")[-1].split("_")[1])
        if traning_patient+test_length >= patient_id > traning_patient:
          if patient_id not in self.patient_ids:
            self.patient_ids.append(patient_id)
          self.data_paths.append(filePath)
          self.data_paths = sorted(self.data_paths)
    print(self.patient_ids)
  
  def __len__(self):
    return len(self.data_paths)

  def __getitem__(self, id):
    filePath = self.data_paths[id] 
    file_name = filePath.split("/")[-1].split("_")
    patient_id = int(file_name[1])
    slice_id = int(file_name[3].split(".")[0])
    mat = loadmat(filePath)
    mask = mat['seg_img']
    image = mat['main_img']

    #expand dimention
    image = np.expand_dims(image, axis = 2)
    
    #change shape (C, H, W)
    image = image.transpose(2, 0, 1)

    image_tensor = torch.from_numpy(image.astype(np.float32))
    mask_tensor = torch.from_numpy(mask.astype(int))


    return image_tensor, mask_tensor, patient_id, slice_id

In [None]:
#U-Net Model

from collections import OrderedDict

import torch
import torch.nn as nn


class UNet(nn.Module):

    def __init__(self, in_channels=3, out_channels=5, init_features=64):
        super(UNet, self).__init__()

        features = init_features
        self.encoder1 = UNet._block(in_channels, features, name="enc1")
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder2 = UNet._block(features, features * 2, name="enc2")
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder3 = UNet._block(features * 2, features * 4, name="enc3")
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder4 = UNet._block(features * 4, features * 8, name="enc4")
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck")

        self.upconv4 = nn.ConvTranspose2d(
            features * 16, features * 8, kernel_size=2, stride=2
        )
        self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4")
        self.upconv3 = nn.ConvTranspose2d(
            features * 8, features * 4, kernel_size=2, stride=2
        )
        self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3")
        self.upconv2 = nn.ConvTranspose2d(
            features * 4, features * 2, kernel_size=2, stride=2
        )
        self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2")
        self.upconv1 = nn.ConvTranspose2d(
            features * 2, features, kernel_size=2, stride=2
        )
        self.decoder1 = UNet._block(features * 2, features, name="dec1")

        self.conv = nn.Conv2d(
            in_channels=features, out_channels=out_channels, kernel_size=1
        )

    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        enc3 = self.encoder3(self.pool2(enc2))
        enc4 = self.encoder4(self.pool3(enc3))

        bottleneck = self.bottleneck(self.pool4(enc4))

        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.decoder4(dec4)
        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)

        output = self.conv(dec1)
        return output, dec1
    
    @staticmethod
    def _block(in_channels, features, name):
        return nn.Sequential(
            OrderedDict(
                [
                    (
                        name + "conv1",
                        nn.Conv2d(
                            in_channels=in_channels,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm1", nn.BatchNorm2d(num_features=features)),
                    (name + "relu1", nn.ReLU(inplace=True)),
                    (
                        name + "conv2",
                        nn.Conv2d(
                            in_channels=features,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm1", nn.BatchNorm2d(num_features=features)),
                    (name + "relu1", nn.ReLU(inplace=True)),
                ]
            )
        )

In [None]:
#traning and test data split

def data_loaders():
    dataset_train, dataset_valid = datasets()

    def worker_init(worker_id):
        np.random.seed(42 + worker_id)

    loader_train = DataLoader(
        dataset_train,
        batch_size=1,
        shuffle=True,
        drop_last=True,
        num_workers=2,
        worker_init_fn=worker_init,
    )
    loader_valid = DataLoader(
        dataset_valid,
        batch_size=1,
        shuffle=False,
        drop_last=False,
        num_workers=2,
        worker_init_fn=worker_init,
    )

    return loader_train, loader_valid

def datasets():
    train = OrganSegmentationDataset(
        subset="train"
    )
    valid = OrganSegmentationDataset(
        subset="validation"
    )
    return train, valid

In [None]:
#Stored model load

import glob
import torch

def get_prev_traning_data():
  last_epoch = 0
  checkPoints = sorted(glob.glob("/content/drive/MyDrive/Dataset/segmentation/Output/checkpoint_ioU/*.pt"))
  # checkPoints = sorted(glob.glob("/content/drive/MyDrive/Research/Image_segmentation/checkpoint/*.pt"))
  for checkPoint in checkPoints:
    checkEpoch = int(checkPoint.split('/')[-1].split('_')[1])
    if checkEpoch > last_epoch:
      last_epoch = checkEpoch
      last_checkpoint = checkPoint

  if len(checkPoints) != 0:
    return torch.load(last_checkpoint)
  else:
    return 0

In [None]:
!pip install pytorch_toolbelt
!pip install pytorch-ignite

In [None]:
#Pixel deviation based loss

import torch
from torch.nn.functional import one_hot

def pixDeviationLoss(img, pred, true):
  pred = torch.argmax(pred, 0).squeeze()
  true = true.squeeze()
  miss_match = ~torch.eq(true, pred).int()
  
  img = img.squeeze()

  return torch.sum(torch.abs(torch.mul(img, miss_match)), (0, 1))

In [None]:
#Training and save model


%load_ext tensorboard
import torch.utils.tensorboard as tb
import tempfile
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
from torch.nn.functional import one_hot
from pytorch_toolbelt import losses
from ignite.engine import *
from ignite.handlers import *
from ignite.metrics import *
import csv


def eval_step(engine, batch):
    return batch

default_evaluator = Engine(eval_step)

cm = ConfusionMatrix(num_classes=5)
metric = DiceCoefficient(cm)
metric.attach(default_evaluator, 'dice')

# max_slice_count = get_max_slice_count()

log_dir = tempfile.mkdtemp()
%tensorboard --logdir {log_dir} --reload_interval 1
writer = tb.SummaryWriter(log_dir, flush_secs=1)

device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")

loader_train, loader_valid = data_loaders()
loaders = {"train": loader_train, "valid": loader_valid}

unet = UNet(in_channels=1, out_channels=5)
unet.to(device)
dice_val = DiceVal()
loss_function = losses.JaccardLoss(mode = "multiclass")
best_validation_dsc = 0.0
optimizer = optim.Adam(unet.parameters(), lr=0.0001)
# loss_function = nn.CrossEntropyLoss()
loss_function_validatiion = losses.DiceLoss(mode = "multiclass")

loss_train_final = []
dice_valid_final = []
Valid_loss_final = []


epoch_num = 50

done_epoch = 0
prev_traning_data = get_prev_traning_data()

if prev_traning_data:
  print("hello")
  done_epoch = prev_traning_data['epoch']
  loss = prev_traning_data['loss']
  unet.load_state_dict(prev_traning_data['model_state_dict'])
  optimizer.load_state_dict(prev_traning_data['optimizer_state_dict'])
                       

for epoch in tqdm(range(done_epoch+1, epoch_num+1)):
  print("{epc} is running".format(epc = epoch))
  loss_train = []
  loss_valid = []
  Dice_valid = []
  Dice_per_organ = {}
  step = 0
  img_print = 0

  for phase in ["train", 'valid']:
    if phase == "train":
      validation_predict = {}
      validation_true = {}
      unet.train()
    else:
      unet.eval()

    for i, data in enumerate(loaders[phase]):
      x, y_true, patient_id , slice_id = data
      x, y_true = x.to(device), y_true.to(device)
      
      optimizer.zero_grad()

      with torch.set_grad_enabled(phase == "train"):
        y_pred = unet(x)
        # loss = loss_function(y_pred, y_true)
        loss = loss_function(y_pred, y_true) + pixDeviationLoss(x, y_pred, y_true)  


        if phase == "valid":
          loss = loss_function_validatiion(y_pred, y_true)
          loss_valid.append(float(loss.item()))
          step = step + 1

          patient_id = int(patient_id)
          slice_id = int(slice_id)

          if patient_id not in validation_predict.keys():
            validation_predict[patient_id] = y_pred
            validation_true[patient_id] = y_true
          else:
            validation_predict[patient_id] = torch.cat((validation_predict[patient_id], y_pred))
            validation_true[patient_id] = torch.cat((validation_true[patient_id], y_true))

          if (img_print % 70) == 0:
            y_pred_np = y_pred.detach().cpu().numpy().squeeze()
            y_pred_np = np.argmax(y_pred_np, 0)
            y_true_np = y_true.detach().cpu().numpy().squeeze()
            main_image_np = x.detach().cpu().numpy().squeeze()
            plt.figure(img_print)
            plt.subplot(1, 3, 1)
            plt.imshow(main_image_np)
            plt.subplot(1, 3, 2)
            plt.imshow(y_true_np)
            plt.subplot(1, 3, 3)
            plt.imshow(y_pred_np)
            plt.show()
          img_print = img_print + 1

        if phase == "train":
          loss_train.append(float(loss.item()))
          loss.backward()
          optimizer.step()

  all_dice_per_organ = torch.zeros(1,5)
  for patient in validation_predict.keys():
    val_patient_pred = validation_predict[patient]
    val_patient_true = validation_true[patient]
    dice = 1 - loss_function_validatiion(val_patient_pred, val_patient_true)
    Dice_valid.append(float(dice.item()))

    dice_per_organ = default_evaluator.run([[val_patient_pred, val_patient_true]])
    dice_per_organ = dice_per_organ.metrics['dice']
    dice_per_organ = dice_per_organ[None, :]
    all_dice_per_organ = torch.cat((all_dice_per_organ, dice_per_organ), 0)

  avg_all_dice_per_organ = torch.mean(all_dice_per_organ[1:],0)

  t_loss = np.mean(np.array(loss_train))
  v_loss = np.mean(np.array(loss_valid))
  V_dice = np.mean(np.array(Dice_valid))


  with open("/content/drive/MyDrive/Dataset/segmentation/Output/D_loss_data_IoU.csv", 'a') as f:
  # with open("/content/drive/MyDrive/dataset/Output/data_IoU.csv", 'a') as f:
    csv_writer = csv.writer(f, delimiter=',')
    csv_writer.writerow([epoch, t_loss, v_loss, V_dice, avg_all_dice_per_organ[0].item(), avg_all_dice_per_organ[1].item(), avg_all_dice_per_organ[2].item(), avg_all_dice_per_organ[3].item(), avg_all_dice_per_organ[4].item()])

  loss_train_final.append(t_loss)
  writer.add_scalar("train Loss", t_loss, epoch)
  dice_valid_final.append(V_dice)
  writer.add_scalar("validation IoU", V_dice, epoch)
  Valid_loss_final.append(v_loss)
  writer.add_scalar("validation loss", v_loss, epoch)
  
  for i in range(5):
    writer.add_scalar("Organ IoU/{organ}".format(organ=i), float(avg_all_dice_per_organ[i].item()), epoch)
    
  check_file = "/content/drive/MyDrive/Dataset/segmentation/Output/D_loss_checkpoint_ioU/IoU_seg_{ep}_{dic}.pt".format(ep=epoch, dic = V_dice)
  # check_file = "/content/drive/MyDrive/dataset/Output/checkpoint_IoU/seg_{ep}_{dic}.pt".format(ep=epoch, dic = V_dice)
  torch.save({
            'epoch': epoch,
            'model_state_dict': unet.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            't_loss': t_loss,
            'v_loss': v_loss,
            'V_dice': V_dice,
            'dice_01': avg_all_dice_per_organ
            }, check_file)

writer.flush()

trl = [i for i in range(1,len(loss_train_final)+1)]
plt.figure("test loss")
plt.plot(trl, loss_train_final)
plt.show()

vpl = [i for i in range(1,len(dice_valid_final)+1)]
plt.figure("Valid IoU")
plt.plot(vpl, dice_valid_final)
plt.show()