<a href="https://colab.research.google.com/github/young-hyun-park/HW/blob/main/2d_liver_segmentation_using_smp.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Import

In [1]:
!pip install segmentation-models-pytorch

Collecting segmentation-models-pytorch
  Downloading segmentation_models_pytorch-0.2.1-py3-none-any.whl (88 kB)
[?25l[K     |███▊                            | 10 kB 30.4 MB/s eta 0:00:01[K     |███████▍                        | 20 kB 24.3 MB/s eta 0:00:01[K     |███████████                     | 30 kB 18.5 MB/s eta 0:00:01[K     |██████████████▉                 | 40 kB 16.4 MB/s eta 0:00:01[K     |██████████████████▌             | 51 kB 9.0 MB/s eta 0:00:01[K     |██████████████████████▏         | 61 kB 10.5 MB/s eta 0:00:01[K     |██████████████████████████      | 71 kB 9.5 MB/s eta 0:00:01[K     |█████████████████████████████▋  | 81 kB 10.7 MB/s eta 0:00:01[K     |████████████████████████████████| 88 kB 5.2 MB/s 
Collecting pretrainedmodels==0.7.4
  Downloading pretrainedmodels-0.7.4.tar.gz (58 kB)
[K     |████████████████████████████████| 58 kB 7.7 MB/s 
[?25hCollecting timm==0.4.12
  Downloading timm-0.4.12-py3-none-any.whl (376 kB)
[K     |███████████████████

In [2]:
import torch
import numpy as np
from glob import glob
import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
import albumentations as A
import torch.nn.functional as F
import albumentations.pytorch
from torchvision import datasets, transforms
from torchsummary import summary
from PIL import Image
import os
import re
import cv2 
import nibabel as nib
import segmentation_models_pytorch as smp
import zipfile

In [3]:
seed = 1

lr = 0.001
momentum = 0.99

batch_size = 32
test_batch_size = 32

epochs = 10
log_interval = 10

# Load Data path

In [4]:
PATH = '/content/drive/MyDrive/Task03_Liver/Task03_Liver'

In [5]:
image_path =os.path.join(PATH+'/image.zip')
mask_path = os.path.join(PATH+'/mask.zip')

In [6]:
with zipfile.ZipFile(image_path, 'r') as zip_ref:
    zip_ref.extractall('/content/image')
with zipfile.ZipFile(mask_path, 'r') as zip_ref:
    zip_ref.extractall('/content/mask')

In [7]:
image_path_list = glob('/content/image'+'/*')
mask_path_list = glob('/content/mask'+'/*')

In [8]:
from sklearn.model_selection import train_test_split
import pandas as pd

train_path,test_path = train_test_split(image_path_list, test_size = 0.1, random_state = 42)

In [9]:
train_image_path = list()
test_image_path = list()
for i in train_path:
  train_image_path_ = glob(i+'/*.png')
  for path in train_image_path_:
    train_image_path.append(path)
for i in test_path:
  test_image_path_ = glob(i +'/*.png')
  for path in train_image_path_:
    test_image_path.append(path)

In [10]:
image = Image.open(train_image_path[0])
image = np.array(image)
image.shape

(512, 512)

# Dataset

In [11]:
class Train_Dataset(Dataset):
    def __init__(self, data_path,transform = None):
        self.data_path = data_path
        self.transform = transform
    def __len__(self):
        return len(self.data_path)
    def __getitem__(self,idx):
        path = self.data_path[idx]
        img_name = path.split('/')[-1]
        file_name = path.split('/')[-2]
        one_mask_path =  '/content/mask/mask'+ '/' + file_name + '/'+ img_name
        image = np.array(Image.open(path).resize((256,256)))
        image = image[:,:,np.newaxis]/255
        mask =  np.array(Image.open(one_mask_path).resize((256,256)))
        mask[mask==2] = 1
        mask = mask[...,np.newaxis]
        if self.transform is not None:
            transformed = self.transform(image=image, mask=mask)
            image = transformed['image']
            mask = transformed['mask']
        return image, mask

In [12]:
class Test_Dataset(Dataset):
    def __init__(self, data_path,transform = None):
        self.data_path = data_path
        self.transform = transform
    def __len__(self):
        return len(self.data_path)
    def __getitem__(self,idx):
        path = self.data_path[idx]
        img_name = path.split('/')[-1]
        file_name = path.split('/')[-2]
        one_mask_path = '/content/mask/mask' + '/' + file_name + '/'+img_name
        image = np.array(Image.open(path).resize((256,256)))
        image = image[:,:,np.newaxis]/255
        mask =  np.array(Image.open(one_mask_path).resize((256,256)))
        mask[mask==2] =1
        mask = mask[...,np.newaxis]
        if self.transform is not None:
            transformed = self.transform(image=image, mask=mask)
            image = transformed['image']
            mask = transformed['mask']
        return image, mask

In [13]:
train_transform = A.Compose(
    [
     A.pytorch.transforms.ToTensor()
     ]
    )

test_transform = A.Compose(
    [
     A.pytorch.transforms.ToTensor()
     ]
    )

In [14]:
train_data = Train_Dataset(train_image_path,transform = train_transform)
test_data = Test_Dataset(test_image_path,transform = test_transform)

In [15]:
train_loader = torch.utils.data.DataLoader(
    train_data,
    batch_size = batch_size,
    shuffle = True,
)
test_loader = torch.utils.data.DataLoader(
    test_data,
        batch_size = batch_size,
    shuffle = False,
)

# Model

In [16]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func
    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            #self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            #self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [17]:
save_path = "/content/drive/MyDrive/Task03_Liver/Task03_Liver/model"
###############################
trial = 1
n_epoches = 100
LR = 0.001
LR_DECREASE = 1e-5
lr_decrease_epoch = 10
BATCH_SIZE = 64
patience= 15
###############################
ENCODER = 'resnet34'
ACTIVATION = 'sigmoid' # could be None for logits or 'softmax2d' for multicalss segmentation
DEVICE = 'cuda'

# create segmentation model with pretrained encoder
model = smp.Unet(
    encoder_name=ENCODER,
    encoder_weights=None,
    in_channels=1,
    classes=1,
    activation=ACTIVATION,
)

loss = smp.utils.losses.DiceLoss()

metrics = [
    smp.utils.metrics.IoU(threshold=0.5),
]

optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=LR),
])

In [18]:
train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)

val_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

In [19]:
max_score = 0
early_stopping = EarlyStopping(patience=patience, verbose=True)

for epoch in range(0, n_epoches):
    
    print(f'\nEpoch: {epoch}')
    train_logs = train_epoch.run(train_loader)
    valid_logs = val_epoch.run(test_loader)
    
    with open(os.path.join(save_path, f'results{str(trial).zfill(2)}.csv'), 'a') as f:
            f.write('%03d,%0.6f,%0.6f,%0.6f,%0.6f\n' % (
                (epoch + 1),
                train_logs['dice_loss'],
                train_logs['iou_score'],
                valid_logs['dice_loss'],
                valid_logs['iou_score'],
            ))
    
    # do something (save model, change lr, etc.)
    if max_score < valid_logs['iou_score']:
        max_score = valid_logs['iou_score']
        torch.save(model, os.path.join(save_path, f'best_model{str(trial).zfill(2)}.pth'))
        print('New Record!')
        
    torch.save(model, os.path.join(save_path, f'final_model{str(trial).zfill(2)}.pth'))
    
    early_stopping(valid_logs['dice_loss'], model)
    if early_stopping.early_stop:
        print("Early stopping")
        break
    if epoch == lr_decrease_epoch:
        optimizer.param_groups[0]['lr'] = LR_DECREASE
        print(f'Decrease decoder learning rate to {LR_DECREASE}!')


Epoch: 0
train:   0%|          | 0/1591 [00:00<?, ?it/s]


RuntimeError: ignored

In [None]:
plt.plot(train_acc)
plt.plot(val_acc) 
plt.title('U-Net Model Accuracy', fontsize = 15)
plt.xlabel('Epoch', fontsize = 15)
plt.ylabel('Accuaracy', fontsize = 15)
plt.ylim(0.95,1)
plt.legend(['train','val'], loc='upper left')
plt.show()

In [None]:
plt.plot(train_loss)
plt.plot(loss_val)
plt.title('U-Net Model Loss', fontsize = 15)
plt.xlabel('Epoch', fontsize = 15)
plt.ylabel('Loss', fontsize = 15)
plt.ylim(0,0.2)
plt.legend(['train','val'], loc='upper left')
plt.show()

In [None]:
plt.plot(train_dice)
plt.plot(val_dice)
plt.title('U-Net Model Dice Coef', fontsize = 15)
plt.xlabel('Epoch', fontsize = 15)
plt.ylabel('Dice', fontsize = 15)
plt.ylim(0.6,1)
plt.legend(['train','val'], loc='upper left')
plt.show()

In [None]:
pred = F.softmax(test_masks_pred,dim = 1)
prediction = torch.argmax(pred,dim = 1)
prediction.shape

In [None]:
plt.figure(figsize = (10,10))
for i in range(3):  
  plt.subplot(3,1,i+1)
  plt.imshow(prediction[i].cpu().data.numpy(),'gray')

In [None]:
plt.figure(figsize = (10,10))
for i in range(3):  
  plt.subplot(3,1,i+1)
  plt.imshow(torch.squeeze(inputs_test[i]).cpu().data.numpy(),'gray')

In [None]:
inputs_test.shape

In [None]:
plt.figure(figsize = (10,10))
for i in range(3):  
  plt.subplot(3,1,i+1)
  plt.imshow(torch.squeeze(labels_test[i]).cpu().data.numpy(),'gray')