In [None]:
import pandas as pd
import os
import glob
import cv2
import numpy as np
import matplotlib.pyplot as plt
import random
from mpl_toolkits.axes_grid1 import ImageGrid
from sklearn.model_selection import train_test_split
import copy
from tqdm import tqdm

import torch
from torch import nn, cat, squeeze
import torchvision
from torchvision import transforms as T
from torchvision.transforms import functional as F
from PIL import Image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
TRAIN_IMG_PATH = '../input/isic2018/ISIC2018_Task1-2_Training_Input/ISIC2018_Task1-2_Training_Input'
TRAIN_MASK_PATH = '../input/isic2018/ISIC2018_Task1_Training_GroundTruth/ISIC2018_Task1_Training_GroundTruth'
VAL_IMG_PATH = '../input/isic2018/ISIC2018_Task1-2_Validation_Input/ISIC2018_Task1-2_Validation_Input'
VAL_MASK_PATH = '../input/isic2018/ISIC2018_Task1_Validation_GroundTruth/ISIC2018_Task1_Validation_GroundTruth'

# train_img = os.listdir(TRAIN_IMG_PATH)
train_img = glob.glob(TRAIN_IMG_PATH + '*/*.jpg')
train_mask = glob.glob(TRAIN_MASK_PATH + '*/*.png')

val_img = glob.glob(VAL_IMG_PATH + '*/*.jpg')
val_mask = glob.glob(VAL_MASK_PATH + '*/*.png')

In [None]:
print(f'Len train image:', len(train_img))
print(f'Len train mask:', len(train_mask))
print(f'Len val image:', len(val_img))
print(f'Len val mask:', len(val_mask))

In [None]:
df = pd.DataFrame({'image_path':sorted(train_img), 'mask_path':sorted(train_mask), 'split':'train'})
df2 = pd.DataFrame({'image_path':sorted(val_img), 'mask_path':sorted(val_mask), 'split':'val'})
df = df.append(df2)
df

In [None]:
"""
The dataset was split into three subsets, training set, validation set, and test set, which the proportion is 70%, 10% and 20% of the whole dataset, respectively
"""
seed = 108
df = df.reset_index(drop=True)
df_train = df[df['split']=='train']
train_df, df_ = train_test_split(df_train, shuffle=True, train_size=0.8, random_state=seed)
val_df, test_df = train_test_split(df_, shuffle=True, train_size=0.5, random_state=seed)
test_df = test_df.append(df[df['split']=='val'])
print('Train - Valid - Test set size')
print(f'Train: {train_df.shape}\nValid: {val_df.shape}\nTest: {test_df.shape}')

In [None]:
IMG_SIZE = 224
images = []
masks = []
for data in train_df.sample(5).values:
    img = cv2.resize(cv2.imread(data[0]), (IMG_SIZE, IMG_SIZE))
    mask = cv2.resize(cv2.imread(data[1]), (IMG_SIZE, IMG_SIZE))
    images.append(img)
    masks.append(mask)
images = np.hstack(np.array(images))
masks = np.hstack(np.array(masks))

fig = plt.figure(figsize=(25,25))
grid = ImageGrid(fig, 111, nrows_ncols=(3,1), axes_pad=0.5)

grid[0].imshow(images)
grid[0].set_title('Images', fontsize=20)
grid[0].axis('off')
grid[1].imshow(masks)
grid[1].set_title('Masks', fontsize=20)
grid[1].axis('off')
grid[2].imshow(images)
grid[2].imshow(masks, alpha=0.4)
grid[2].set_title('Image with mask', fontsize=20)
grid[2].axis('off')

In [None]:
def show_img(inputs, nrows=4, ncols=4, image=True):
    plt.figure(figsize=(10, 10))
    plt.subplots_adjust(wspace=0., hspace=0.)
    if len(inputs) > nrows*ncols:
        inputs = inputs[:nrows*ncols]
    
    for i in range(len(inputs)):
        if image is True:
            img = inputs[i].numpy().transpose(1,2,0)
            mean = [0.5, 0.5, 0.5]
            std = [0.5, 0.5, 0.5]
            img = (img*std + mean).astype(np.float32)
        else:
            img = inputs[i].numpy().astype(np.float32)
            img = img[0,:,:]
        plt.subplot(nrows, ncols, i+1)
        plt.imshow(img)
        plt.axis('off')
    return plt.show()

In [None]:
img = cv2.imread(data[0])
mask = cv2.imread(data[1])
print(img.shape)
print(mask.shape)

In [None]:
IMG_SIZE = 224
class SkinLesionDataset(torch.utils.data.Dataset):
    def __init__(self, df, image_size=IMG_SIZE, mode='train', augmentation_prob=0.4):
        self.df = df
        self.image_size = IMG_SIZE
        self.mode = mode
        self.augmentation_prob = augmentation_prob
        self.RotationDegree = [0, 90, 180, 270]
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        image = Image.open(self.df.iloc[idx, 0])
        mask = Image.open(self.df.iloc[idx, 1])
        
        aspect_ratio = image.size[1]/image.size[0]
        
        Transform = []
        p_transform = random.random()
        Transform.append(T.Resize((self.image_size, self.image_size)))
        if (self.mode=='train') and p_transform <= self.augmentation_prob:
            RotationDegree = self.RotationDegree[random.randint(0, 3)]
            if (RotationDegree == 90) or (RotationDegree == 270):
                aspect_ratio = 1/aspect_ratio
            Transform.append(T.RandomRotation(RotationDegree))
            
            Transform.append(T.RandomRotation(10))
            CropRange = random.randint(250,270)
            Transform.append(T.CenterCrop((int(CropRange*aspect_ratio),CropRange)))
            Transform = T.Compose(Transform)
            
            image = Transform(image)
            mask = Transform(mask)
            
            if random.random() < 0.5:
                image = F.hflip(image)
                mask = F.hflip(mask)
            if random.random() < 0.5:
                image = F.vflip(image)
                mask = F.vflip(mask)
            
            Transform = []                
            image = T.ColorJitter(brightness=0.2,contrast=0.2,hue=0.02)(image)
            
        Transform.append(T.Resize((self.image_size, self.image_size)))
        Transform.append(T.ToTensor())
        Transform = T.Compose(Transform)
        
        image = Transform(image)
        mask = Transform(mask)
        
        image = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(image)
        return image, mask

In [None]:
dataframe = {'train':train_df, 'val':val_df, 'test':test_df}
dataset = {x:SkinLesionDataset(dataframe[x], mode=x) for x in ['train', 'val', 'test']}
dataloader = {x:torch.utils.data.DataLoader(dataset[x], batch_size=16, shuffle=True, num_workers=2) for x in ['train', 'val', 'test']}
sizes = {x:len(dataloader[x]) for x in ['train','val','test']}
sizes

In [None]:
images, masks = next(iter(dataloader['train']))
print(images.shape)
print(masks.shape)
show_img(images)
show_img(masks, image=False)

In [None]:
from torch.nn import init

def init_weights(net, init_type='normal', gain=0.02):
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
            if init_type == 'normal':
                init.normal_(m.weight.data, 0.0, gain)
            elif init_type == 'xavier':
                init.xavier_normal_(m.weight.data, gain=gain)
            elif init_type == 'kaiming':
                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            elif init_type == 'orthogonal':
                init.orthogonal_(m.weight.data, gain=gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias.data, 0.0)
        elif classname.find('BatchNorm2d') != -1:
            init.normal_(m.weight.data, 1.0, gain)
            init.constant_(m.bias.data, 0.0)

    print('initialize network with %s' % init_type)
    net.apply(init_func)

class conv_block(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(conv_block,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )


    def forward(self,x):
        x = self.conv(x)
        return x

class up_conv(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(up_conv,self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
		    nn.BatchNorm2d(ch_out),
			nn.ReLU(inplace=True)
        )

    def forward(self,x):
        x = self.up(x)
        return x
    
class Attention_block(nn.Module):
    def __init__(self,F_g,F_l,F_int):
        super(Attention_block,self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(F_int)
            )
        
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self,g,x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1+x1)
        psi = self.psi(psi)

        return x*psi
    
class AttU_Net(nn.Module):
    def __init__(self,img_ch=3,output_ch=1):
        super(AttU_Net,self).__init__()
        
        self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)

        self.Conv1 = conv_block(ch_in=img_ch,ch_out=64)
        self.Conv2 = conv_block(ch_in=64,ch_out=128)
        self.Conv3 = conv_block(ch_in=128,ch_out=256)
        self.Conv4 = conv_block(ch_in=256,ch_out=512)
        self.Conv5 = conv_block(ch_in=512,ch_out=1024)

        self.Up5 = up_conv(ch_in=1024,ch_out=512)
        self.Att5 = Attention_block(F_g=512,F_l=512,F_int=256)
        self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)

        self.Up4 = up_conv(ch_in=512,ch_out=256)
        self.Att4 = Attention_block(F_g=256,F_l=256,F_int=128)
        self.Up_conv4 = conv_block(ch_in=512, ch_out=256)
        
        self.Up3 = up_conv(ch_in=256,ch_out=128)
        self.Att3 = Attention_block(F_g=128,F_l=128,F_int=64)
        self.Up_conv3 = conv_block(ch_in=256, ch_out=128)
        
        self.Up2 = up_conv(ch_in=128,ch_out=64)
        self.Att2 = Attention_block(F_g=64,F_l=64,F_int=32)
        self.Up_conv2 = conv_block(ch_in=128, ch_out=64)

        self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)


    def forward(self,x):
        # encoding path
        x1 = self.Conv1(x)

        x2 = self.Maxpool(x1)
        x2 = self.Conv2(x2)
        
        x3 = self.Maxpool(x2)
        x3 = self.Conv3(x3)

        x4 = self.Maxpool(x3)
        x4 = self.Conv4(x4)

        x5 = self.Maxpool(x4)
        x5 = self.Conv5(x5)

        # decoding + concat path
        d5 = self.Up5(x5)
        x4 = self.Att5(g=d5,x=x4)
        d5 = torch.cat((x4,d5),dim=1)        
        d5 = self.Up_conv5(d5)
        
        d4 = self.Up4(d5)
        x3 = self.Att4(g=d4,x=x3)
        d4 = torch.cat((x3,d4),dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        x2 = self.Att3(g=d3,x=x2)
        d3 = torch.cat((x2,d3),dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        x1 = self.Att2(g=d2,x=x1)
        d2 = torch.cat((x1,d2),dim=1)
        d2 = self.Up_conv2(d2)

        d1 = self.Conv_1x1(d2)

        return nn.Sigmoid()(d1)
    
model = AttU_Net(img_ch=3,output_ch=1).to(device)
init_weights(model)

In [None]:
def dice_coef_metric(pred, label):
    smooth = 1.0
    intersection = (pred*label).sum()
    union = pred.sum() + label.sum()
    return (2*intersection + smooth)/(union + smooth)

def dice_coef_loss(pred, label):
    smooth = 1.0
    intersection = 2.0 * (pred * label).sum() + smooth
    union = pred.sum() + label.sum() + smooth
    return 1 - (intersection / union)

def bce_dice_loss(pred, label):
    dice_loss = dice_coef_loss(pred, label)
    bce_loss = nn.BCELoss()(pred, label)
    return dice_loss + bce_loss

def compute_iou(model, loader, threshold=0.3):
    valloss = 0
    with torch.no_grad():
        for step, (data, target) in enumerate(loader):
            data = data.to(device)
            target = target.to(device)

            outputs = model(data)
            out_cut = np.copy(outputs.data.cpu().numpy())
            out_cut[np.nonzero(out_cut < threshold)] = 0.0
            out_cut[np.nonzero(out_cut >= threshold)] = 1.0

            loss = dice_coef_metric(out_cut, target.data.cpu().numpy())
            valloss += loss

    return valloss / step

def iou(y_true, y_pred):
    intersection = torch.sum(y_true * y_pred)
    return (intersection + 1.) / (torch.sum(y_true) + torch.sum(y_pred) - intersection + 1.)

def falsepos(y_true, y_pred):
    intersection = torch.sum(y_true * y_pred)
    return torch.sum(y_pred) - intersection

def falseneg(y_true, y_pred):
    intersection = torch.sum(y_true * y_pred)
    return torch.sum(y_true) - intersection

def precision(y_true, y_pred):
    intersection = torch.sum(y_true * y_pred)
    return intersection / (torch.sum(y_pred) + 1.)

def recall(y_true, y_pred):
    intersection = torch.sum(y_true * y_pred)
    return intersection / (torch.sum(y_true) + 1.)

def fscore(y_true, y_pred):
    presci = precision(y_true, y_pred)
    rec = recall(y_true, y_pred)
    return 2*(presci * rec)/(presci + rec)

def weighted_fscore_loss(prew=1, recw=1):
    def fscore_loss(y_true, y_pred):
        presci = precision(y_true, y_pred)
        rec = recall(y_true, y_pred)
        return -(prew+recw)*(presci * rec)/(prew*presci + recw*rec)
    return fscore_loss

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=3, factor=0.5)
measures = {'dice_coef':dice_coef_metric,
           'iou':iou,
           'precision':precision,
           'recall':recall,
           'fscore':fscore}

In [None]:
def train_model(model_name, model, dataloader, loss_func, optimizer, scheduler, measures, num_epochs, save_path='attn_unet.pt'):
    print('-'*10 , model_name, '-'*10)
    
    train_log = {k:[] for k in measures.keys()}
    train_log['loss'] = []
    val_log = {k:[] for k in measures.keys()}
    val_log['loss'] = []
    
    best_val_score = 0.
    best_wts = copy.deepcopy(model.state_dict())
    
    for epoch in range(num_epochs):
        
        print(f'Epoch {epoch+1}/{num_epochs}')
        
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()
                
            running_losses = []
            measurements = {k:0. for k in measures.keys()}

            for i, (images, masks) in enumerate(tqdm(dataloader[phase])):
                images = images.to(device)
                masks = masks.to(device)
                
                optimizer.zero_grad()

                outputs = model(images)
                loss = loss_func(outputs, masks)
                running_losses.append(loss.item())
        
                for (k,mobj) in measures.items():
                    measurements[k] += mobj(outputs, masks).item()
                
                if phase=='train':
                    loss.backward()
                    optimizer.step()
                           
            for k in measures.keys():
                measurements[k] = measurements[k] / len(dataloader[phase])
            measurements['loss'] = np.array(running_losses).mean()
            
            if phase=='val':
                scheduler.step(measurements['dice_coef'])
                
            if phase=='train':
                for k in measurements.keys():
                    train_log[k].append(measurements[k])
            else:
                for k in measurements.keys():
                    val_log[k].append(measurements[k])
                
            print(f'{phase}:', end='\t')
            for k,v in measurements.items():
                print(" {}:{:.4f}".format(k,v), end='  ')
            print()
            print(f'current lr:', optimizer.param_groups[0]['lr'])
            
            if phase=='val' and measurements['dice_coef'] > best_val_score:
                print(f"New score: {measurements['dice_coef']:.4f}\t Previous score: {best_val_score:.4f}")
                best_val_score = measurements['dice_coef']
                best_wts = copy.deepcopy(model.state_dict())
                torch.save(model.state_dict(), save_path)
                
    model.load_state_dict(best_wts)
    print(f'Training Completed. Best score: {best_val_score}')
    return train_log, val_log

In [None]:
num_epochs = 5
train_log, val_log = train_model("AttnUNet", model, dataloader, bce_dice_loss, optimizer, scheduler, measures, num_epochs)

In [None]:
def plot_model_history(model_name, train_log, val_log, num_epochs):
    n = len(train_log)
    x = np.arange(num_epochs)
    fig = plt.figure(figsize=(20, 10))
    plt.title(f"{model_name}", fontsize=15)
    for i, k in enumerate(train_log.keys()):       
        plt.subplot(n//2, 2, i+1)
        plt.plot(x, train_log[k], label=f'train_{k}', lw=3, c="b")
        plt.plot(x, val_log[k], label=f'val_{k}', lw=3, c="r")
        plt.legend(fontsize=12)
        plt.xlabel("Epoch", fontsize=15)
        plt.ylabel(f"{k}", fontsize=15)
    plt.show()
plot_model_history('Attn_UNet', train_log, val_log, num_epochs)

In [None]:
# Evaluation test set
test_log = {k:0. for k in measures.keys()}
with torch.no_grad():
    for i, (images, masks) in enumerate(tqdm(dataloader['test'])):
        images = images.to(device)
        masks = masks.to(device)
        outputs = model(images)
        for (k,mobj) in measures.items():
            test_log[k] += mobj(outputs, masks).item()
    for k in measures.keys():
        test_log[k] = test_log[k] / len(dataloader['test'])
        
print(test_log)

In [None]:
test_sample = val_df.sample(1).values[0]
image = cv2.resize(cv2.imread(test_sample[0]), (IMG_SIZE, IMG_SIZE))
mask = cv2.resize(cv2.imread(test_sample[1]), (IMG_SIZE, IMG_SIZE))

# pred
pred = torch.tensor(image.astype(np.float32) / 255.).unsqueeze(0).permute(0,3,1,2)
pred = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(pred)
pred = model(pred.to(device))
pred = pred.detach().cpu().numpy()[0, 0,:,:]

pred_t = np.copy(pred)
pred_t[np.nonzero(pred_t < 0.3)] = 255
pred_t[np.nonzero(pred_t >= 0.3)] = 0
pred_t = pred_t.astype("uint8")
# pred_t = np.dstack([pred_t, pred_t, pred_t])
# plot
fig, ax = plt.subplots(nrows=2,  ncols=2, figsize=(10, 10))

ax[0, 0].imshow(image)
ax[0, 0].set_title("image")
ax[0, 1].imshow(mask)
ax[0, 1].set_title("mask")
ax[1, 0].imshow(pred)
ax[1, 0].set_title("prediction")
ax[1, 1].imshow(pred_t)
ax[1, 1].set_title("prediction with threshold")
plt.show()

In [None]:
def makeAnnotatedImage(x, y, label):
    """
    x: images
    y: prediction
    label: mask
    """
    y = y.cpu().detach().numpy()[0,:,:]
    label = label.cpu().detach().numpy()[0,:,:]
    
    imgs = []
    orig = x.cpu().detach().numpy()
    orig = (orig + 1)*127
    orig = orig.astype(np.uint8)
    orig = np.dstack((orig[0,:,:], orig[1,:,:], orig[2,:,:]))
    
    img = cv2.cvtColor(orig, cv2.COLOR_BGR2RGB)
    cv2.putText(img, 'Original', (5,20), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0,0,0,0), 2)
    imgs.append(img)

    orig = cv2.cvtColor(orig, cv2.COLOR_RGB2HSV)

    # apply prediction and label markings
    img = np.copy(orig)
    h = img[:,:,0]
    s = img[:,:,1]
    h[label > 0.75] = 100 # BLUE
    s[label > 0.75] = 250
    cv2.putText(img, 'Label', (5,20), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0,0,0,0), 2)
    imgs.append(cv2.cvtColor(img, cv2.COLOR_HSV2BGR))

    img = np.copy(orig)
    h = img[:,:,0]
    s = img[:,:,1]
    h[np.nonzero(y > 0.75)] = 50 # GREEN
    s[np.nonzero(y > 0.75)] = 250
    cv2.putText(img, 'Prediction', (5,20), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0,0,0,0), 2)
    imgs.append(cv2.cvtColor(img, cv2.COLOR_HSV2BGR))

    img = np.copy(orig)
    h = img[:,:,0]
    s = img[:,:,1]
    h[y > .75] += 50 # GREEN
    s[y > .75] = 250
    h[label > .75] += 100 # BLUE
    s[label > .75] = 250
    cv2.putText(img, 'Combined', (5,20), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0,0,0,0), 2)
    imgs.append(cv2.cvtColor(img, cv2.COLOR_HSV2BGR))

    final = np.hstack(imgs)
    return final

In [None]:
images, masks = next(iter(dataloader['test']))
outputs = model(images.to(device))
samples = []

for idx in range(len(images)):
    x, y, output = images[idx], masks[idx], outputs[idx]
    samples.append(makeAnnotatedImage(x, output, y))
    
for idx in range(len(samples)):
    plt.figure()
    plt.axis('off')
    plt.imshow(samples[idx])

In [None]:
test_samples = test_df.sample(len(test_df)).values

def batch_preds_overlap(model, samples):
    """
    Computes prediction on the dataset
    
    Returns: list with images overlapping with predictions
    
    """
    prediction_overlap = []
    for test_sample in samples:

         # sample
            
        image = cv2.resize(cv2.imread(test_sample[0]),(IMG_SIZE, IMG_SIZE))
        image =  image / 255.
        ground_truth = cv2.resize(cv2.imread(test_sample[1], 0), (IMG_SIZE, IMG_SIZE)).astype("uint8")

        # pred
        prediction = torch.tensor(image).unsqueeze(0).permute(0,3,1,2)
        prediction = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(prediction)
        prediction = model(prediction.to(device).float())
        prediction = prediction.detach().cpu().numpy()[0,0,:,:]

        prediction[np.nonzero(prediction < 0.3)] = 0.0
        prediction[np.nonzero(prediction >= 0.3)] = 255.
        prediction = prediction.astype("uint8")

        # overlap 
        original_img = cv2.resize(cv2.imread(test_sample[0]),(IMG_SIZE, IMG_SIZE))

        _, thresh_gt = cv2.threshold(ground_truth, 127, 255, 0)
        _, thresh_p = cv2.threshold(prediction, 127, 255, 0)
        contours_gt, _ = cv2.findContours(thresh_gt, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
        contours_p, _ = cv2.findContours(thresh_p, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

        overlap_img = cv2.drawContours(original_img, contours_gt, 0, (0, 255, 0), 1)
        overlap_img = cv2.drawContours(overlap_img, contours_p, 0, (255,36,0), 1)#255,0,0
        prediction_overlap.append(overlap_img)

    return prediction_overlap
    
prediction_overlap_r = batch_preds_overlap(model, test_samples)

In [None]:
pred_overlap_5x1_r = []
pred_overlap_5x3_r = []

for i in range(5, 105+5, 5):
    pred_overlap_5x1_r.append(np.hstack(np.array(prediction_overlap_r[i-5:i])))

for i in range(3, 21+3, 3):
    pred_overlap_5x3_r.append(np.vstack(pred_overlap_5x1_r[i-3:i]))

In [None]:
def plot_plate_overlap(batch_preds, title, num):
    plt.figure(figsize=(15, 15))
    plt.imshow(batch_preds)
    plt.axis("off")

    plt.figtext(0.76,0.75,"Green - Ground Truth", va="center", ha="center", size=20,color="lime");
    plt.figtext(0.26,0.75,"Red - Prediction", va="center", ha="center", size=20, color="#ff0d00");
    plt.suptitle(title, y=.80, fontsize=20, weight="bold", color="#00FFDE");

    fn = "_".join((title+str(num)).lower().split()) + ".png"
    plt.savefig(fn, bbox_inches='tight', pad_inches=0.2, transparent=False, facecolor='black')
    plt.close()

title = "Predictions of Attn_UNet"

for num, batch in enumerate(pred_overlap_5x3_r):
    plot_plate_overlap(batch,title, num)

In [None]:
from PIL import Image

def make_gif(title):
    base_name = "_".join(title.lower().split())

    base_len = len(base_name) 
    end_len = len(".png")
    fp_in = f"{base_name}*.png"
    fp_out = f"{base_name}.gif"

    img, *imgs = [Image.open(f) 
                  for f in sorted(glob.glob(fp_in), key=lambda x : int(x[base_len:-end_len]))]

    img.save(fp=fp_out, format='GIF', append_images=imgs,
             save_all=True, duration=2000, loop=0)
    
    return fp_out

fn = make_gif(title)

In [None]:
from IPython.display import Image as Image_display
with open(fn,'rb') as f:
    display(Image_display(data=f.read(), format='png'))