In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader, random_split

In [2]:
DataPath = "./lgg-mri-segmentation/kaggle_3m/"
dirs = []
images = []
masks = []
for dirpath, dirnames, filenames in os.walk(DataPath):
    for filename in filenames:
        if 'mask' in filename:
            dirs.append(dirpath.replace(DataPath, ''))
            masks.append(filename)
            images.append(filename.replace('_mask', ''))

In [None]:
def plot_image(idx):
    imagePath = os.path.join(DataPath, dirs[idx], images[idx])
    maskPath = os.path.join(DataPath, dirs[idx], masks[idx])
    image = plt.imread(imagePath)
    mask = plt.imread(maskPath)

    fig, axs = plt.subplots(1, 3, figsize=[13, 15])
    
    axs[0].imshow(image)
    axs[0].set_title('Brain MRI')
    
    axs[1].imshow(mask, cmap='gray')
    axs[1].set_title('Mask')
    
    axs[2].imshow(image)
    axs[2].imshow(mask, cmap='gray', alpha=0.3)
    axs[2].set_title('MRI with mask')
    
    plt.show()

for idx in range(5):
    plot_image(idx+100)

In [4]:
class MRIDataset(Dataset):
    def __init__(self):
        self.images = images
        self.masks = masks
        self.dirs = dirs

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        imagePath = os.path.join(DataPath, dirs[idx], images[idx])
        maskPath = os.path.join(DataPath, dirs[idx], masks[idx])

        image = plt.imread(imagePath)/255
        image = torch.from_numpy(image).type(torch.float32)
        image = image.permute((2, 0, 1))

        mask = np.expand_dims(plt.imread(maskPath)/255, axis=-1)
        mask = torch.from_numpy(mask).type(torch.float32)
        mask = mask.permute((2, 0, 1))
        return image, mask
    
dataset = MRIDataset()

train_size = int(0.8 * len(dataset))
test_size = len(images) - train_size
train_set, test_set = random_split(dataset, [train_size, test_size])

In [5]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        return x
    
class Contract(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.doubleConv = DoubleConv(in_channels, out_channels)

    def forward(self, x):
        x = F.max_pool2d(x, 2)
        x = self.doubleConv(x)
        return x
    
class Expand(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.upsample = nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size=2, stride=2)
        self.doubleConv = DoubleConv(in_channels, out_channels)

    def forward(self, x, shortcut):
        x = self.upsample(x)
        return self.doubleConv(torch.cat([shortcut, x], dim=1))
    
class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
    def forward(self, x):
        x = self.conv(x)
        x = F.sigmoid(x)
        return x

In [6]:
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        self.inconv = DoubleConv(3, 64)
        self.down1 = Contract(64, 128)
        self.down2 = Contract(128, 256)
        self.down3 = Contract(256, 512)
        self.down4 = Contract(512, 1024)
        self.up1 = Expand(1024, 512)
        self.up2 = Expand(512, 256)        
        self.up3 = Expand(256, 128)        
        self.up4 = Expand(128, 64)
        self.outconv = OutConv(64, 1)

    def forward(self, x):
        x1 = self.inconv(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)

        return self.outconv(x)

In [7]:
def dice_metric(pred, label):
    intersection = 2.0 * (pred * label).sum()
    union = pred.sum() + label.sum()
    if union == 0:
        return 1
    return intersection/union

def dice_loss(pred, label):
    smooth = 1.0
    intersection = 2.0 * (pred * label).sum() + smooth
    union = pred.sum() + label.sum() + smooth
    return 1 - (intersection/union)

In [8]:
max_epoch = 5
batch_size = 64
learning_rate = 0.001

In [9]:
train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_set, batch_size=batch_size, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet().to(device)

optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
out = model(torch.randn(4, 3, 256, 256).to(device))
print(out.shape)

In [None]:
loss_set = []
accu_set = []
disp_freq = 100
iter_per_batch = train_size//batch_size

model.train()
for epoch in range(max_epoch):
    for batch_id, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)

        loss = dice_loss(output, target)
        loss.backward()
        optimizer.step()
        loss_set.append(loss.item())

        out_cut = np.copy(output.data.cpu().numpy())
        out_cut[np.nonzero(out_cut < 0.5)] = 0.0
        out_cut[np.nonzero(out_cut >= 0.5)] = 1.0
        accuracy = dice_metric(output, target)
        accu_set.append(accuracy)

        if batch_id % disp_freq == 0:
            print("Epoch [{}][{}]\t Batch [{}][{}]\t Training Loss {:.4f}\t Accuracy {:.4f}".format(
                epoch, max_epoch, batch_id, iter_per_batch, 
                loss, accuracy))

In [None]:
model.eval()