In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
!pip install torchsummary

In [None]:
import numpy as np
import pandas as pd
import os
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
import torchvision
from torchsummary import summary
from torch.utils.data import Dataset, DataLoader
from torch import nn, optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models
from torch.optim import lr_scheduler, Adam
from albumentations import (
    Compose, OneOf, Normalize, Resize, RandomResizedCrop, RandomCrop, HorizontalFlip, VerticalFlip, 
    RandomBrightness, RandomContrast, RandomBrightnessContrast, Rotate, ShiftScaleRotate, Cutout, 
    HueSaturationValue, CoarseDropout, ToGray
    )
from albumentations.pytorch import ToTensorV2

device = ("cuda" if torch.cuda.is_available() else "cpu")

device

In [None]:
dataset_folder = "/kaggle/input/hpa-single-cell-image-classification/"
training_image_folder = dataset_folder+"train/"
train_df = pd.read_csv(dataset_folder+"train.csv")
len(train_df)

In [None]:
def get_binary_mask(img):
    '''
    Turn the RGB image into grayscale before
    applying an Otsu threshold to obtain a
    binary segmentation
    '''
    
    blurred_img = cv2.GaussianBlur(img,(25,25),0)
    gray_img = cv2.cvtColor(blurred_img, cv2.COLOR_RGBA2GRAY)
    ret, otsu = cv2.threshold(gray_img, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
    
    kernel = np.ones((40,40),np.uint8)
    closed_mask = cv2.morphologyEx(otsu, cv2.MORPH_CLOSE, kernel)
    return closed_mask

def load_RGBY_image(img):
    '''
    Load and stack the channels that are stored separately.
    '''
    
    red_image = cv2.imread(img+"_red.png", cv2.IMREAD_UNCHANGED)
    green_image = cv2.imread(img+"_green.png", cv2.IMREAD_UNCHANGED)
    blue_image = cv2.imread(img+"_blue.png", cv2.IMREAD_UNCHANGED)
    yellow_image = cv2.imread(img+"_yellow.png", cv2.IMREAD_UNCHANGED)

    stacked_images = np.transpose(np.array([red_image, green_image, blue_image, yellow_image]), (1,2,0))
    return stacked_images

img = load_RGBY_image('/kaggle/input/hpa-single-cell-image-classification/train/8061ee18-bbb2-11e8-b2ba-ac1f6b6435d0')
plt.imshow(img[:, :, :3])
plt.show()

In [None]:
def threshold_img(img):
    """Numpy indexing"""
    img_thres = img
    img_thres[ img < 0.5 ] = 0
    
    return img_thres


class CellMaskDataset(Dataset):
    
    def __init__(self, stacked_transform=None, mask_transform=None, train=False):
        
        self.stacked_transform = stacked_transform
        self.mask_transform = mask_transform
        self.train = train
        
    def __len__(self):
        
        if self.train:
        
            return len(train_df[:20000])
        
        else:
            
            return len(train_df[20001:])
        
    def __getitem__(self, idx):
        
        if self.train:
            
            tmp_df = train_df[:20000]
            image_path = training_image_folder+tmp_df.iloc[idx].ID
            stacked_images = load_RGBY_image(image_path)
            binary_mask = get_binary_mask(stacked_images)
            #plt.imshow(binary_mask)
            #plt.show()
            #print(image_path)
            stacked_images = stacked_images[:,:,:3]
            
            binary_mask = threshold_img(binary_mask)
            binary_mask = self.mask_transform(image=binary_mask)
            stacked_images = self.stacked_transform(image=stacked_images)
            
        else: 
            
            tmp_df = train_df[20001:20050]
            image_path = training_image_folder+tmp_df.iloc[idx].ID
            stacked_images = load_RGBY_image(image_path)
            binary_mask = get_binary_mask(stacked_images)
            stacked_images = stacked_images[:,:,:3]
            
            binary_mask = threshold_img(binary_mask)
            binary_mask = self.mask_transform(image=binary_mask)
            stacked_images = self.stacked_transform(image=stacked_images)
        
        return stacked_images['image'], binary_mask['image'].float()
    
    
def get_transforms(data):
    
    if data == 'mask':
        
        return Compose([
            Resize(64, 64), # For more accurately segmented images: Resize(256, 256)
            #ToGray(p=1.0),
            Normalize(
                mean=[0.0],
                std=[1.0],
            ),
            ToTensorV2(),

        ])
    
    else:
        
        return Compose([
            Resize(64, 64), # For more accurately segmented images: Resize(256, 256)
            Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
            ),
            ToTensorV2(),
            
        ])



train_dataset = CellMaskDataset(stacked_transform=get_transforms('stupid'), mask_transform=get_transforms('mask'), train=True)
train_dataloader = DataLoader(train_dataset,
                                batch_size=1,
                                shuffle=True,
                                num_workers=0)

val_dataset = CellMaskDataset(stacked_transform=get_transforms('stupid'), mask_transform=get_transforms('mask'), train=False)
val_dataloader = DataLoader(val_dataset,
                                batch_size=1,
                                shuffle=False,
                                num_workers=0)
from torchvision.utils import save_image

#torch.set_printoptions(profile="full")
#for idx, (image, mask) in enumerate(train_dataloader):
    #if idx == 22:
        #save_image(mask, "real.png")

In [None]:
from torch.nn.modules.loss import _Loss


class SoftDiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(SoftDiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        #comment out if your model contains a sigmoid or equivalent activation layer
        inputs = torch.sigmoid(inputs)       
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()                            
        dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        
        return dice

class First2D(nn.Module):
    
    def __init__(self, in_channels, middle_channels, out_channels, dropout=False):
        super(First2D, self).__init__()
        
        layers = [
            nn.Conv2d(in_channels, middle_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(middle_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(middle_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        ]
        
        if dropout:
            
            layers.append(nn.Dropout2d(p=dropout))
        
        self.block = nn.Sequential(*layers)
        
    def forward(self, x):
        
        x = self.block(x)
        
        return x
    
class Upsample2D(nn.Module):
    
    def __init__(self, in_channels, middle_channels, out_channels, deconv_channels, dropout=False):
        super(Upsample2D, self).__init__()
        
        layers = [
            nn.Conv2d(in_channels, middle_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(middle_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(middle_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(out_channels, deconv_channels, kernel_size=2, stride=2)
        ]
        
        if dropout:
            
            layers.append(nn.Dropout2d(p=dropout))
            
        self.block = nn.Sequential(*layers)
        
    def forward(self, x):
        
        x = self.block(x)
        
        return x
    
    
class Downsample2D(nn.Module):
    
    def __init__(self, in_channels, middle_channels, out_channels, dropout=False, downsample_kernel=2):
        super(Downsample2D, self).__init__()
        
        layers = [
            nn.MaxPool2d(kernel_size=downsample_kernel),
            nn.Conv2d(in_channels, middle_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(middle_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(middle_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        ]
        
        if dropout:
            
            layers.append(nn.Dropout2d(p=dropout))
            
        self.block = nn.Sequential(*layers)
        
    def forward(self, x):
        
        x = self.block(x)
        
        return x
    
class Center2D(nn.Module):
    
    def __init__(self, in_channels, middle_channels, out_channels, deconv_channels, dropout=False):
        super(Center2D, self).__init__()
        
        layers = [
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(in_channels, middle_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(middle_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(middle_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(out_channels, deconv_channels, kernel_size=2, stride=2)
        ]
        
        if dropout:
            
            layers.append(nn.Dropout2d(p=dropout))
            
        self.block = nn.Sequential(*layers)
        
    def forward(self, x):
        
        x = self.block(x)
        
        return x
    
class Last2D(nn.Module):
    
    def __init__(self, in_channels, middle_channels, out_channels):
        super(Last2D, self).__init__()
        
        layers = [
            nn.Conv2d(in_channels, middle_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(middle_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(middle_channels, middle_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(middle_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(middle_channels, out_channels, kernel_size=1)
        ]
        
        self.block = nn.Sequential(*layers)
        
    def forward(self, x):
        
        x = self.block(x)
        
        return x
            



In [None]:
class UNET(nn.Module):
    
    def __init__(self, in_channels, out_channels, conv_depths=[64, 128, 256, 512, 1024]):
        super(UNET, self).__init__()
        
        encoder_layers = []
        encoder_layers.append(First2D(in_channels, conv_depths[0], conv_depths[0]))
        encoder_layers.extend([Downsample2D(conv_depths[i], conv_depths[i + 1], conv_depths[i+1]) for i in range(len(conv_depths) - 2)])
        
        decoder_layers = []
        decoder_layers.extend([Upsample2D(2 * conv_depths[i + 1], 2 * conv_depths[i],  2 * conv_depths[i],  conv_depths[i])
                              for i in reversed(range(len(conv_depths) - 2))])
        decoder_layers.append(Last2D(conv_depths[1], conv_depths[0], out_channels))
        
        self.encoder_block = nn.Sequential(*encoder_layers)
        self.center_block = Center2D(conv_depths[-2], conv_depths[-1], conv_depths[-1], conv_depths[-2])
        self.decoder_block = nn.Sequential(*decoder_layers)
        
    def forward(self, x, return_all=False):
        
        x_enc = [x]
        for enc_layer in self.encoder_block:
            
            x_enc.append(enc_layer(x_enc[-1]))
            
        x_dec = [self.center_block(x_enc[-1])]
        
        for dec_layer_idx, dec_layer in enumerate(self.decoder_block):
            
            x_opposite = x_enc[-1 - dec_layer_idx]
            x_cat = torch.cat([pad_to_shape(x_dec[-1], x_opposite.shape), x_opposite],
                             dim=1)
            x_dec.append(dec_layer(x_cat))
            
        if not return_all:
            
            return x_dec[-1]
        
        else:
            
            return x_enc + x_dec
        
        
def pad_to_shape(current, targ_shp):
    
    if len(targ_shp) == 4:
        
        pad = (0, targ_shp[3] - current.shape[3], 0, targ_shp[2] - current.shape[2])
        
    elif len(targ_shp) == 5:
        
        pad = (0, targ_shp[4] - current.shape[4], 0, targ_shp[3] - current.shape[3], 0, targ_shp[2] - current.shape[2])
        
    return F.pad(current, pad)



In [None]:
u = UNET(3, 1).to(device)

summary(u, input_size=(3, 64, 64))

In [None]:
import time
import matplotlib.pyplot as plt
from torchvision.utils import save_image

class Trainer():
    
    def __init__(self, net, loss, optimizer, device, scheduler=None):
        
        self.net = net
        self.loss = loss
        self.optimizer = optimizer
        self.scheduler = scheduler
        
        self.device = device
        self.net.to(self.device)
        self.loss.to(self.device)
        
    def val_epoch(self, dataloader, dice_loss):
        
        running_val_loss = 0.0
        
        with torch.no_grad():
            
            for batch_idx, (x_batch, y_batch) in enumerate(dataloader):
                
                x_batch = x_batch.to(self.device)
                y_batch = y_batch.to(self.device)
                
                y_out =  self.net(x_batch)
                validation_loss = self.loss(y_out, y_batch)
                dice_loss = dice_loss(y_out, y_batch)
                running_val_loss += validation_loss
                
        return ((running_val_loss/(len(dataloader)), dice_loss))
        
    def train_epoch(self, dataloader, dice_loss, epoch):
        
        self.net.train()
        
        epoch_running_loss = 0.0
        running_dice_loss = 0.0
        
        for batch_idx, (x_batch, y_batch) in enumerate(dataloader):
            
            x_batch = x_batch.to(self.device)
            y_batch = y_batch.to(self.device)
            
            self.optimizer.zero_grad()
            y_out = self.net(x_batch)
            training_loss = self.loss(y_out, y_batch)
            train_dice = dice_loss(y_out, y_batch)
            
            training_loss.backward()
            self.optimizer.step()
            
            epoch_running_loss += training_loss
            running_dice_loss += train_dice
            
            if batch_idx % 200 == 0 and batch_idx != 0:
                
                print(f"Step - {batch_idx} | Training Loss - {epoch_running_loss/batch_idx} | Dice Loss - {running_dice_loss/batch_idx}")
                
                img1 = torch.sigmoid(y_out) # output is the output tensor of your UNet, the sigmoid will center the range around 0.
                # Binarize the image
                threshold = (img1.min() + img1.max()) * 0.5
                ima = torch.where(img1 > threshold, 1.0, 0.0)
                
                #print(ima)
                
                save_image(ima, f'BIN_ima_{batch_idx}_{epoch}.png')
                save_image(y_batch.squeeze(1), f'BIN_ima_groundtruth_{batch_idx}_{epoch}.png')
                save_image(x_batch.squeeze(1), f'BIN_original_{batch_idx}_{epoch}.png')
                
            
            
            epoch_running_loss += training_loss.item()
            
        return (epoch_running_loss/len(dataloader)), train_dice
    
    
                
    
    def train_unet(self, train_loader, val_loader, n_epochs, dice_metric):
            
        min_loss = np.inf
        train_time = time.time()
        dice_metric = dice_metric.to(device)
                
        logs = {}
        
        for epoch in range(1, n_epochs+1):
            
            train_loss, train_dice = self.train_epoch(train_loader, dice_metric, epoch)
            
            #self.scheduler.step()
            
            val_loss, val_dice = self.val_epoch(val_loader, dice_metric)
            
            logs = {'epoch': epoch,
                    'time': epoch_end - train_start,
                    'train_loss': train_loss,
                    'validation_loss': val_loss,
                    'train_dice': trian_dice,
                    'validation_dice': val_dice
                    }
            
            print("-" * 20)
            print(f"Epoch - {logs['epoch']} | Time Elapsed - {logs['time']} | Training Loss - {logs['train_loss']} | Train Dice Coeff - {logs['train_dice']}") 
            print(f"Validation Loss - {logs['validation_loss']} | Validation Dice - {logs['validation_dice']}")
                
    def predict_dataset(self, dataloader, export_path):
        
        with torch.no_grad():
                
            for batch_idx, (x_batch) in enumerate(dataloader):

                image_filename = '%s.png' % str(batch_idx + 1).zfill(3)

                x_batch = x_batch.to(self.device)
                y_out = self.net(x_batch)
                
                img1 = torch.sigmoid(y_out) # output is the output tensor of your UNet, the sigmoid will center the range around 0.
                # Binarize the image
                threshold = (img1.min() + img1.max()) * 0.5
                ima = torch.where(img1 > threshold, 1.0, 0.0)

                save_image(ima, os.path.join(export_path, image_filename))

            
            

In [None]:
in_channels = 3
out_channels = 1
width = 32
depth = 6
conv_depths = [int(width*(2**k)) for k in range(depth)]

unet = UNET(in_channels, out_channels, conv_depths).to(device)
loss = nn.BCEWithLogitsLoss()
dice = SoftDiceLoss()
optimizer = optim.Adam(unet.parameters(), lr=1e-3)

trainer = Trainer(unet, loss, optimizer, device=device)

#summary(unet, input_size=(3, 256, 256))

In [None]:
trainer.train_unet(train_dataloader, val_dataloader, 5, dice)

In [None]:
%pylab inline
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

print("Original Image")
img = mpimg.imread('./BIN_original_1000_1.png')
imgplot = plt.imshow(img)
plt.show()

print("Ground Truth")
img = mpimg.imread('./BIN_ima_groundtruth_1000_1.png')
imgplot = plt.imshow(img)
plt.show()

print("Predicted")
img = mpimg.imread('./BIN_ima_1000_1.png')
imgplot = plt.imshow(img)
plt.show()