<a href="https://colab.research.google.com/github/pkliui/machine-learning/blob/master/stepik-deep-learning/16-HW-semantic-segmentation/16_hw_semantic_segmentation_unet_kaggle.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Semantic segmentation of moles' images by UNet

* This is a pytorch implementation of the UNet network for semantic segmentation of moles' images
* Dataset: PD2 ADDI database https://www.fc.up.pt/addi/ph2%20database.html 
* Homework of Unit 16, Deep learning course (part 1, spring 2021) offered by Moscow Institute of Physics and Technology https://stepik.org/course/91157/info

## Fix seed for reproducibility
* #torch.use_deterministic_algorithms(True) yields error in kaggle, no error in colab

In [None]:
import numpy as np
import torch, os
import random
# fix seed for reproducible results
def set_seed(seed):
    torch.manual_seed(seed)
    #torch.use_deterministic_algorithms(True)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

set_seed(42)

## Declare parameters

In [None]:
SIZE_X = (572, 572) # size of input images
SIZE_Y = (388, 388) # size of input segmented images

BATCH_SIZE = 8 # batch size


TRAIN_SHARE = 100 # size of train set
VAL_SHARE = 50# size of val set
TEST_SHARE = 50# size of test  set

MAX_EPOCHS = 150 # number of epochs 
LEARNING_RATE = 1e-3 # learning rate

SCHEDULER_STEP = 50 # scheduler step
SCHEDULER_GAMMA = 0.1


## Read  and resize images
> Read images and lesions (segmented images) from the root

In [None]:
images = []
lesions = []
from skimage.io import imread
import os
root = '/kaggle/input/ph2databaseaddi/PH2Dataset'

for root, dirs, files in os.walk(os.path.join(root, 'PH2 Dataset images')):
    if root.endswith('_Dermoscopic_Image'):
        images.append(imread(os.path.join(root, files[0])))
    if root.endswith('_lesion'):
        lesions.append(imread(os.path.join(root, files[0])))

> Resize images to have the same size as the net expects

> Use **nearest neighboour interpolation**. This is important because using any other interpolation "may result in tampering with the ground truth labels" [ https://ai.stackexchange.com/questions/6274/how-can-i-deal-with-images-of-variable-dimensions-when-doing-image-segmentation ]

In [None]:
# resize images as required by UNet architecture, resize() automatically normalizes to (0,1)
# X = image to be segemnted
# Y = segmented image

from skimage.transform import resize
size_X = SIZE_X
size_Y = SIZE_Y

import cv2 
X = [cv2.resize(x, size_X, interpolation=cv2.INTER_NEAREST) for x in images] 
Y = [cv2.resize(y, size_Y, interpolation=cv2.INTER_NEAREST)>0.5 for y in lesions] 
X = X / np.max(X)
Y = Y / np.max(Y)


# convert to float32
import numpy as np
X = np.array(X, np.float32)
Y = np.array(Y, np.float32)


print(f'Loaded {len(X)} images')

> draw some images 

In [None]:
import matplotlib.pyplot as plt
from IPython.display import clear_output

plt.figure(figsize=(18, 6))
for i in range(6):
    plt.subplot(2, 6, i+1)
    plt.axis("off")
    plt.imshow(X[i])

    plt.subplot(2, 6, i+7)
    plt.axis("off")
    plt.imshow(Y[i])
plt.show();

## Split into train-val-test

In [None]:
# generate len(X) random indices
# from len(X) as it were np.arange(len(X))
# False is to generate without replacement (no repetitions)
ix = np.random.choice(len(X), len(X), False)
#
# split generated indices to train-val-test sets as following: 100 test-50 val-50 test
# [100, 150] entries indicate where along axis the ix array is split. 
tr, val, ts = np.split(ix, [TRAIN_SHARE, TRAIN_SHARE+VAL_SHARE])

In [None]:
assert (len(tr), len(val), len(ts))==(TRAIN_SHARE, VAL_SHARE, TEST_SHARE)

In [None]:
# load data using dataloader
from torch.utils.data import DataLoader

# set the batch size
batch_size = BATCH_SIZE

# set the dataloaders
# set drop_last to skip the batches with the # elements < batch size
data_tr = DataLoader(list(zip(np.rollaxis(X[tr], 3, 1), Y[tr, np.newaxis])), 
                     batch_size=batch_size, shuffle=True, drop_last=True)
data_val = DataLoader(list(zip(np.rollaxis(X[val], 3, 1), Y[val, np.newaxis])),
                      batch_size=batch_size, shuffle=True, drop_last=True)
data_ts = DataLoader(list(zip(np.rollaxis(X[ts], 3, 1), Y[ts, np.newaxis])),
                     batch_size=batch_size, shuffle=True, drop_last=True)

In [None]:
ii=0
for X_val,Y_val in data_val:
    ii+=1
    print(ii)
    print(X_val[0].shape)
    print(X_val.shape)

In [None]:
# use cuda if available
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

## Implement UNet model

In [None]:
# install torchvision
!pip install torchvision

In [None]:
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import torch.optim as optim
from time import time

from matplotlib import rcParams

In [None]:
class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        # encoder (downsampling)
        # Each enc_conv/dec_conv block should look like this:
        # nn.Sequential(
        #     nn.Conv2d(...),
        #     ... (2 or 3 conv layers with relu and batchnorm),
        # )
        ##################
        # encoder layer 0 
        #################
        # 3, 572, 572
        self.e0_conv = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3),
            # 64, 570, 570
            nn.Conv2d(64, 64, kernel_size=3),
            # 64, 568, 568
            nn.BatchNorm2d(64),
            nn.ReLU()
            )
        # 64, 568, 568
        self.e0_pool =   nn.MaxPool2d(2, stride=2, return_indices=False) 
        # 64, 284, 284 
        #
        ##################
        # encoder layer 0 - cropping for decoder layer and generating a fake maxpool image tensor to get indices for maxunpool
        #################
        self.e0_crop = nn.Sequential(
            # 64, 568, 568
            torchvision.transforms.CenterCrop(392)
            # 64, 392,392
            )
        self.e0_pool_idx =  nn.MaxPool2d(2, stride=2, return_indices=True)
        # 64, 196,196
        #
        #################
        # encoder layer1
        ################
        self.e1_conv = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3),
            # 128, 282, 282
            nn.Conv2d(128, 128, kernel_size=3),
            # 128, 280, 280
            nn.BatchNorm2d(128),
            nn.ReLU()
            )
        # 128, 280, 280
        self.e1_pool =   nn.MaxPool2d(2, stride=2, return_indices=False)
        # 128, 140, 140
        #
        ###################
        # encoder layer 1 - cropping for decoder layer and generating a fake maxpool image tensor to get indices for maxunpool
        ##################
        self.e1_crop = nn.Sequential(
            # 128, 280, 280
            torchvision.transforms.CenterCrop(200)
            # 128, 200, 200
            )
        self.e1_pool_idx =  nn.MaxPool2d(2, stride=2, return_indices=True)
        # 128, 100, 100
        #
        ###################
        # encoder layer 2
        ###################
        self.e2_conv = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3),
            # 256, 138, 138
            nn.Conv2d(256, 256, kernel_size=3),
            # 256, 136, 136
            nn.BatchNorm2d(256),
            nn.ReLU()
            )
        # 256, 136, 136
        self.e2_pool =  nn.MaxPool2d(2, stride=2, return_indices=False)
        # 256, 68, 68
        #
        #################
        # encoder layer 2 - cropping for decoder layer and generating a fake maxpool image tensor to get indices for maxunpool
        #################
        self.e2_crop = nn.Sequential(
            # 256, 136, 136
            torchvision.transforms.CenterCrop(104)
            # 256, 104, 104
            )
        self.e2_pool_idx =  nn.MaxPool2d(2, stride=2, return_indices=True)
        # 256, 52, 52
        #
        ##################
        # encoder layer 3
        #################
        self.e3_conv = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3),
            # 512, 66, 66
            nn.Conv2d(512, 512, kernel_size=3),
            # 512, 64, 64
            nn.BatchNorm2d(512),
            nn.ReLU()
            )
        self.e3_pool =  nn.MaxPool2d(2, stride=2, return_indices=False)
        # 512, 32, 32
        #
        #################
        # encoder layer 3 - cropping for decoder layer and generating a fake maxpool image tensor to get indices for maxunpool
        #################
        self.e3_crop = nn.Sequential(
            # 512, 64, 64
            torchvision.transforms.CenterCrop(56)
            # 512, 56, 56
            )
        self.e3_pool_idx =  nn.MaxPool2d(2, stride=2, return_indices=True)
        # 512, 28, 28
        #
        ###
        # bottleneck
        ###
        # 512, 32, 32
        self.bottleneck_conv = nn.Sequential(
            nn.Conv2d(512, 1024, kernel_size=3),
            # 1024, 30, 30
            nn.Conv2d(1024, 1024, kernel_size=3),
            # 1024, 28, 28
            nn.BatchNorm2d(1024),
            nn.ReLU()
            )
            # 1024, 28, 28

        # decoder (upsampling)
        ###################
        # decoder layer 3
        ###################
        # 1024, 28, 28--> 255         
        self.d3_upsample = nn.MaxUnpool2d(kernel_size=2, stride=2)
        #self.d3_upsample =  nn.Upsample(scale_factor=2)
        # 1024, 56, 56
        self.d3_upconv = nn.Sequential(
          # H_out=stride*(H_in−1)−2×padding+kernel_size+output_padding 
          # 56 = 1*(56-1)-2*1+2+1 
          #nn.ConvTranspose2d(1024, 512, kernel_size=2, padding=1, output_padding = 1),
          # 56 = 1*(56-1)-2*1+3
          nn.ConvTranspose2d(1024, 512, kernel_size=3, padding=1),
          # 512, 56, 56
          nn.BatchNorm2d(512),
          nn.ReLU()
          # 512, 56, 56
          # 1024, 56, 56 after concatenation w/ corresponding cropped encoder map
          )
        # 1024, 56, 56 
        #
        self.d3_conv = nn.Sequential(
            nn.Conv2d(1024, 512, kernel_size=3),
            ## 56- 3 + 1
            # 512, 54 ,54
            nn.Conv2d(512, 512, kernel_size=3),
            # 512, 52, 52
            nn.BatchNorm2d(512),
            nn.ReLU()
            # 512, 52, 52
            )
        #
        ##################
        # decoder layer 2 
        ##################
        # 512, 52, 52
        self.d2_upsample =  nn.MaxUnpool2d(kernel_size=2, stride=2)
        #self.d2_upsample =  nn.Upsample(scale_factor=2)
        # 512, 104, 104
        self.d2_upconv = nn.Sequential(
        #nn.ConvTranspose2d(512, 256, kernel_size=2, padding=1, output_padding = 1),
        nn.ConvTranspose2d(512, 256, kernel_size=3, padding=1),
        # 256, 104, 104
        nn.BatchNorm2d(256),
        nn.ReLU()
        # 256, 104, 104
        # 512, 104, 104 after concatenation w/ corresponding cropped encoder map
        )
        #
        self.d2_conv = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3),
            # 256, 102 ,102
            nn.Conv2d(256, 256, kernel_size=3),
            # 256, 100, 100
            nn.BatchNorm2d(256),
            nn.ReLU()
            # 256, 100, 100
            )
        ##################
        # decoder layer 1
        ##################
        # 256, 100, 100
        self.d1_upsample =   nn.MaxUnpool2d(kernel_size=2, stride=2)
        #self.d1_upsample =  nn.Upsample(scale_factor=2)
        # 256, 200, 200
        self.d1_upconv = nn.Sequential(
            #nn.ConvTranspose2d(256, 128, kernel_size=2, padding=1, output_padding = 1),
            nn.ConvTranspose2d(256, 128, kernel_size=3, padding=1),
            # 128, 200, 200
            nn.BatchNorm2d(128),
            nn.ReLU()
            # 128, 200, 200
            # 256, 200, 200 after concatenation w/ corresponding cropped encoder map
            )
        #
        self.d1_conv = nn.Sequential(
            # 256, 200, 200
            nn.Conv2d(256, 128, kernel_size=3),
            # 128, 198, 198
            nn.Conv2d(128, 128, kernel_size=3),
            # 128, 196, 196
            nn.BatchNorm2d(128),
            nn.ReLU()
            # 128, 196, 196
            )
        ###
        # decoder layer 0
        ###
        # 128, 196, 196
        self.d0_upsample =   nn.MaxUnpool2d(kernel_size=2, stride=2)
        #self.d0_upsample =  nn.Upsample(scale_factor=2)
        # 128, 392, 392
        self.d0_upconv = nn.Sequential(
          #nn.ConvTranspose2d(128, 64, kernel_size=2, padding=1, output_padding = 1),
          nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1),
          # 64, 392, 392
          nn.BatchNorm2d(64),
          nn.ReLU()
          # 64, 392, 392
          # 128, 392, 392 after concatenation w/ corresponding cropped encoder map
          )
        #
        self.d0_conv = nn.Sequential(
            # 128, 392, 392
            nn.Conv2d(128, 64, kernel_size=3),
            # 64, 390, 390
            nn.Conv2d(64, 64, kernel_size=3),
            # 64, 388, 388
            nn.Conv2d(64, 1, kernel_size=1),
            # 1, 388, 388
            nn.BatchNorm2d(1),
            # 1, 388, 388
            )

    def forward(self, x):
        # encoder
        ###################
        # encoder layer 0
        ##################
        #
        # convolutions
        #--> # 1, 572, 572
        e0 = self.e0_conv(x)
        # --> # 64, 568, 568 
        assert (e0.shape[1],e0.shape[2],e0.shape[3])==(64, 568, 568), "encoder layer e0 expected shape {}, got{}".format("(64, 568, 568)",(e0.shape[1],e0.shape[2],e0.shape[3]))
        #
        # pooling
        e0_pool = self.e0_pool(e0)
        # --> # 64, 284, 284 
        assert (e0_pool.shape[1],e0_pool.shape[2],e0_pool.shape[3])==(64, 284, 284), "encoder layer e0 after pooling expected shape{}".format("(64, 284, 284)")
        #
        # cropping and cropped indices for skip connections
        e0_crop = self.e0_crop(e0)
        # --> # 64, 392,392
        _, idx0_crop = self.e0_pool_idx(e0_crop) # --> pooling indices from 64, 392,392 cropped map 
        # --> # 64, 196,196
        assert (e0_crop.shape[1],e0_crop.shape[2],e0_crop.shape[3])==(64, 392, 392), "encoder layer e0 after cropping expected shape{}".format("(64, 392, 392)")
        assert (idx0_crop.shape[1],idx0_crop.shape[2],idx0_crop.shape[3])==(64, 196, 196), "encoder indices idx0 after cropping expected shape {}, got {}".format("(64, 196, 196)",(idx0_crop.shape[1],idx0_crop.shape[2],idx0_crop.shape[3]))
        #
        ##################
        # encoder layer 1
        ##################
        #  convolutions
        # --> # 64, 284, 284
        e1 = self.e1_conv(e0_pool)
        # --> # 128, 280, 280
        assert (e1.shape[1],e1.shape[2],e1.shape[3])==(128, 280, 280), "encoder layer e1 expected shape {}".format("(128, 280, 280)")
        #
        # pooling
        e1_pool = self.e1_pool(e1)
        # --> # 128, 140, 140
        assert (e1_pool.shape[1],e1_pool.shape[2],e1_pool.shape[3])==(128, 140, 140), "encoder layer e1 after pooling expected shape{}".format("(128, 140, 140)")
        #
        # cropping and cropped indices for skip connections
        e1_crop = self.e1_crop(e1)
        # --> # 128, 200, 200
        _, idx1_crop = self.e1_pool_idx(e1_crop) # --> pooling indices from 128, 200, 200 cropped map
        # --> # 128, 100, 100
        assert (e1_crop.shape[1],e1_crop.shape[2],e1_crop.shape[3])==(128, 200, 200), "encoder layer e1 after cropping expected shape{}".format("(128, 200, 200)")
        assert (idx1_crop.shape[1],idx1_crop.shape[2],idx1_crop.shape[3])==(128, 100, 100), "encoder indices idx1 after cropping expected shape{}".format("(128, 100, 100)")
        #
        ##################
        # encoder layer 2
        ##################
        #  convolutions
        # --> # 128, 140, 140
        e2 = self.e2_conv(e1_pool)
        # --> # 256, 136, 136
        assert (e2.shape[1],e2.shape[2],e2.shape[3])==(256, 136, 136), "encoder layer e2 expected shape {}".format("(256, 136, 136)")
        #
        # pooling
        e2_pool = self.e2_pool(e2) 
        # --> # 256, 68, 68
        assert (e2_pool.shape[1],e2_pool.shape[2],e2_pool.shape[3])==(256, 68, 68), "encoder layer e2 after pooling expected shape{}".format("(256, 68, 68)")
        #
        # cropping and cropped indices for skip connections
        e2_crop = self.e2_crop(e2) 
        # --> # 256, 104, 104
        _, idx2_crop = self.e2_pool_idx(e2_crop) # --> pooling indices from 256, 104, 104 cropped map
        # --> # 256, 52, 52
        assert (e2_crop.shape[1],e2_crop.shape[2],e2_crop.shape[3])==(256, 104, 104), "encoder layer e2 after cropping expected shape{}".format("(256, 104, 104)")
        assert (idx2_crop.shape[1],idx2_crop.shape[2],idx2_crop.shape[3])==(256, 52, 52), "encoder indices idx2 after cropping expected shape{}".format("(256, 52, 52)")
        #
        ##################
        # encoder layer 3
        ##################
        #  convolutions
        # --> # 256, 68, 68
        e3 = self.e3_conv(e2_pool)
        # --> # 512, 64, 64
        assert (e3.shape[1],e3.shape[2],e3.shape[3])==(512, 64, 64), "encoder layer e3 expected shape {}".format("(512, 64, 64)")
        #
        # pooling
        e3_pool = self.e3_pool(e3) 
        # --> # 512, 32, 32
        assert (e3_pool.shape[1],e3_pool.shape[2],e3_pool.shape[3])==(512, 32, 32), "encoder layer e3 after pooling expected shape{}".format("(512, 32, 32)")
        #
        # cropping and cropped indices for skip connections
        e3_crop = self.e3_crop(e3) 
        # --> # 512, 56, 56
        _, idx3_crop = self.e3_pool_idx(e3_crop) # --> pooling indices from 512, 56, 56 cropped map
        # --> # 512, 28, 28
        assert (e3_crop.shape[1],e3_crop.shape[2],e3_crop.shape[3])==(512, 56, 56), "encoder layer e3 after cropping expected shape{}".format("(512, 56, 56)")
        assert (idx3_crop.shape[1],idx3_crop.shape[2],idx3_crop.shape[3])==(512, 28, 28), "encoder indices idx3 after cropping expected shape{}".format("(512, 28, 28)")
        #
        # bottleneck
        # --> # 512, 32, 32
        b = self.bottleneck_conv(e3_pool) 
        # --> # 1024, 28, 28
        assert (b.shape[1],b.shape[2],b.shape[3])==(1024, 28, 28), "bottleneck expected shape{}".format("(1024, 28, 28)")
        #
        # decoder
        ##################
        # decoder layer 3 (reverse counting order)
        ##################
        #
        # upconvolution
        d3_upconv = self.d3_upconv(b) 
        # --> # 512, 28, 28
        assert (d3_upconv.shape[1],d3_upconv.shape[2],d3_upconv.shape[3])==(512, 28, 28), "decoder layer d3 after upconvolution expected shape{}".format("(512, 28, 28)")
        #
        # upsampling  idx3 - 512, 28, 28
        # --> # 512, 28, 28 
        d3_upsample = self.d3_upsample(d3_upconv,idx3_crop) 
        #d3_upsample = self.d3_upsample(b) 
        # --> # 512, 56, 56     
        assert (d3_upsample.shape[1],d3_upsample.shape[2],d3_upsample.shape[3])==(512, 56, 56), "decoder layer d3 after upsampling expected shape{}".format("(512, 56, 56)")
        #
        # concatenation
        d3_concat = torch.cat((e3_crop,d3_upsample),dim=1) 
        # -->  512,56,56 + 512,56,56 = 1024,56,56
        assert (d3_concat.shape[1],d3_concat.shape[2],d3_concat.shape[3])==(1024,56,56), "decoder layer d3 after concatenation expected shape{}".format("(1024,56,56)")
        #
        # convolution
        d3 = self.d3_conv(d3_concat) 
        # -->    # 512, 52, 52 
        assert (d3.shape[1],d3.shape[2],d3.shape[3])==(512, 52, 52), "decoder layer d3 final expected shape{}".format("(512, 52, 52)")
        #
        ##################
        # decoder layer 2
        ##################
        #
        # upconvolution
        d2_upconv = self.d2_upconv(d3) 
        # --> # 256, 52, 52
        assert (d2_upconv.shape[1],d2_upconv.shape[2],d2_upconv.shape[3])==(256, 52, 52), "decoder layer d2 after upconvolution expected shape{}".format("(256, 52, 52)")
        #
        # upsampling - idx2 - # 256, 52, 52
        d2_upsample = self.d2_upsample(d2_upconv,idx2_crop) 
        #d2_upsample = self.d2_upsample(d3) 
        # --> 256, 104, 104
        assert (d2_upsample.shape[1],d2_upsample.shape[2],d2_upsample.shape[3])==(256, 104, 104), "decoder layer d2 after upsampling expected shape{}".format("(256, 104, 104)")
        #
        # concatenation
        d2_concat = torch.cat((e2_crop,d2_upsample),dim=1) 
        # -->  256, 104, 104 + 256, 104, 104 = 512, 104, 104
        assert (d2_concat.shape[1],d2_concat.shape[2],d2_concat.shape[3])==(512, 104, 104), "decoder layer d2 after concatenation expected shape{}".format("(512, 104, 104)")
        #
        # convolution
        d2 = self.d2_conv(d2_concat)   
        # 256, 100, 100 
        assert (d2.shape[1],d2.shape[2],d2.shape[3])==(256, 100, 100 ), "decoder layer d2 final expected shape{}".format("(256, 100, 100 )")
        #
        ##################
        # decoder layer 1
        ##################
        #
        # upconvolution
        d1_upconv = self.d1_upconv(d2) 
        # --> # 128, 100, 100
        assert (d1_upconv.shape[1],d1_upconv.shape[2],d1_upconv.shape[3])==(128, 100, 100), "decoder layer d1 after upconvolution expected shape{}".format("(128,100,100)")
        #
        # upsampling
        d1_upsample = self.d1_upsample(d1_upconv,idx1_crop) 
        #d1_upsample = self.d1_upsample(d2) 
        # --> 128, 200, 200
        assert (d1_upsample.shape[1],d1_upsample.shape[2],d1_upsample.shape[3])==(128, 200, 200), "decoder layer d1 after upsampling expected shape{}".format("(128, 200, 200)")
        #
        # concatenation
        d1_concat = torch.cat((e1_crop,d1_upsample),dim=1) 
        # -->  128, 200, 200 + 128, 200, 200 = 256, 200, 200
        assert (d1_concat.shape[1],d1_concat.shape[2],d1_concat.shape[3])==(256, 200, 200), "decoder layer d1 after concatenation expected shape{}".format("(256, 200, 200)")
        #
        # convolution
        d1 = self.d1_conv(d1_concat)   
        # -->    # 128, 196, 196 
        assert (d1.shape[1],d1.shape[2],d1.shape[3])==(128, 196, 196), "decoder layer d1 final expected shape{}".format("(128, 196, 196)")
        #
        ##################
        # decoder layer 0
        ##################
        #
        # upconvolution
        d0_upconv = self.d0_upconv(d1) 
        # --> # 64, 196, 196
        assert (d0_upconv.shape[1],d0_upconv.shape[2],d0_upconv.shape[3])==(64, 196, 196), "decoder layer d0 after upconvolution expected shape{}".format("(64, 196, 196)")
        #
        # upsampling
        d0_upsample = self.d0_upsample(d0_upconv,idx0_crop) 
        #d0_upsample = self.d0_upsample(d1) 
        # --> 64, 392, 392
        assert (d0_upsample.shape[1],d0_upsample.shape[2],d0_upsample.shape[3])==(64, 392, 392), "decoder layer d0 after upsampling expected shape{}".format("(64, 392, 392)")
#
        # concatenation
        d0_concat = torch.cat((e0_crop,d0_upsample),dim=1) 
        # -->  64, 392, 392 + 64, 392, 392 = 128, 392, 392
        assert (d0_concat.shape[1],d0_concat.shape[2],d0_concat.shape[3])==(128, 392, 392), "decoder layer d0 after concatenation expected shape{}".format("(128, 392, 392)")
        #
        # convolution
        d0 = self.d0_conv(d0_concat) 
        # -->    # 1,388,388
        assert (d0.shape[1],d0.shape[2],d0.shape[3])==(1,388,388), "decoder layer d0 final expected shape{}".format("(1,388,388)")
        
        # return d0 output
        return d0

In [None]:
def iou_pytorch(outputs: torch.Tensor, labels: torch.Tensor):
    # You can comment out this line if you are passing tensors of equal shape
    # But if you are passing output from UNet or something it will most probably
    # be with the BATCH x 1 x H x W shape
    outputs = outputs.squeeze(1).byte()  # BATCH x 1 x H x W => BATCH x H x W
    labels = labels.squeeze(1).byte()
    SMOOTH = 1e-8
    intersection = (outputs & labels).float().sum((1, 2))  # Will be zero if Truth=0 or Prediction=0
    union = (outputs | labels).float().sum((1, 2))         # Will be zzero if both are 0
    
    iou = (intersection + SMOOTH) / (union + SMOOTH)  # We smooth our devision to avoid 0/0
    
    thresholded = torch.clamp(20 * (iou - 0.5), 0, 10).ceil() / 10  # This is equal to comparing with thresolds
    
    return thresholded  # 

## Define segmentation metric

In [None]:
bce_loss = nn.BCEWithLogitsLoss()

## Setup training pipeline

In [None]:
def train(model, opt, loss_fn, epochs, batch_size, data_tr, data_val, metric, lr_sched=None):
    """
    - trains the model
    - computes training loss, validation loss and validation score
    - plots the progress
    ---
    input
    ---
    model: 
        pytorch  model to train
    opt: 
        optimizer
    loss_fn : 
        loss function
    epochs: int
        number of epochs
    data_tr: dataloader
        training data
    data_val: dataloader
        validation data
    metric: 
        segmentation metric
    lr_sched: 
        scheduler
        default: None
    ---
    return
    ---
    training loss, validation loss and validation score
    """
    # loss and score declaration
    train_loss_values = []
    val_loss_values = []
    avg_score_values = []
    #
    # iterate through epochs
    for epoch in range(epochs):
        #
        # epochs counter
        ii = 0
        print('* Epoch %d/%d' % (epoch+1, epochs))
        # 
        # average loss declaration
        avg_loss = 0
        val_avg_loss = 0
        #
        ##############
        # train model 
        ##############
        #
        model.train() 
        #
        for X_batch, Y_batch in data_tr:
            #print("batch ", ii, " out of ", len(data_tr) )
            # epochs counter
            ii+=1
            #print("X_batch.shape from data_tr",X_batch.shape)
            #print("Y_batch.shape from data_tr",Y_batch.shape)
          #
            # data to device
            X_batch = X_batch.to(device)
            #print("X_batch.shape to device",X_batch.shape)
            Y_batch = Y_batch.to(device)
            #print("Y_batch.shape to device",Y_batch.shape)
            #
            # set parameter gradients to zero
            opt.zero_grad()
            #
            # forward propagation
            #
            # get logits
            Y_pred = model(X_batch)
            #print("Y_pred.shape",Y_pred.shape)
            #
            # compute train loss
            loss =  loss_fn(Y_pred,Y_batch) # forward-pass - BCEWithLogitsLoss (pred,prob)
            #
            # backward-pass
            #
            loss.backward()  
            # update weights
            opt.step()  
            #
            # calculate loss to show the user
            avg_loss += loss 
        avg_loss = avg_loss / len(data_tr)
        #
        print('loss: %f' % avg_loss)
        # append train loss
        train_loss_values.append(avg_loss.detach().cpu().numpy())

        #
        # validate model
        #
        with torch.no_grad():
            # 
            # set dropout and batch normalization layers to evaluation mode before running inference
            model.eval()  
            score = 0
            avg_score = 0
          #
            for X_val, Y_val in data_val:
                # get logits for val set
                Y_hat =  model(X_val.to(device)).detach().to('cpu')
                #
                # only for plotting purposes: apply sigmoid and round to the nearest integer (0,1)
                # to obtain binary image
                Y_hat_2plot = torch.round(torch.sigmoid(Y_hat))
                #Y_hat = torch.round(torch.sigmoid(Y_hat))

                #print("Y hat shape", Y_hat.shape)
                #
                # compute val loss and append it
                val_loss =  loss_fn(Y_hat, Y_val)
                val_avg_loss += val_loss 
                #
                # compute score for the current batch
                #score += metric(Y_hat_2plot.to(device), Y_val.to(device)).mean().item()
                # temporarily replace by metric().mean without .item() because this leads to float has no attribute detach error 
                # see https://github.com/horovod/horovod/issues/852
                score += metric(Y_hat_2plot.to(device), Y_val.to(device)).mean()
            #
            # compute and append average val loss at current epoch
            val_avg_loss = val_avg_loss / len(data_val)
            val_loss_values.append(val_avg_loss.detach().cpu().numpy())
            #
            # compute and append average score at current epoch
            avg_score = score/len(data_val)
            avg_score_values.append(avg_score.detach().cpu().numpy())



        clear_output(wait=True)
        
        # plotting
        num_images_to_plot = 5 * (batch_size > 5) + batch_size * (batch_size <= 5)
        rcParams['figure.figsize'] = (2*num_images_to_plot,2*4)
        #
        for k in range(num_images_to_plot):
            # subplot (height, width, absolute image position)
            plt.subplot(4, num_images_to_plot, k+1)
            plt.imshow(np.rollaxis(X_val[k].numpy(), 0, 3), cmap='gray')
            plt.title('Input image')
            plt.axis('off')


            plt.subplot(4, num_images_to_plot, k+num_images_to_plot+1)
            plt.imshow(Y_hat[k, 0], cmap='gray')
            plt.title('Output')
            plt.axis('off')


            plt.subplot(4, num_images_to_plot, k+num_images_to_plot*2+1)
            plt.imshow(Y_hat_2plot[k, 0], cmap='gray')
            plt.title('Binary output')
            plt.axis('off')


            plt.subplot(4, num_images_to_plot, k+num_images_to_plot*3+1)
            plt.imshow(Y_val[k, 0], cmap='gray')
            plt.title('Ground truth')
            plt.axis('off')
            
            plt.tight_layout()

        plt.suptitle('%d / %d - train loss: %f' % (epoch+1, epochs, avg_loss))
        plt.suptitle('%d / %d - val. loss: %f' % (epoch+1, epochs, val_avg_loss))
        plt.show()


        # CHANGES HERE
        # make a scheduler step if required
        if lr_sched != None:
            lr_sched.step()
        # CHANGES END

    plt.plot(train_loss_values)
    plt.plot(val_loss_values)
    plt.plot(avg_score_values)
    plt.legend(["train_loss", "val_loss", "val_score"], loc ="lower right")
    plt.show

    return train_loss_values, val_loss_values, avg_score_values

In [None]:
def score_model(model, metric, data):
    """
    computes model's score using provided metric and data
    ---
    return
    ---
    scores/len(data): float
        model's score
    """
    # set dropout and batch normalization layers to evaluation mode before running inference
    model.eval()
    scores = 0
    # iterate thru data
    for X_batch, Y_label in data:
      # no gradient for validation
        with torch.no_grad():
          #
          # predict
            Y_pred = model(X_batch.to(device))
          #
          # compute sigmoid and round to the nearest integer (0,1) 
          # to be able to compare with the binary ground truth images
            Y_pred = torch.round(torch.sigmoid(Y_pred))
            
        scores += metric(Y_pred, Y_label.to(device)).mean().item()

    return scores/len(data)

In [None]:
# send model to device 
unet_model = UNet().to(device)

# define 
max_epochs = MAX_EPOCHS
batch_size = BATCH_SIZE
bce_loss = nn.BCEWithLogitsLoss()
unet_optimizer = optim.AdamW(unet_model.parameters(), lr=LEARNING_RATE)
# scheduler
scheduler = optim.lr_scheduler.StepLR(optimizer=unet_optimizer, step_size=SCHEDULER_STEP, gamma=SCHEDULER_GAMMA)

In [None]:
train_loss_values, val_loss_values, avg_score_values = train(unet_model, unet_optimizer, bce_loss, max_epochs, batch_size, data_tr, data_val, iou_pytorch, lr_sched=scheduler)

In [None]:
score_model(unet_model, iou_pytorch, data_val)

## Save loss, score and model

In [None]:
import pandas as pd

# map individual array elements to floats and then the map to a list
train_loss_values_2save = (list(map(float,train_loss_values)))
val_loss_values_2save = (list(map(float,val_loss_values)))
avg_score_values_2save = (list(map(float,avg_score_values)))

# save loss and score as csv
pd.DataFrame(train_loss_values_2save).to_csv('/kaggle/working/train_loss_values.csv', index = False)
pd.DataFrame(val_loss_values_2save).to_csv('/kaggle/working/val_loss_values.csv', index = False)
pd.DataFrame(avg_score_values_2save).to_csv('/kaggle/working/avg_score_values.csv', index = False)

# save model
torch.save(unet_model.state_dict(), '/kaggle/working/my-unet-model.pt')

In [None]:
states = {
        'number of epochs': MAX_EPOCHS,
        'state_dict': unet_model.state_dict(),
        'optimizer': unet_optimizer.state_dict()
         }
torch.save(states, '/kaggle/working/my-unet-model-states.pt')