In [None]:
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
import torch, torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
from torch.utils.data.dataset import Dataset
import torch.nn.functional as F
from PIL import Image

In [None]:
os.chdir('/kaggle')

In [None]:
! mkdir 'train'

In [None]:
! unzip 'input/data-science-bowl-2018/stage1_train.zip' -d 'stage1_train' 

In [None]:
! mkdir 'stage1_test'

In [None]:
! unzip 'input/data-science-bowl-2018/stage1_test.zip' -d 'stage1_test'

In [None]:
class CustomDataset(Dataset):
  def __init__(self, data, transforms = None):

    self.data = data
    self.transformations = transforms
    self.len = len(data)

  def __getitem__(self, index):

    img_name = os.listdir(f'{self.data[index]}/images')
    img = Image.open(f'{self.data[index]}/images/{img_name[0]}')
    img_tensor = self.transformations(img)

    gt_mask = torch.zeros((1,256,256))

    masks = os.listdir(f'{self.data[index]}/masks')
    wm_masks = torch.zeros((len(masks),256,256))

    for i,mask in enumerate(masks):
      mask_img = Image.open(f'{self.data[index]}/masks/{mask}')
      mask_img = self.transformations(mask_img)

      gt_mask += mask_img
      wm_masks[i] = mask_img[0]

    weight_map = make_weight_map(wm_masks.numpy())

    return img_tensor, gt_mask, weight_map

  def __len__(self):

    return self.len

In [None]:
from skimage.segmentation import find_boundaries

def make_weight_map(masks):

  w0 = 10
  sigma = 5

  n_masks, n_rows, n_cols = masks.shape

  dist_map = np.zeros((n_rows*n_cols, n_masks))

  X1, Y1 = np.meshgrid(np.arange(n_rows),np.arange(n_cols))

  X1, Y1 = np.c_[X1.ravel(), Y1.ravel()].T

  for i, mask in enumerate(masks):
    boundaries = find_boundaries(mask, mode = 'inner')
    X2, Y2 = np.nonzero(boundaries)
    xSum = (X2.reshape(-1,1) - X1.reshape(1,-1))**2
    ySum = (Y2.reshape(-1,1) - Y1.reshape(1,-1))**2
    dist_map[:,i] = np.sqrt(xSum + ySum).min(axis = 0)

  if (n_masks == 1):
    d1 = dist_map.ravel()
    border_loss = w0*np.exp((-(d1**2))/(2*(sigma**2)))

  else:
    for i,arr in enumerate(dist_map):
      dist_map[i,:] = np.sort(arr)

    d1 = dist_map[:,0]
    d2 = dist_map[:,1]
    border_loss = w0*np.exp(((-(d1 + d2)**2))/(2*(sigma**2)))

  wb_Loss = np.zeros((n_rows, n_cols))
  wb_Loss[X1, Y1] = border_loss

  class_Loss = np.zeros((n_rows, n_cols))

  w_1 = 1 - masks.sum()/(class_Loss.size)
  w_0 = 1 - w_1

  class_Loss[masks.sum(0) == 1] = w_1
  class_Loss[masks.sum(0) == 0] = w_0

  total_Loss = class_Loss + wb_Loss

  return torch.from_numpy(total_Loss).reshape(1, n_rows, n_cols)

In [None]:
import glob
images = glob.glob('stage1_train/' + '*')
images = np.array(images)

In [None]:
np.random.shuffle(images)
shuffled_images = list(images)
train_images = shuffled_images[:int(0.9*len(shuffled_images) + 1)]
validation_images = shuffled_images[int(0.9*len(shuffled_images) + 1):int(len(shuffled_images) + 1)]

In [None]:
transformations = transforms.Compose([transforms.Resize((256,256)),transforms.ToTensor()])

In [None]:
custom_dataset_train = CustomDataset(train_images, transformations)
custom_dataset_validation = CustomDataset(validation_images, transformations)

In [None]:
trainloader = torch.utils.data.DataLoader(dataset = custom_dataset_train, 
                                          batch_size = 8, shuffle = True)
validationloader = torch.utils.data.DataLoader(dataset = custom_dataset_validation, 
                                               batch_size = 8, shuffle = True)

In [None]:
class UNet(nn.Module):

    def contracting_block(self, in_channels, out_channels, kernel_size=3):
        
      block = nn.Sequential(
              nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, 
                                    out_channels=out_channels, padding = 1),
              nn.BatchNorm2d(out_channels),
              nn.ReLU(),
              nn.Conv2d(kernel_size=kernel_size, in_channels=out_channels, 
                                    out_channels=out_channels, padding = 1),
              nn.BatchNorm2d(out_channels),
              nn.ReLU(),
              )
      return block

    def expansive_block(self, in_channels, mid_channel, out_channels, kernel_size=3):
         
      block = nn.Sequential(
              nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, 
                                    out_channels=mid_channel, padding = 1),
              nn.BatchNorm2d(mid_channel),
              nn.ReLU(), 
              nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, 
                                    out_channels=mid_channel, padding = 1),
              nn.BatchNorm2d(mid_channel),
              nn.ReLU(),       
              nn.ConvTranspose2d(in_channels=mid_channel, out_channels=out_channels, 
                                             kernel_size=3, stride=2, padding=1, output_padding=1)
              )
      return  block

    def final_block(self, in_channels, mid_channel, out_channels, kernel_size=3):
        
      block = nn.Sequential(
              nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, 
                                out_channels=mid_channel, padding = 1),
              nn.BatchNorm2d(mid_channel),
              nn.ReLU(),  
              nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, 
                                out_channels=mid_channel, padding = 1),
              nn.BatchNorm2d(mid_channel),
              nn.ReLU(),
              nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, 
                                out_channels=out_channels, padding= 1),
              nn.Sigmoid() 
              )
      
      return block

    def __init__(self, in_channel, out_channel):
      super(UNet, self).__init__()

      #Encode
      self.conv_encode1 = self.contracting_block(in_channels=in_channel, out_channels=64)
      self.conv_maxpool1 = nn.MaxPool2d(kernel_size=2)
      self.conv_encode2 = self.contracting_block(64, 128)
      self.conv_maxpool2 = nn.MaxPool2d(kernel_size=2)
      self.conv_encode3 = self.contracting_block(128, 256)
      self.conv_maxpool3 = nn.MaxPool2d(kernel_size=2)

      # Bottleneck
      self.bottleneck =  nn.Sequential(
                         nn.Conv2d(kernel_size=3, in_channels=256, out_channels=512, padding = 1),
                         nn.BatchNorm2d(512),
                         nn.ReLU(), 
                         nn.Conv2d(kernel_size=3, in_channels=512, out_channels=512, padding = 1),
                         nn.BatchNorm2d(512),
                         nn.ReLU(),
                         nn.ConvTranspose2d(in_channels=512, out_channels=256, 
                                                     kernel_size=3, stride=2, padding=1, output_padding=1)
                         )
      # Decode
      self.conv_decode3 = self.expansive_block(512, 256, 128)
      self.conv_decode2 = self.expansive_block(256, 128, 64)
      self.final_layer = self.final_block(128, 64, out_channel)

    def crop_and_concat(self, upsampled, bypass, crop=False):
        
      if crop:
        c = (bypass.size()[2] - upsampled.size()[2]) // 2
        bypass = F.pad(bypass, (-c, -c, -c, -c))
      
      return torch.cat((upsampled, bypass), 1)

    def forward(self, x):
      # Encode
      encode_block1 = self.conv_encode1(x)
      encode_pool1 = self.conv_maxpool1(encode_block1)
      encode_block2 = self.conv_encode2(encode_pool1)
      encode_pool2 = self.conv_maxpool2(encode_block2)
      encode_block3 = self.conv_encode3(encode_pool2)
      encode_pool3 = self.conv_maxpool3(encode_block3)

      # Bottleneck
      bottleneck1 = self.bottleneck(encode_pool3)

      # Decode
      decode_block3 = self.crop_and_concat(bottleneck1, encode_block3)
      cat_layer2 = self.conv_decode3(decode_block3)
      decode_block2 = self.crop_and_concat(cat_layer2, encode_block2)
      cat_layer1 = self.conv_decode2(decode_block2)
      decode_block1 = self.crop_and_concat(cat_layer1, encode_block1)
      final_layer = self.final_layer(decode_block1)
      
      return final_layer

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
unet = UNet(4, 1).to(device)

In [None]:
optimizer = optim.Adam(unet.parameters(), lr = 0.01)
epochs = 20

In [None]:
costs_train = []
costs_validation = []

for epoch in range(epochs):

  for i, (images, gt_masks, wt_maps) in enumerate(trainloader):
    
    batch_size = images.shape[0]
    
    images = images.to(device)    
    gt_masks = gt_masks.to(device)
    wt_maps = wt_maps.to(device)

    train_out = unet(images)
    
    train_out = torch.clip(train_out, 0.05, 0.95)

    if (gt_masks.requires_grad != True):
      gt_masks.requires_grad = True
    
    if (wt_maps.requires_grad != True):
      wt_maps.requires_grad = True

    train_loss = -((wt_maps.mul(torch.log(train_out.mul(gt_masks) + (1-train_out).mul(1-gt_masks)))).sum())/(batch_size)
    costs_train.append(train_loss.item())

    if (((i+1)%(len(validationloader))) == 1):
      valid_iter = iter(validationloader)

    valid_images, gt_masks_valid, wt_maps_valid = next(valid_iter)

    valid_images = valid_images.to(device)
    gt_masks_valid = gt_masks_valid.to(device)
    wt_maps_valid = wt_maps_valid.to(device)

    valid_out = unet(valid_images)
    valid_out = torch.clip(valid_out, 0.05, 0.95)

    valid_loss = -((wt_maps_valid.mul(torch.log(valid_out.mul(gt_masks_valid) + (1-valid_out).mul(1-gt_masks_valid)))).sum())/(batch_size)
    costs_validation.append(valid_loss.item())

    optimizer.zero_grad()
    train_loss.backward()
    optimizer.step() 

    if (((i+1) %10) == 0):
      print(f'Epoch [{epoch+1}/{epochs}], Step [{i+1}/{len(trainloader)}], Training loss: {train_loss.item()}, Validation loss: {valid_loss.item()} ')

plt.plot(np.arange(1, len(costs_train)+1), costs_train, label = 'Training Loss')
plt.plot(np.arange(1, len(costs_validation)+1), costs_validation, label = 'Validation Loss' )
plt.legend()
plt.show()