## Import Packages

In [None]:
#Importing modules
import os
import sys
import cv2
import torch
import random
import numpy as np
import pandas as pd
from PIL import Image
import torch.nn as nn
from IPython import display
import albumentations as albu
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchvision import transforms
import torchvision.models as models
from torch.utils.data import DataLoader
import torchvision.transforms.functional as TF
from torch.utils.data.sampler import SubsetRandomSampler
from sklearn.model_selection import StratifiedShuffleSplit

## Progress Bar

In [None]:
def update_progress(progress,line):
    '''
    update_progress() : Displays or updates a console progress bar
    Accepts a float between 0 and 1. Any int will be converted to a float.
    A value under 0 represents a 'halt'.
    A value at 1 or bigger represents 100%
    '''
    barLength = 50 # Modify this to change the length of the progress bar
    status = line
    if isinstance(progress, int):
        progress = float(progress)
    if not isinstance(progress, float):
        progress = 0
        status = "error: progress var must be float"
    if progress < 0:
        progress = 0
        status = "Halt..."
    if progress >= 1:
        progress = 1
        status = "Done..."+" "*50
    block = int(round(barLength*progress))
    text = "\rPercent: [{:s}] {:.2f}% {:s}".format( "â–ˆ"*block + "-"*(barLength-block), progress*100, status)
    sys.stdout.write(text)
    sys.stdout.flush()

## Mask Visualize Func

In [None]:
#Image Visualization
# helper function for data visualization
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(np.array(image))
    plt.show()

## Class List

In [None]:
#Total Number of Classes
classList = ['1','2','3','4']

## Image Preprocessor

In [None]:
#Processing of Image
class ImageProcessor:
    def HorizontalFlip(self,p = 0.5):
        return albu.HorizontalFlip(p = p)
    
    def ShiftScaleRotate(self,scale_limit = 0.5, rotate_limit = 0, shift_limit = 0.1, p = 1, border_mode = 0):
        return albu.ShiftScaleRotate(
            scale_limit = scale_limit,
            rotate_limit = rotate_limit,
            shift_limit = rotate_limit,
            p = p, border_mode = border_mode
        )

    def PadIfNeeded(self,min_height = 256, min_width = 256, always_apply = True, border_mode = 0):
        return albu.PadIfNeeded(
            min_height = min_height,
            min_width = min_width,
            always_apply = always_apply,
            border_mode = border_mode
        )
    
    def RandomCrop(self,height = 256, width = 256, always_apply = True):
        return albu.RandomCrop(height = height, width = width, always_apply = always_apply)
    
    def IAAAdditiveGaussianNoise(self,p = 0.2):
        return albu.IAAAdditiveGaussianNoise(p = p)
    
    def IAAPerspective(self,p = 1):
        return albu.IAAPerspective(p = p)
    
    def CLAHE(self,p = 1):
        return albu.CLAHE(p = p)
    
    def RandomBrightness(self,p = 1):
        return albu.RandomBrightness(p = p)
    
    def RandomGamma(self,p = 1):
        return albu.RandomGamma(p = p)
    
    def IAASharpen(self,p = 1):
        return albu.IAASharpen(p = p)
    
    def Blur(self,blur_limit = 3, p = 1):
        return albu.Blur(blur_limit = blur_limit, p = p)
    
    def MotionBlur(self,blur_limit = 3, p = 1):
        return albu.MotionBlur(blur_limit = blur_limit, p = p)
    
    def RandomContrast(self, p = 1):
        return albu.RandomContrast(p = p)
    
    def HueSaturationValue(self,p = 1):
        return albu.HueSaturationValue(p = p)
    
    def OneOf(self,operations = [], p = 0.9):
        return albu.OneOf(operations, p = p)
    
    def Compose(self,trans = []):
        return albu.Compose(trans)
    
    def Lambda(self,**kwargs):
        return albu.Lambda(**kwargs)
    
    def augment(self,**kwargs):
        kwargs = dict(map(lambda item:(item[0],np.array(item[1])),kwargs.items()))
        transform = self.Compose([
            self.HorizontalFlip(p=0.5),

            self.ShiftScaleRotate(scale_limit=0.5, rotate_limit=0, shift_limit=0.1, p=1, border_mode=0),

            self.PadIfNeeded(min_height=256, min_width=256, always_apply=True, border_mode=0),
            self.RandomCrop(height=256, width=256, always_apply=True),

            self.IAAAdditiveGaussianNoise(p=0.2),
            self.IAAPerspective(p=0.5),

            self.OneOf(
                [
                    self.CLAHE(p=1),
                    self.RandomBrightness(p=1),
                    self.RandomGamma(p=1),
                ],
                p=0.9,
            ),

            self.OneOf(
                [
                    self.IAASharpen(p=1),
                    self.Blur(blur_limit=3, p=1),
                    self.MotionBlur(blur_limit=3, p=1),
                ],
                p=0.9,
            ),

            self.OneOf(
                [
                    self.RandomContrast(p=1),
                    self.HueSaturationValue(p=1),
                ],
                p=0.9,
            )
        ])
        sample = transform(**kwargs)
        return dict(map(lambda item:(item[0],Image.fromarray(sample[item[0]])),kwargs.items())).values()
    #Run Length Encoding
    def rle2mask(self,mask_rle, shape = (1600, 256)):
        img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
        if mask_rle:
            s = mask_rle.split()
            starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
            starts -= 1
            ends = starts + lengths
            for lo, hi in zip(starts, ends):
                img[lo:hi] = 1
        return img.reshape(shape).T
    
    def mask2rle(self,img):
        '''
        img: numpy array, 1 - mask, 0 - background
        Returns run length as string formated
        '''
        pixels= np.array(img).T.flatten()
        pixels = np.concatenate([[0], pixels, [0]])
        runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
        runs[1::2] -= runs[::2]
        return ' '.join(str(x) for x in runs)

## Train Dataset

In [None]:
#Class for getting Datas from training set and processing Image
class TrainDataset:
    class_map = torch.Tensor([
        [1,0],
        [0,1]
    ])
    def __init__(self,path,class_id):
        self.path = path
        self.input = pd.read_csv('%s/train.csv'%self.path)
        
        self.naInput = self.input[self.input.isnull().any(axis=1)].reset_index(drop=True)
        
        self.input.dropna(inplace=True)
        self.input = self.input[self.input.ClassId == int(class_id)].reset_index(drop=True)
        
        self.ip = ImageProcessor()

    def __len__(self):
        return len(self.input)

    def __getitem__(self,idx):
        image = Image.open('%s/train_images/%s'%(self.path,self.input.iloc[idx,0]))
        mask = Image.fromarray(self.ip.rle2mask(self.input.iloc[idx,2])*255)
        p = random.random()
        while True:
            img,msk = self.ip.augment(image = image, mask = mask)
            img = TF.to_tensor(img)
            msk = TF.to_tensor(msk)
            if (p > 0.5 and msk.sum()):
                break
            if p < 0.5 and not msk.sum():
                break
            if p < 0.5 and msk.sum():
                n = random.randint(0,len(self.naInput))
                image = image = Image.open('%s/train_images/%s'%(self.path,self.naInput.iloc[idx,0]))
        image, mask = img, msk
        image = TF.normalize(image,mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]) 
        return image,self.class_map[mask[0].round().long()].permute(2,0,1)

In [None]:
#Distributing to different classes
train_dataset = {}
for classId in classList:
    train_dataset[classId] = TrainDataset('../input/severstal-steel-defect-detection/',classId)

In [None]:
train_dataset

## Visualize Mask

In [None]:
## Visualize random image of different classes
cls = str(np.random.choice([1,4]))
print('Class: %s'%cls)
n = np.random.choice(len(train_dataset[cls]))
image,mask = train_dataset[cls][n]
visualize(image = TF.to_pil_image(image),mask = TF.to_pil_image(mask).convert('L'))

## Encoder

In [None]:
#Hard Attention
class SpatialAttention2d(nn.Module):
    def __init__(self, channel):
        super(SpatialAttention2d, self).__init__()
        self.squeeze = nn.Conv2d(channel, 1, kernel_size=1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        z = self.squeeze(x)
        z = self.sigmoid(z)
        return x * z

In [None]:
#Global Attention Unit-Soft Attention
class GAB(nn.Module):
    def __init__(self, input_dim, reduction=4):
        super(GAB, self).__init__()
        self.global_avgpool = nn.AdaptiveAvgPool2d(1)
        self.conv1 = nn.Conv2d(input_dim, input_dim // reduction, kernel_size=1, stride=1)
        self.conv2 = nn.Conv2d(input_dim // reduction, input_dim, kernel_size=1, stride=1)
        self.relu = nn.ReLU(inplace=True)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        z = self.global_avgpool(x)
        z = self.relu(self.conv1(z))
        z = self.sigmoid(self.conv2(z))
        return x * z

In [None]:
#Concurrent Spital and Channel Squeze and Excitation
class SCse(nn.Module):
    def __init__(self, dim):
        super(SCse, self).__init__()
        self.satt = SpatialAttention2d(dim)
        self.catt = GAB(dim)

    def forward(self, x):
        return self.satt(x) + self.catt(x)

In [None]:
#Residual Network- Adding More Layer
class ResNet34(nn.Module):
    def __init__(self, pretrained=True):
        """Declare all needed layers."""
        super(ResNet34, self).__init__()
        resnet = models.resnet34(pretrained=pretrained)
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False)
        self.bn1 =  resnet.bn1
        self.relu = resnet.relu
        self.maxpool = resnet.maxpool
        self.layer1 = nn.Sequential(resnet.layer1,SCse(64))
        self.layer2 = nn.Sequential(resnet.layer2,SCse(128))
        self.layer3 = nn.Sequential(resnet.layer3,SCse(256))
        self.layer4 = nn.Sequential(resnet.layer4,SCse(512))
        del resnet

    def forward(self, x):
        feature_map = []
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        for i in range(1,5):
            x = getattr(self,'layer%s'%i)(x)
            feature_map.append(x)

        out = nn.AvgPool2d(x.shape[2:])(x).view(x.shape[0], -1)

        return feature_map, out

## Decoder

In [None]:
#Feature Pyramid Attention--Extract dense feature of Picture
class FPA(nn.Module):
    def __init__(self, channels=512):
        """
        Feature Pyramid Attention
        :type channels: int
        """
        super(FPA, self).__init__()
        channels_mid = int(channels/4)

        self.channels_cond = channels

        # Master branch
        self.conv_master = nn.Conv2d(self.channels_cond, channels, kernel_size=1, bias=False)
        self.bn_master = nn.BatchNorm2d(channels)

        # Global pooling branch
        self.conv_gpb = nn.Conv2d(self.channels_cond, channels, kernel_size=1, bias=False)
        self.bn_gpb = nn.BatchNorm2d(channels)

        # C333 because of the shape of last feature maps is (16, 16).
        self.conv7x7_1 = nn.Conv2d(self.channels_cond, channels_mid, kernel_size=(7, 7), stride=2, padding=3, bias=False)
        self.bn1_1 = nn.BatchNorm2d(channels_mid)
        self.conv5x5_1 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(5, 5), stride=2, padding=2, bias=False)
        self.bn2_1 = nn.BatchNorm2d(channels_mid)
        self.conv3x3_1 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(3, 3), stride=2, padding=1, bias=False)
        self.bn3_1 = nn.BatchNorm2d(channels_mid)

        self.conv7x7_2 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(7, 7), stride=1, padding=3, bias=False)
        self.bn1_2 = nn.BatchNorm2d(channels_mid)
        self.conv5x5_2 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(5, 5), stride=1, padding=2, bias=False)
        self.bn2_2 = nn.BatchNorm2d(channels_mid)
        self.conv3x3_2 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(3, 3), stride=1, padding=1, bias=False)
        self.bn3_2 = nn.BatchNorm2d(channels_mid)

        # Convolution Upsample
        self.conv_upsample_3 = nn.ConvTranspose2d(channels_mid, channels_mid, kernel_size=4, stride=2, padding=1, bias=False)
        self.bn_upsample_3 = nn.BatchNorm2d(channels_mid)

        self.conv_upsample_2 = nn.ConvTranspose2d(channels_mid, channels_mid, kernel_size=4, stride=2, padding=1, bias=False)
        self.bn_upsample_2 = nn.BatchNorm2d(channels_mid)

        self.conv_upsample_1 = nn.ConvTranspose2d(channels_mid, channels, kernel_size=4, stride=2, padding=1, bias=False)
        self.bn_upsample_1 = nn.BatchNorm2d(channels)

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        """
        :param x: Shape: [b, 2048, h, w]
        :return: out: Feature maps. Shape: [b, 2048, h, w]
        """
        # Master branch
        x_master = self.conv_master(x)
        x_master = self.bn_master(x_master)

        # Global pooling branch
        x_gpb = nn.AvgPool2d(x.shape[2:])(x).view(x.shape[0], self.channels_cond, 1, 1)
        x_gpb = self.conv_gpb(x_gpb)
        x_gpb = self.bn_gpb(x_gpb)

        # Branch 1
        x1_1 = self.conv7x7_1(x)
        x1_1 = self.bn1_1(x1_1)
        x1_1 = self.relu(x1_1)
        x1_2 = self.conv7x7_2(x1_1)
        x1_2 = self.bn1_2(x1_2)

        # Branch 2
        x2_1 = self.conv5x5_1(x1_1)
        x2_1 = self.bn2_1(x2_1)
        x2_1 = self.relu(x2_1)
        x2_2 = self.conv5x5_2(x2_1)
        x2_2 = self.bn2_2(x2_2)

        # Branch 3
        x3_1 = self.conv3x3_1(x2_1)
        x3_1 = self.bn3_1(x3_1)
        x3_1 = self.relu(x3_1)
        x3_2 = self.conv3x3_2(x3_1)
        x3_2 = self.bn3_2(x3_2)
        # Merge branch 1 and 2
        x3_upsample = self.relu(self.bn_upsample_3(self.conv_upsample_3(x3_2)))
        x2_merge = self.relu(x2_2 + x3_upsample)
        x2_upsample = self.relu(self.bn_upsample_2(self.conv_upsample_2(x2_merge)))
        x1_merge = self.relu(x1_2 + x2_upsample)

        x_master = x_master * self.relu(self.bn_upsample_1(self.conv_upsample_1(x1_merge)))

        out = self.relu(x_master + x_gpb)

        return out

In [None]:
#Global Attention Unit
class GAU(nn.Module):
    def __init__(self, channels_high, channels_low, upsample=True):
        super(GAU, self).__init__()
        # Global Attention Upsample
        self.upsample = upsample
        self.conv3x3 = nn.Conv2d(channels_low, channels_low, kernel_size=3, padding=1, bias=False)
        self.bn_low = nn.BatchNorm2d(channels_low)

        self.conv1x1 = nn.Conv2d(channels_high, channels_low, kernel_size=1, padding=0, bias=False)
        self.bn_high = nn.BatchNorm2d(channels_low)

        if upsample:
            self.conv_upsample = nn.ConvTranspose2d(channels_high, channels_low, kernel_size=4, stride=2, padding=1, bias=False)
            self.bn_upsample = nn.BatchNorm2d(channels_low)
        else:
            self.conv_reduction = nn.Conv2d(channels_high, channels_low, kernel_size=1, padding=0, bias=False)
            self.bn_reduction = nn.BatchNorm2d(channels_low)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, fms_high, fms_low, fm_mask=None):
        """
        Use the high level features with abundant catagory information to weight the low level features with pixel
        localization information. In the meantime, we further use mask feature maps with catagory-specific information
        to localize the mask position.
        :param fms_high: Features of high level. Tensor.
        :param fms_low: Features of low level.  Tensor.
        :param fm_mask:
        :return: fms_att_upsample
        """
        b, c, h, w = fms_high.shape

        fms_high_gp = nn.AvgPool2d(fms_high.shape[2:])(fms_high).view(len(fms_high), c, 1, 1)
        fms_high_gp = self.conv1x1(fms_high_gp)
        fms_high_gp = self.bn_high(fms_high_gp)
        fms_high_gp = self.relu(fms_high_gp)

        fms_low_mask = self.conv3x3(fms_low)
        fms_low_mask = self.bn_low(fms_low_mask)

        fms_att = fms_low_mask * fms_high_gp
        if self.upsample:
            out = self.relu(
                self.bn_upsample(self.conv_upsample(fms_high)) + fms_att)
        else:
            out = self.relu(
                self.bn_reduction(self.conv_reduction(fms_high)) + fms_att)

        return out

In [None]:
class PANPlusPlus(nn.Module):
    def __init__(self, channels_blocks = [64,128,256,512]):
        super(PANPlusPlus, self).__init__()
        
        self.fpa_1 = FPA(channels=channels_blocks[1])
        self.fpa_2 = FPA(channels=channels_blocks[2])
        self.fpa_3 = FPA(channels=channels_blocks[3])

        self.gau_01 = GAU(channels_blocks[1],channels_blocks[0])
        self.gau_11 = GAU(channels_blocks[2],channels_blocks[1])
        self.gau_21 = GAU(channels_blocks[3],channels_blocks[2])

        self.gau_02 = GAU(channels_blocks[1],2*channels_blocks[0])
        self.gau_12 = GAU(channels_blocks[2],2*channels_blocks[1])

        self.gau_03 = GAU(2*channels_blocks[1],4*channels_blocks[0])

    def forward(self, fms=[]):
        """
        :param fms: Feature maps of forward propagation in the network. shape:[b, c, h, w]
        """
        fm_low_00 = fms[0] #[B,64,H,W]
        fm_low_10 = fms[1] #[B,128,H,W]
        fm_high_10 = self.fpa_1(fm_low_10) #[B,128,H,W]
        fm_high_01 = self.gau_01(fm_high_10, fm_low_00) #[B,64,H,W]

        fm_low_20 = fms[2] #[B,256,H,W]
        fm_high_20 = self.fpa_2(fm_low_20) #[B,256,H,W]
        fm_high_11 = self.gau_11(fm_high_20, fm_low_10) #[B,128,H,W]
        fm_high_02 = self.gau_02(fm_high_11, torch.cat([fm_low_00,fm_high_01],dim=1)) #[B,64+64,H,W]
 
        fm_low_30 = fms[3] #[B,512,H,W]
        fm_high_30 = self.fpa_3(fm_low_30) #[B,512,H,W]
        fm_high_21 = self.gau_21(fm_high_30,fm_low_20) #[B,256,H,W]
        fm_high_12 = self.gau_12(fm_high_21, torch.cat([fm_low_10,fm_high_11],dim=1)) #[B,128+128,H,W]
        fm_high_03 = self.gau_03(fm_high_12, torch.cat([fm_low_00,fm_high_01,fm_high_02],dim=1))
        return fm_high_03

In [None]:
class RESPANPLUS(nn.Module):
    def __init__(self,num_class,mask_criterion,defect_criterion):
        super(RESPANPLUS, self).__init__()

        self.mask_criterion = mask_criterion
        self.defect_criterion = defect_criterion

        self.enc = ResNet34(False)
        self.dec = PANPlusPlus()
        self.defect_classifier = nn.Sequential(
            nn.Linear(512,256),
            nn.ReLU(inplace=True),
            nn.Linear(256,num_class),
            nn.Softmax(dim=1))

        self.mask_classifier = nn.Sequential(
            nn.ConvTranspose2d(256, 64, kernel_size=2, stride=2, padding=0, bias=False), #kernel_size=4, stride=2, padding=1
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.Conv2d(32, num_class, kernel_size=1, bias=False),
            nn.Softmax(dim=1))
    
    def forward(self,x, y = None):
        fms,cls_feats = self.enc(x)
        logits = self.defect_classifier(cls_feats)
        out = torch.zeros(*logits.shape,*x.shape[2:], device = x.device)
        if y is not None:
            c_y = y.max(1)[1].sum([1,2]) > 0
            if c_y.sum() > 1:
                out[c_y == 1] = self.mask_classifier(self.dec([fm[c_y == 1] for fm in fms]))
                mask_loss = self.mask_criterion(out[:,1,:,:],y[:,1,:,:])
            else:
                mask_loss = 0
            class_loss = self.defect_criterion(logits,c_y.long())
            loss = class_loss + mask_loss
            return out,loss
        else:
            c_y = logits.max(1)[1]
            if c_y.sum() > 0:
                out[c_y == 1] = self.mask_classifier(self.dec([fm[c_y == 1] for fm in fms]))
            return out

## Focal BCE Loss

In [None]:
#Cross entropy Loss
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, logits=False, reduce=True):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.logits = logits
        self.reduce = reduce

    def forward(self, inputs, targets):
        if self.logits:
            BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        else:
            BCE_loss = F.binary_cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

        if self.reduce:
            return torch.mean(F_loss)
        else:
            return F_loss

## Load Checkpoint

In [None]:
if os.path.exists('../input/steel-defect-detection-last/model.pth'):
    checkpoint = torch.load(
        '../input/steel-defect-detection-last/model.pth',
        map_location = lambda storage, loc: storage.cuda()
    )
else:
    checkpoint = {}
torch.save(checkpoint,'model.pth')

## Model and loss function

In [None]:
model = checkpoint.get('model',{})
optimizer = checkpoint.get('optimizer',{})
exp_lr_scheduler = {}

mask_criterion = FocalLoss()
defect_criterion = nn.CrossEntropyLoss()

if torch.cuda.is_available():
    mask_criterion = mask_criterion.cuda()
    defect_criterion = defect_criterion.cuda()

for key in classList:
    model[key] = model.get(key,RESPANPLUS(2,mask_criterion,defect_criterion))
    if torch.cuda.is_available():
        model[key] = model[key].cuda()
    optimizer[key] = optimizer.get(key,torch.optim.Adam(
        [
            {'params': model[key].dec.parameters(), 'lr': 1e-4}, 
    
            # decrease lr for encoder in order not to permute 
            # pre-trained weights with large gradients on training start
            {'params': model[key].enc.parameters(), 'lr': 1e-5}
        ]
    ))
    exp_lr_scheduler[key] = torch.optim.lr_scheduler.StepLR(optimizer[key], step_size=7, gamma=0.1)


## Train and validation split with loader

In [None]:
batch_size = 10
train_loader = {}
val_loader = {}
train_indices = checkpoint.get('train_indices',{})
val_indices = checkpoint.get('val_indices',{})
for key in classList:
    sss = StratifiedShuffleSplit(n_splits=1, test_size=0.33, random_state=42)
    splitter = sss.split(train_dataset[key].input, train_dataset[key].input.EncodedPixels.apply(str.split).apply(len)>0)
    if not (key in train_indices and key in val_indices):
        train_indices[key], val_indices[key] = next(splitter)
    # Creating PT data samplers and loaders
    train_sampler = SubsetRandomSampler(train_indices[key])
    val_sampler = SubsetRandomSampler(val_indices[key])
    train_loader[key] = DataLoader(train_dataset[key],batch_size=batch_size,sampler=train_sampler)
    val_loader[key] = DataLoader(train_dataset[key], batch_size=batch_size,sampler=val_sampler)

## Metrics

In [None]:
#If two identical then 1 or 0
def dice_coeff(prob, target):
    batch_size = target.size(0)
    prob = prob.view(batch_size,-1)
    target = target.view(batch_size,-1)
    smooth = 1e-6
    intersection = (prob * target)
    score = (2. * intersection.sum(1) + smooth) / (prob.sum(1) + target.sum(1) + smooth)
    return score.mean()

## Validation Function

In [None]:
def validate():
    total_loss = 0
    total_accuracy = 0
    for i,key in enumerate(classList,start=1):
        print('\nvalidating of type %s'%key)
        label_loss = 0
        label_accuracy = 0
        completed = 0
        model[key].eval()
        for batch_idx,(data,target) in enumerate(val_loader[key],start=1):
            if torch.cuda.is_available():
                data = data.cuda()
                target = target.cuda()
            output,loss = model[key](data,target)
            accuracy = dice_coeff(output.max(1)[1].float(), target.max(1)[1].float()).item()
            label_loss += loss.item()
            label_accuracy += accuracy
            # print statistics
            completed += target.size(0)/len(val_indices[key])
            update_progress(completed,"Acuuracy: {:.2f} Loss: {:.4f}".format(accuracy*100,loss.item()))
            torch.cuda.empty_cache()
        print('\nAccuracy:',round(label_accuracy*100/batch_idx,2),'Loss:',round(label_loss/batch_idx,4))
        total_loss += label_loss/batch_idx
        total_accuracy += label_accuracy/batch_idx
    print("\nValidation Accuracy:",round(total_accuracy*100/i,2),'Validation Loss:',round(total_loss/i,4))
    return total_accuracy/i, total_loss/i

## Training Function

In [None]:
def train(epoch):
    total_loss = 0
    total_accuracy = 0
    print("Epoch:",epoch)
    for i,key in enumerate(classList,start=1):
        print('\nTraining of type %s'%key)
        label_loss = 0
        label_accuracy = 0
        completed = 0
        model[key].train()
        for batch_idx,(data,target) in enumerate(train_loader[key],start=1):
            if torch.cuda.is_available():
                data = data.cuda()
                target = target.cuda()
            output,loss = model[key](data,target)
            optimizer[key].zero_grad()
            loss.backward()
            optimizer[key].step()
            accuracy = dice_coeff(output.max(1)[1].float(), target.max(1)[1].float()).item()
            label_loss += loss.item()
            label_accuracy += accuracy
            # print statistics
            completed += target.size(0)/len(train_indices[key])
            update_progress(completed,"Acuuracy: {:.2f} Loss: {:.4f}".format(accuracy*100,loss.item()))
            torch.cuda.empty_cache()
        print('\nAccuracy:',round(label_accuracy*100/batch_idx,2),'Loss:',round(label_loss/batch_idx,4))
        total_loss += label_loss/batch_idx
        total_accuracy += label_accuracy/batch_idx
        exp_lr_scheduler[key].step()
    print("\nTrain Accuracy:",round(total_accuracy*100/i,2),'Train Loss:',round(total_loss/i,4))
    val_accuracy, val_loss = validate()
    return total_accuracy/i,val_accuracy,total_loss/i,val_loss

## Train and Validate

In [None]:
train_data = checkpoint.get('train_data',[])
val_data = checkpoint.get('val_data',[])
epoch_data = checkpoint.get('epoch_data',[])
for i in range(len(epoch_data)+1,len(epoch_data)+1+1):
#     train_accuracy,val_accuracy,train_loss,val_loss = train(i)
#     train_data.append([train_accuracy,train_loss])
#     val_data.append([val_accuracy,val_loss])
#     epoch_data.append(i)
#     # Save model
#     #if np.array(val_data)[:,1].min() == val_loss:
#     torch.save({'model':model,'optimizer':optimizer,'train_indices':train_indices,'val_indices':val_indices,'train_data':train_data,'val_data':val_data,'epoch_data':epoch_data},'model.pth')
    # Visualize
    fig, ax = plt.subplots(1, 2, figsize=(10*2,7*1))
    ax[0].plot(epoch_data, np.array(train_data)[:,0], label="Train Accuracy {:.2f}".format(train_data[-1][0]*100))
    ax[0].plot(epoch_data, np.array(val_data)[:,0], label="Validation Accuracy {:.2f}".format(val_data[-1][0]*100))
    ax[1].plot(epoch_data, np.array(train_data)[:,1], label="Train Loss {:.4f}".format(train_data[-1][1]))
    ax[1].plot(epoch_data,np.array(val_data)[:,1], label="Validation Loss {:.4f}".format(val_data[-1][1]))                     
    display.clear_output(wait=False)
    ax[0].legend()
    ax[1].legend()
    plt.show()

## Visualize on train dataset

In [None]:
def predict(typ,ind):
    typ = str(typ)
    x,y = train_dataset[typ][ind]
    x = torch.Tensor(x).unsqueeze(0)
    y = torch.Tensor(y).unsqueeze(0)
    if torch.cuda.is_available():
        x = x.cuda()
        y = y.cuda()
    model[typ].eval()
    pred = model[typ](x)
    print('Dice Score:',dice_coeff(pred.max(1)[1].float(), y.max(1)[1].float()).item())
    x = transforms.ToPILImage()(x.squeeze(0).cpu())
    y = transforms.ToPILImage()(y.squeeze(0).cpu())
    p = transforms.ToPILImage()(pred.squeeze(0).cpu())
    mask = transforms.ToPILImage()((pred.max(1)[1]*255).type(torch.uint8).cpu())
    return x,y,p,mask

## Prediction Visualize

In [None]:
## Visualize
cls = str(np.random.choice([1,4]))
print('Class: %s'%cls)
n = np.random.choice(len(train_dataset[cls]))
i,gt,pp,pm = predict(cls,n)


visualize(
    image = i,
    ground_truth = gt.convert('L'),
    predicted_mask = pm.convert('L'),
    predicted_proba = pp.convert('L'),
)