### Importing Libraries

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import matplotlib.pyplot as plt

from torch.utils.data import Dataset, DataLoader

import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
import time
import torch.optim as optim
import os

from PIL import Image
import PIL

import random
import wandb
#import any other library you need below this line
import math


### Loading data

Upload the data in zip format to Colab. Then run the cell below.

In [None]:
!unzip data.zip

### Defining the Dataset Class

In [None]:
class Cell_data(Dataset):
    def __init__(self, data_dir, size, train=True, train_test_split=0.8, augment_data=True):
        # ######################### inputs ##################################
        # data_dir(string) - directory of the data#########################
        # size(int) - size of the images you want to use###################
        # train(boolean) - train data or test data#########################
        # train_test_split(float) - the portion of the data for training###
        # augment_data(boolean) - use data augmentation or not#############
        super(Cell_data, self).__init__()
        self.data_dir = data_dir
        self.images = os.listdir(os.path.join(data_dir, "scans"))

        self.train_set = []
        end = int(len(self.images) * train_test_split)
        for i in range(0, end):
            self.train_set.append(self.images[i])
        self.test_set = []
        for i in range(end, len(self.images)):
            self.test_set.append(self.images[i])

        self.transforms1 = transforms.Compose([
            transforms.ToTensor(),
        ])
        self.transforms2 = transforms.Compose([
            transforms.PILToTensor(),
        ])

        self.isTrain = train
        self.augment_data = augment_data

        self.size = size
        # initialize the data class

    def __getitem__(self, idx):
        if self.isTrain:
            img_item = self.train_set[idx]
        else:
            img_item = self.test_set[idx]

        # img = Image.open(os.path.join(self.data_dir,"scans",img_item)).convert("L")
        # label = Image.open(os.path.join(self.data_dir,"labels",img_item)).convert("1")
        img = Image.open(os.path.join(self.data_dir, "scans", img_item))
        label = Image.open(os.path.join(self.data_dir, "labels", img_item))
        # load image and mask from index idx of your data

        img = img.resize((self.size, self.size))
        label = label.resize((self.size, self.size))

        # data augmentation part
        if self.augment_data:
            augment_mode = np.random.randint(0, 6)
            if augment_mode == 0:
                img = img.transpose(PIL.Image.FLIP_LEFT_RIGHT)
                label = label.transpose(PIL.Image.FLIP_LEFT_RIGHT)
                # flip image vertically
            elif augment_mode == 1:
                img = img.transpose(PIL.Image.FLIP_TOP_BOTTOM)
                label = label.transpose(PIL.Image.FLIP_TOP_BOTTOM)
                # flip image horizontally
            elif augment_mode == 2:
                width, height = img.size
                img = img.resize((width*2, height*2))
                label = label.resize((width*2, height*2))

                left = (width*2 - width)/2
                top = (height*2 - height)/2
                right = (width*2 + width)/2
                bottom = (height * 2 + height) / 2

                img = img.crop((left, top, right, bottom))
                label = label.crop((left, top, right, bottom))
                # zoom image
            elif augment_mode == 3:
                rand_gamma = np.random.uniform(0.0,1.5)
                img = transforms.functional.adjust_gamma(img, gamma=rand_gamma)
                # Gamma adjust
            elif augment_mode == 4:
                shear_angle = random.randint(-20,20)
                img= TF.affine(img, angle=0, translate=(0,0), scale = 1.0, shear=shear_angle)
                label = TF.affine(label, angle=0, translate=(0,0), scale = 1.0, shear=shear_angle)
                #Sheer Transform
            else:
                angle = np.random.randint(0, 90)
                img = img.rotate(angle)
                label = label.rotate(angle)
                # rotate image

        img = self.transforms1(img)
        label = self.transforms2(label)

        mean, std = img.mean([1, 2]), img.std([1, 2])
        self.normTrans = transforms.Compose([
            transforms.Normalize(mean, std)
        ])

        img = self.normTrans(img)

        return img, label
        # return image and mask in tensors

    def __len__(self):
        if self.isTrain:
            return len(self.train_set)
        else:
            return len(self.test_set)
        # return len(self.images)

### Define the Model
1. Define the Convolution blocks
2. Define the down path
3. Define the up path
4. combine the down and up path to get the final model

In [None]:
class twoConvBlock(nn.Module):
    def __init__(self, input, output):
        super(twoConvBlock, self).__init__()
        self.conv1 = nn.Conv2d(input, output, kernel_size=(3, 3))
        self.conv2 = nn.Conv2d(output, output, kernel_size=(3, 3))
        self.norm = nn.BatchNorm2d(output)
        self.relu = nn.ReLU()
        # initialize the block

    def forward(self, input):
        out = self.conv1(input)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.norm(out)
        out = self.relu(out)
        return out
        # implement the forward path

class downStep(nn.Module):
    def __init__(self, input, output):
        super(downStep, self).__init__()
        self.conv = twoConvBlock(input, output)
        self.maxPooling = nn.MaxPool2d((2, 2), stride=2)
        # initialize the down path

    def forward(self, input):
        copy_out = self.conv(input)
        out = self.maxPooling(copy_out)
        return out, copy_out
        # implement the forward path

class upStep(nn.Module):
    def __init__(self, input):
        super(upStep, self).__init__()
        output = int(input/2)
        self.upSampling = nn.ConvTranspose2d(input, output, kernel_size=(2, 2),stride=(2, 2))
        self.conv = twoConvBlock(input, output)

    def forward(self, input, copy_input):
        out = self.upSampling(input)
        _, _, h, w = out.size()
        # _, _, h2, w2 = copy_input.size()
        # left = int((w2 - w)/2)
        # top = int((h2 - h)/2)
        # copy = transforms.functional.crop(copy_input, top=top, left=left, height=h, width=w)

        cropTrans = transforms.Compose([
            transforms.CenterCrop((h, w)),
        ])
        copy = cropTrans(copy_input)

        out = torch.cat((copy, out), dim=1)
        out = self.conv(out)
        return out
        

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.down1 = downStep(1,64)
        self.down2 = downStep(64,128)
        self.down3 = downStep(128,256)
        self.down4 = downStep(256,512)
        self.conv = twoConvBlock(512,1024)
        self.up1 = upStep(1024)
        self.up2 = upStep(512)
        self.up3 = upStep(256)
        self.up4 = upStep(128)
        self.endConv = nn.Conv2d(64, 2, kernel_size=(1, 1))
        

    def forward(self, input):
        out, copy_out1 = self.down1(input)
        out, copy_out2 = self.down2(out)
        out, copy_out3 = self.down3(out)
        out, copy_out4 = self.down4(out)
        out = self.conv(out)
        out = self.up1(out, copy_out4)
        out = self.up2(out, copy_out3)
        out = self.up3(out, copy_out2)
        out = self.up4(out, copy_out1)
        out = self.endConv(out)
        return out
        

### Training

In [None]:
# Paramteres

# learning rate
lr = 0.001  # 0.005  # 1e-2
# number of training epochs
epoch_n = 1  # 20  # 30
# input image-mask size
image_size = 360  # 400  # 320  # 572
# root directory of project
root_dir = os.getcwd()
# training batch size
batch_size = 1  # 4
# use checkpoint model for training
load = False
# use GPU for training
gpu = True

augment_data = True
wandb.init( 
    project="UNet Cell Segmentation", 
    config={
        "learning_rate": lr,
        "epochs": epoch_n,
        "batch_size": batch_size,
        "image_size": image_size
    }
)

data_dir = os.getcwd()+ '/data/cells'

trainset = Cell_data(data_dir=data_dir, size=image_size, augment_data=augment_data)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)

testset = Cell_data(data_dir=data_dir, size=image_size, train=False, augment_data=False)
testloader = DataLoader(testset, batch_size=batch_size)

device = torch.device('cuda:0' if gpu else 'cpu')

model = UNet().to('cuda:0').to(device)
# print(model)
if load:
    print('loading model')
    model.load_state_dict(torch.load('checkpoint.pt'))

criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.99, 0.999), weight_decay=0.0005)  # 0.99
# optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.99, weight_decay=0.0005)

train_loss_log = []
test_loss_log = []

model.train()
begin = time.time() 
for e in range(epoch_n):
    epoch_loss = 0
    correct_train = 0
    total_train = 0
    model.train()
    for i, data in enumerate(trainloader):
        image, label = data

        # image = image.unsqueeze(1).to(device)
        image = image.to(device)
        label = label.long().to(device)

        pred = model(image)
        #print(pred.shape)
        label = label.squeeze(1)

        crop_x = (label.shape[1] - pred.shape[2]) // 2
        crop_y = (label.shape[2] - pred.shape[3]) // 2

        label = label[:, crop_x: label.shape[1] - crop_x, crop_y: label.shape[2] - crop_y]

        loss = criterion(pred, label)
        # print(loss)
        total_train += label.shape[0] * label.shape[1] * label.shape[2]
        
        _, pred_labels_train = torch.max(pred, dim = 1)
        correct_train += (pred_labels_train == label).sum().item()
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        epoch_loss += loss.item()

        # print('batch %d --- Loss: %.4f' % (i, loss.item() / batch_size))
    t_acc = correct_train / total_train
    train_loss = epoch_loss / trainset.__len__()
    print('Epoch %d / %d --- Loss: %.4f' % (e + 1, epoch_n, epoch_loss / trainset.__len__()))
    train_loss_log.append(epoch_loss / trainset.__len__())

    model.eval()

    total = 0
    correct = 0
    total_loss = 0

    with torch.no_grad():
        for i, data in enumerate(testloader):
            image, label = data

            # image = image.unsqueeze(1).to(device)
            image = image.to(device)
            label = label.long().to(device)

            pred = model(image)

            label = label.squeeze(1)

            crop_x = (label.shape[1] - pred.shape[2]) // 2
            crop_y = (label.shape[2] - pred.shape[3]) // 2

            label = label[:, crop_x: label.shape[1] - crop_x, crop_y: label.shape[2] - crop_y]

            loss = criterion(pred, label)
            total_loss += loss.item()

            _, pred_labels = torch.max(pred, dim=1)

            total += label.shape[0] * label.shape[1] * label.shape[2]
            correct += (pred_labels == label).sum().item()
        v_acc = correct/total
        v_loss = total_loss/testset.__len__()
        print('Accuracy: %.4f ---- Loss: %.4f' % (v_acc, v_loss))

        test_loss_log.append(total_loss / testset.__len__())
        if correct/total > 0.75:
              torch.save(model.state_dict(), 'checkpoint.pt')
        
    wandb.log({'train_accuracy': t_acc,'time': time.time()-begin, 'training_loss': train_loss, 'validation_loss': v_loss})
    wandb.watch(model)


In [None]:
# torch.cuda.empty_cache()

### Testing and Visualization

In [None]:
model.eval()

testset = Cell_data(data_dir=data_dir, size=572, train=False, augment_data=False)
testloader = DataLoader(testset, batch_size=batch_size)

output_masks = []
output_labels = []

with torch.no_grad():
    for i in range(testset.__len__()):
        image, labels = testset.__getitem__(i)

        # input_image = image.unsqueeze(0).unsqueeze(0).to(device)
        input_image = image.unsqueeze(0).to(device)
        pred = model(input_image)

        labels = labels.squeeze(0)
        output_mask = torch.max(pred, dim=1)[1].cpu().squeeze(0).numpy()

        crop_x = (labels.shape[0] - output_mask.shape[0]) // 2
        crop_y = (labels.shape[1] - output_mask.shape[1]) // 2
        # labels = labels[crop_x: labels.shape[0] - crop_x, crop_y: labels.shape[1] - crop_y].numpy()
        labels = labels[crop_x: labels.shape[0] - crop_x, crop_y: labels.shape[1] - crop_y]

        labels = labels.numpy()
        output_masks.append(output_mask)
        output_labels.append(labels)

In [None]:
# Plot usingplt plot train-test plot

plt.plot(range(epoch_n), train_loss_log, 'g', label='Training Loss')
plt.plot(range(epoch_n), test_loss_log, 'r', label='Testing Loss')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.title('Train-Test Loss')
plt.legend()
plt.show()

In [None]:
fig, axes = plt.subplots(testset.__len__(), 2, figsize = (20, 20))

for i in range(testset.__len__()):
  axes[i, 0].imshow(output_labels[i])
  axes[i, 0].axis('off')
  axes[i, 1].imshow(output_masks[i])
  axes[i, 1].axis('off')