In [None]:
# ! cp drive/My\ Drive/CT.zip .
# ! unzip CT.zip

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageOps
import pandas as pd
import os

ngpu = torch.cuda.device_count()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device, ngpu)

In [None]:
img_name = '62_50.bmp'
mask_name = img_name

im = Image.open(os.path.join('CT/test/img/', img_name))
# im = ImageOps.grayscale(im)
im = im.resize((256, 256))
im = np.array(im)
plt.imshow(im, cmap="gray")
plt.show()

seg = Image.open(os.path.join('CT/test/seg_w', mask_name))
seg = seg.resize((256, 256))
seg = np.array(seg)
plt.imshow(seg, cmap="gray")
plt.show()

In [None]:
def load_batch(batch_size = 20, dims=(256, 256), data_path='CT/train/'):
    img_names = np.array(os.listdir(data_path+'img'))
    mask_names = np.array([x for x in img_names])
    shuffled_idxs = np.random.permutation(len(img_names))
    shuffled_img_names = img_names[shuffled_idxs]
    shuffled_mask_names = mask_names[shuffled_idxs]
    no_of_batches = img_names.shape[0] // batch_size
    extra_batch = img_names.shape[0] % batch_size
    for i in range(no_of_batches):
        # print(shuffled_img_names[i*batch_size:(i+1)*batch_size].shape)
        imgs, segs, p_segs = [], []
        for j in range(batch_size):
            im = Image.open(os.path.join(data_path+'img', shuffled_img_names[i*batch_size+j]))
            seg = Image.open(os.path.join(data_path+'seg_w', shuffled_mask_names[i*batch_size+j]))
            seg_p = Image.open(os.path.join(data_path+'seg_l', shuffled_mask_names[i*batch_size+j]))
            im = ImageOps.grayscale(im)
            im = im.resize(dims)
            seg = seg.resize(dims)
            seg_p = seg_p.resize(dims)
            im = np.array(im) / 255.
            seg = np.array(seg) / 255.
            seg_p = np.array(seg_p) / 255.
            imgs.append(im)
            segs.append(seg)
            p_segs.append(seg_p)
        yield (torch.FloatTensor(imgs).view(batch_size,1,dims[0],dims[1]).to(device),
               torch.FloatTensor(segs).view(batch_size,1,dims[0],dims[1]).to(device),
               torch.FloatTensor(p_segs).view(batch_size,1,dims[0],dims[1]).to(device))
    if extra_batch:
        imgs, segs = [], []
        for j in range(extra_batch):
            im = Image.open(os.path.join(data_path+'img', shuffled_img_names[(i+1)*batch_size]))
            seg = Image.open(os.path.join(data_path+'seg_w', shuffled_mask_names[(i+1)*batch_size]))
            seg_p = Image.open(os.path.join(data_path+'seg_l', shuffled_mask_names[(i+1)*batch_size]))
            im = ImageOps.grayscale(im)
            im = im.resize(dims)
            seg = seg.resize(dims)
            seg_p = seg_p.resize(dims)
            im = np.array(im) / 255.
            seg = np.array(seg) / 255.
            seg_p = np.array(seg_p) / 255.
            imgs.append(im)
            segs.append(seg)
            p_segs.append(seg_p)
        yield (torch.FloatTensor(imgs).view(extra_batch,1,dims[0],dims[1]).to(device),
               torch.FloatTensor(segs).view(extra_batch,1,dims[0],dims[1]).to(device),
               torch.FloatTensor(p_segs).view(extra_batch,1,dims[0],dims[1]).to(device))

In [None]:
class conv_block(nn.Module):
    def __init__(self, num_channels, num_filters):
        super(conv_block, self).__init__()
        self.conv1 = nn.Conv2d(num_channels, num_filters, (3, 3), padding=1)
        self.conv1_bn = nn.BatchNorm2d(num_filters)
        self.conv2 = nn.Conv2d(num_filters, num_filters, (3, 3), padding=1)
        self.conv2_bn = nn.BatchNorm2d(num_filters) sasha sloan
    
    def forward(self, inp_tensor):
        encoder = self.conv1(inp_tensor)
        encoder = self.conv1_bn(encoder)
        encoder = torch.relu(encoder)
        encoder = self.conv2(encoder)
        encoder = self.conv2_bn(encoder)
        encoder = torch.relu(encoder)
        return encoder

class encoder_block(nn.Module):
    def __init__(self, num_channels, num_filters):
        super(encoder_block, self).__init__()
        self.conv_block1 = conv_block(num_channels, num_filters)
        self.max_pool1 = nn.MaxPool2d((2, 2), (2, 2))
    
    def forward(self, inp_tensor):
        encoder = self.conv_block1(inp_tensor)
        encoder_pool = self.max_pool1(encoder)
        return (encoder_pool, encoder)

class decoder_block(nn.Module):
    def __init__(self, num_channels, num_filters):
        super(decoder_block, self).__init__()
        self.conv_tp1 = nn.ConvTranspose2d(num_channels, num_filters, (2, 2), stride=(2, 2))
        self.conv_tp1_bn = nn.BatchNorm2d(2*num_filters)
        self.conv_tp2 = nn.Conv2d(2*num_filters, num_filters, (3, 3), padding=1)
        self.conv_tp2_bn = nn.BatchNorm2d(num_filters)
        self.conv_tp3 = nn.Conv2d(num_filters, num_filters, (3, 3), padding=1)
        self.conv_tp3_bn = nn.BatchNorm2d(num_filters)

    def forward(self, inp_tensor, concat_tensor):
        decoder = self.conv_tp1(inp_tensor)
        decoder = torch.cat((concat_tensor, decoder), 1)
        decoder = self.conv_tp1_bn(decoder)
        decoder = torch.relu(decoder)
        decoder = self.conv_tp2(decoder)
        decoder = self.conv_tp2_bn(decoder)
        decoder = torch.relu(decoder)
        decoder = self.conv_tp3(decoder)
        decoder = self.conv_tp3_bn(decoder)
        decoder = torch.relu(decoder)
        return decoder

In [None]:
class UNet2D(nn.Module):
    def __init__(self, num_channels=3):
        super(UNet2D, self).__init__()
        self.encoder_block0 = encoder_block(num_channels, 32)
        self.encoder_block1 = encoder_block(32, 64)
        self.encoder_block2 = encoder_block(64, 128)
        self.encoder_block3 = encoder_block(128, 256)
        self.encoder_block4 = encoder_block(256, 512)
        self.center = conv_block(512, 1024)
        self.decoder_block4 = decoder_block(1024, 512)
        self.decoder_block3 = decoder_block(512, 256)
        self.decoder_block2 = decoder_block(256, 128)
        self.decoder_block1 = decoder_block(128, 64)
        self.decoder_block0 = decoder_block(64, 32)
        self.conv_final = nn.Conv2d(32, 1, (1, 1))
    
    def forward(self, inputs):
        # inputs = x # 256

        encoder0_pool, encoder0 = self.encoder_block0(inputs) # 128
        encoder1_pool, encoder1 = self.encoder_block1(encoder0_pool) # 64
        encoder2_pool, encoder2 = self.encoder_block2(encoder1_pool) # 32
        encoder3_pool, encoder3 = self.encoder_block3(encoder2_pool) # 16
        encoder4_pool, encoder4 = self.encoder_block4(encoder3_pool) # 8

        center = self.center(encoder4_pool) # center (8)

        decoder4 = self.decoder_block4(center, encoder4) # 16
        decoder3 = self.decoder_block3(decoder4, encoder3) # 32
        decoder2 = self.decoder_block2(decoder3, encoder2) # 64
        decoder1 = self.decoder_block1(decoder2, encoder1) # 128
        decoder0 = self.decoder_block0(decoder1, encoder0) # 256

        outputs = torch.sigmoid(self.conv_final(decoder0))
        return outputs

In [None]:
class MO_Net_encoder(nn.Module):
    def __init__(self, num_channels=3):
        super(MO_Net_encoder, self).__init__()
        self.encoder_block0 = encoder_block(num_channels, 32)
        self.encoder_block1 = encoder_block(32, 64)
        self.encoder_block2 = encoder_block(64, 128)
        self.encoder_block3 = encoder_block(128, 256)
        self.encoder_block4 = encoder_block(256, 512)
        self.center = conv_block(512, 1024)
        
    def forward(self, inputs):
        # inputs = x # 256

        encoder0_pool, encoder0 = self.encoder_block0(inputs) # 128
        encoder1_pool, encoder1 = self.encoder_block1(encoder0_pool) # 64
        encoder2_pool, encoder2 = self.encoder_block2(encoder1_pool) # 32
        encoder3_pool, encoder3 = self.encoder_block3(encoder2_pool) # 16
        encoder4_pool, encoder4 = self.encoder_block4(encoder3_pool) # 8
        center = self.center(encoder4_pool) # center (8)

        return center


class MO_Net_decoder(nn.Module):
    def __init__(self):
        super(MO_Net_decoder, self).__init__()
        self.decoder_block4 = decoder_block(1024, 512)
        self.decoder_block3 = decoder_block(512, 256)
        self.decoder_block2 = decoder_block(256, 128)
        self.decoder_block1 = decoder_block(128, 64)
        self.decoder_block0 = decoder_block(64, 32)
        self.conv_final = nn.Conv2d(32, 1, (1, 1))
            
    def forward(self, center):
        # center = x # (8)

        decoder4 = self.decoder_block4(center, encoder4) # 16
        decoder3 = self.decoder_block3(decoder4, encoder3) # 32
        decoder2 = self.decoder_block2(decoder3, encoder2) # 64
        decoder1 = self.decoder_block1(decoder2, encoder1) # 128
        decoder0 = self.decoder_block0(decoder1, encoder0) # 256

        outputs = torch.sigmoid(self.conv_final(decoder0))
        return outputs

In [None]:
def dice_coeff(y_true, y_pred):
    smooth = 1.
    assert y_true.shape == y_pred.shape, "Tensor dimensions must match"
    shape = y_true.shape
    y_true_flat = y_true.view(shape[0]*shape[1]*shape[2]*shape[3],)
    y_pred_flat = y_pred.view(shape[0]*shape[1]*shape[2]*shape[3],)
    intersection = torch.sum(y_true_flat * y_pred_flat)
    score = (2. * intersection + smooth) / (torch.sum(y_true_flat) + torch.sum(y_pred_flat) + smooth)
    return score

def dice_loss(y_true, y_pred):
    loss = 1 - dice_coeff(y_true, y_pred)
    return loss

def bce_dice_loss(y_true, y_pred):
    loss = F.binary_cross_entropy(y_pred, y_true) + dice_loss(y_true, y_pred)
    return loss

In [None]:
# PHASE 1

epochs = 10
batch_size = 5

losses = []
dscoeffs = []
avg_losses = []
val_avg_losses = []
avg_dscoeffs = []
val_avg_dscoeffs = []

model0 = UNet2D(1).to(device)
optimizer = optim.Adam(model0.parameters(), lr = 0.0001)

In [None]:
for epoch in range(1, epochs+1):
    model0.train()
    avg_loss = 0.0
    avg_dscoeff = 0.0
    print('---------- EPOCH:', epoch, '----------')
    print('----------- TRAINING -------------')
    for i, (imgs, _, segs) in enumerate(load_batch(batch_size)):
        optimizer.zero_grad()
        outputs = model0(imgs)
        loss = bce_dice_loss(segs, outputs)
        loss.backward()
        optimizer.step()
        dscoeff = dice_coeff(segs, outputs).item()
        print("Epoch:", epoch, "| Iter:", i+1, "| loss:", round(loss.item(), 4), "| dsc:", round(dscoeff, 4))
        losses.append(loss.item())
        dscoeffs.append(dscoeff)
        avg_loss += loss.item()
        avg_dscoeff += dscoeff
    print('-------------- DONE --------------')
    model0.eval()
    val_avg_loss = 0.0
    val_avg_dscoeff = 0.0
    print('---------- VALIDATING ------------')
    for j, (imgs, _, segs) in enumerate(load_batch(batch_size, data_path='CT/test/')):
        outputs = model0(imgs)
        loss = bce_dice_loss(segs, outputs)
        dscoeff = dice_coeff(segs, outputs).item()
        print("Epoch:", epoch, "| Iter:", j+1, "| loss:", round(loss.item(), 4), "| dsc:", round(dscoeff, 4))
        val_avg_loss += loss.item()
        val_avg_dscoeff += dscoeff
    print('-------------- DONE --------------')
    print()
    avg_loss = round(avg_loss/(i+1), 4)
    avg_dscoeff = round(avg_dscoeff/(i+1), 4)
    val_avg_loss = round(val_avg_loss/(j+1), 4)
    val_avg_dscoeff = round(val_avg_dscoeff/(j+1), 4)
    avg_losses.append(avg_loss)
    val_avg_losses.append(val_avg_loss)
    avg_dscoeffs.append(avg_dscoeff)
    val_avg_dscoeffs.append(val_avg_dscoeff)
    print("Epoch:", epoch, "| loss:", avg_loss, "| dsc:", avg_dscoeff,
          "| val_loss:", val_avg_loss, "| val_dsc:", val_avg_dscoeff)
    print()

In [None]:
# PHASE 2

epochs = 6
batch_size = 5
Lambda = 0.7

losses = []
dscoeffs = []
avg_losses = []
val_avg_losses = []
avg_dscoeffs = []
val_avg_dscoeffs = []

model_enc = MO_Net_encoder(1).to(device)
model_dec1 = MO_Net_decoder().to(device)
model_dec2 = MO_Net_decoder().to(device)
params1 = model0.state_dict()
params2 = model_enc.state_dict()
params3 = model_dec1.state_dict()
params4 = model_dec2.state_dict()
for item in params1:
    if item in params2:
        params2[item] = params1[item]
for item in params1:
    if item in params3:
        params3[item] = params1[item]
for item in params1:
    if item in params4:
        params4[item] = params1[item]
model_enc.load_state_dict(params2)
model_dec1.load_state_dict(params3)
model_dec2.load_state_dict(params4)
optimizer0 = optim.Adam(model_enc.parameters(), lr = 0.000003)
optimizer1 = optim.Adam(model_dec1.parameters(), lr = 0.000001)
optimizer2 = optim.Adam(model_dec2.parameters(), lr = 0.00001)

In [None]:
for epoch in range(1, epochs+1):
    model_enc.train()
    model_dec1.train()
    model_dec2.train()
    avg_loss = 0.0
    avg_dscoeff = 0.0
    print('---------- EPOCH:', epoch, '----------')
    print('----------- TRAINING -------------')
    for i, (imgs, segs, p_segs) in enumerate(load_batch(batch_size)):
        optimizer1.zero_grad()
        optimizer2.zero_grad()
        optimizer3.zero_grad()
        outputs1 = model_dec1(model_enc(imgs))
        outputs2 = model_dec2(model_enc(imgs))
        loss1 = bce_dice_loss(p_segs, outputs1)
        loss2 = bce_dice_loss(segs, outputs2)
        loss = ((1 - Lambda) * loss1) + (Lambda * loss2)
        loss.backward()
        optimizer2.step()
        optimizer3.step()
        optimizer1.step()
        dscoeff = dice_coeff(segs, outputs2).item()
        print("Epoch:", epoch, "| Iter:", i+1, "| loss:", round(loss.item(), 4), "| dsc:", round(dscoeff, 4))
        losses.append(loss.item())
        dscoeffs.append(dscoeff)
        avg_loss += loss.item()
        avg_dscoeff += dscoeff
    print('-------------- DONE --------------')
    model_enc.eval()
    model_dec1.eval()
    model_dec2.eval()
    val_avg_loss = 0.0
    val_avg_dscoeff = 0.0
    print('---------- VALIDATING ------------')
    for j, (imgs, segs, p_segs) in enumerate(load_batch(batch_size, data_path='CT/test/')):
        outputs1 = model_dec1(model_enc(imgs))
        outputs2 = model_dec2(model_enc(imgs))
        loss1 = bce_dice_loss(p_segs, outputs1)
        loss2 = bce_dice_loss(segs, outputs2)
        loss = ((1 - Lambda) * loss1) + (Lambda * loss2)
        dscoeff = dice_coeff(segs, outputs2).item()
        print("Epoch:", epoch, "| Iter:", j+1, "| loss:", round(loss.item(), 4), "| dsc:", round(dscoeff, 4))
        val_avg_loss += loss.item()
        val_avg_dscoeff += dscoeff
    print('-------------- DONE --------------')
    print()
    avg_loss = round(avg_loss/(i+1), 4)
    avg_dscoeff = round(avg_dscoeff/(i+1), 4)
    val_avg_loss = round(val_avg_loss/(j+1), 4)
    val_avg_dscoeff = round(val_avg_dscoeff/(j+1), 4)
    avg_losses.append(avg_loss)
    val_avg_losses.append(val_avg_loss)
    avg_dscoeffs.append(avg_dscoeff)
    val_avg_dscoeffs.append(val_avg_dscoeff)
    print("Epoch:", epoch, "| loss:", avg_loss, "| dsc:", avg_dscoeff,
          "| val_loss:", val_avg_loss, "| val_dsc:", val_avg_dscoeff)
    print()

In [None]:
plt.plot(losses)
plt.title('Loss plot')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.show()

plt.plot(dscoeffs)
plt.title('Performance plot')
plt.xlabel('Iterations')
plt.ylabel('DSC')
plt.show()

In [None]:
plt.plot(avg_losses)
plt.title('Loss plot')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

plt.plot(avg_dscoeffs)
plt.title('Performance plot')
plt.xlabel('Epochs')
plt.ylabel('DSC')
plt.show()

In [None]:
plt.plot(val_avg_losses)
plt.title('Loss plot')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

plt.plot(val_avg_dscoeffs)
plt.title('Performance plot')
plt.xlabel('Epochs')
plt.ylabel('DSC')
plt.show()

In [None]:
img_name = '62_50.bmp'
mask_name = img_name

im = Image.open(os.path.join('CT/test/img', img_name))
# im = ImageOps.grayscale(im)
im = im.resize((256, 256))
im = np.array(im) / 255.
plt.imshow(im, cmap="gray")
plt.show()

seg = Image.open(os.path.join('CT/test/seg_w', mask_name))
seg = seg.resize((256, 256))
seg = np.array(seg)
plt.imshow(seg, cmap="gray")
plt.show()

model_enc.eval()
model_dec2.eval()
out = model_dec2(model_enc(torch.FloatTensor(im.reshape((1,1,256,256))).to(device)))
plt.imshow(out.data.cpu().numpy().squeeze(), cmap="gray")
plt.show()