In [None]:
import os
import random
import matplotlib.pyplot as plt
from torchvision import transforms
from torch import nn
import torch

import numpy as np
from PIL import Image

from tqdm import tqdm
from IPython.display import clear_output

In [None]:
cpu = torch.device("cpu")
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

device

In [None]:
images = "./data/train/kidney_1_dense/images"
labels = "./data/train/kidney_1_dense/labels"

# images = "./data/train/kidney_1_voi/images"
# labels = "./data/train/kidney_1_voi/labels"

# get each file in the folder
image_set = os.listdir(images)
label_set = os.listdir(labels)

set_length = len(image_set)

image_set.sort()
label_set.sort()

# index = random.randint(0, set_length)

# image = image_set[index]
# image_name = os.path.join(images, image)

# label = label_set[index]
# label_name = os.path.join(labels, label)

# print (index)
# print (image_name)

# image_pil = Image.open(image_name)
# label_pil = Image.open(label_name)

# image_tensor = transforms.ToTensor()(image_pil).float()
# label_tensor = transforms.ToTensor()(label_pil)

# image_shape = image_tensor.shape
# print (image_shape)

# # Find the maximum value in the tensor
# max_value = torch.max(image_tensor)
# min_value = torch.min(image_tensor)

# # # Normalize the tensor by dividing by the maximum value
# # image_tensor = (image_tensor - min_value) / (max_value - min_value)

# print (torch.max(image_tensor))
# print (torch.min(image_tensor))

# print (image_tensor)

# # plot the image and label side by side
# fig, axes = plt.subplots(nrows=1, ncols=2)
# ax = axes.ravel()

# ax[0].imshow(image_tensor.permute(1, 2, 0))
# ax[0].set_title("Image")

# ax[1].imshow(label_tensor.permute(1, 2, 0))
# ax[1].set_title("Label")

# plt.tight_layout()
# plt.show()



In [None]:
rles_file = "./data/train_rles.csv"

# read the csv file
import pandas as pd
rles = pd.read_csv(rles_file)
rle = rles.iloc[index]["rle"]

# convert the rle to a mask
def rle2mask(rle, shape):
    """
    rle: run-length as string formated (start length)
    shape: (height, width) of array to return
    Returns numpy array, 1 - mask, 0 - background
    """
    s = rle.split()
    starts = np.asarray(s[0::2], dtype=int)
    lengths = np.asarray(s[1::2], dtype=int)
    # evaluate whether or not this is needed:
    starts -= 1
    
    ends = starts + lengths

    mask = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        mask[lo:hi] = 1

    return mask.reshape(shape)

# def mask2rle(mask, shape):
#     """
#     mask: numpy array, 1 - mask, 0 - background
#     Returns run length as string formatted
#     """

#     mask_1d = mask.reshape(1, shape[0]*shape[1])


# convert the rle to a mask
mask = rle2mask(rle, (image_shape[1], image_shape[2]))

# plot the mask
plt.imshow(mask)
plt.show()


In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()

        self.encoder = nn.Sequential(
            DoubleConv(in_channels, 64),
            nn.MaxPool2d(kernel_size=2, stride=2),
            DoubleConv(64, 128),
            nn.MaxPool2d(kernel_size=2, stride=2),
            DoubleConv(128, 256),
            nn.MaxPool2d(kernel_size=2, stride=2),
            DoubleConv(256, 512),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.bottleneck = DoubleConv(512, 1024)

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2),
            DoubleConv(1024, 512),
            nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),
            DoubleConv(512, 256),
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
            DoubleConv(256, 128),
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
            DoubleConv(128, 64),
        )

        self.final = nn.Conv2d(64, out_channels, kernel_size=1)


    def forward(self, x):
        # Contracting path (encoder)
        encoders = []
        for i in range(len(self.encoder)):
            module = self.encoder[i]
            x = module(x)
            if i % 2 == 0:
                encoders.append(x)

        # Bottleneck
        x = self.bottleneck(x)

        # Reverse the encoder outputs for the expanding path
        encoders = encoders[::-1]

        # Expanding path (decoder)
        for i in range(len(self.decoder)):
            decoder = self.decoder[i]
            x = decoder(x)
            if i % 2 == 0:
                index = i // 2
                encoder = encoders[index]

                if x.shape != encoder.shape:
                    x = transforms.functional.resize(x, size=encoder.shape[2:])
                    
                x = torch.cat((x, encoder), dim=1)

        # Final layer
        x = self.final(x)

        return x
    
# sample = image_tensor.unsqueeze(0).to(device)
# print (sample.shape)

# unet = UNet(in_channels=1, out_channels=1).to(device)
# output = unet(sample)

# output = output.squeeze(0)
# plt.imshow(output.permute(1, 2, 0).detach().to(cpu).numpy())

In [None]:
image_label_set = []

for i in range(set_length):
    image = image_set[i]
    image_name = os.path.join(images, image)

    label = label_set[i]
    label_name = os.path.join(labels, label)

    image_label_set.append((image_name, label_name))

random.shuffle(image_label_set)
print (len(image_label_set))

In [None]:
# create a dataset
from torch.utils.data import Dataset, DataLoader

#Adjust this to take a dir name instead [VK]
class KidneyDataset(Dataset):
    def __init__(self, image_label_set, device):
        self.image_label_set = image_label_set
        self.device = device

    def __len__(self):
        return len(self.image_label_set)

    def __getitem__(self, index):
        image_name, label_name = self.image_label_set[index]

        image_pil = Image.open(image_name)
        label_pil = Image.open(label_name)

        image_tensor = transforms.ToTensor()(image_pil).float().to(self.device)
        label_tensor = transforms.ToTensor()(label_pil).to(self.device)

        return image_tensor, label_tensor

In [None]:
dataset = KidneyDataset(image_label_set, device=device)

# create a dataloader
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

# create a model
model = UNet(in_channels=1, out_channels=1).to(device)

# create a binary cross entropy loss
criterion = nn.BCEWithLogitsLoss()

# create an optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
# train the model
num_epochs = 1
losses = []
for epoch in range(num_epochs):
    loop = tqdm(dataloader)
    for i, (images_tensors, labels) in enumerate(loop):
        # forward pass
        outputs = model(images_tensors)
        loss = criterion(outputs, labels)

        # backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # print the loss
        if (i+1) % 100 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                   .format(epoch+1, num_epochs, i+1, len(dataset)//1, loss.item()))
        losses.append(loss.item())
        loop.set_postfix(loss=loss.item())

        clear_output(wait=True)

        # plot the loss
        plt.plot(losses)
        plt.show()
                
            
# save the model
torch.save(model.state_dict(), './checkpoints/model.ckpt')

# # plot the loss
# plt.plot(losses)
# plt.show()


In [None]:
# test the model
model.eval()

# get a random image
index = random.randint(0, len(dataset))
sample_image, sample_label = dataset[index]

print (index)

sample_image_tensor = sample_image.unsqueeze(0).to(device)

print (sample_image_tensor.shape)

# get the output
output = model(sample_image_tensor)
output_image = output.squeeze(0)

print (output_image)

# round each value in the output to either 0 or 1
output_image_rounded = torch.round(output_image)

output_image_sigmoid = torch.sigmoid(output_image)

output_image_sigmoid_rounded = torch.round(output_image_sigmoid)

# plot the sample, label, and output_image side by side
# set color map to binary
# plt.rcParams['image.cmap'] = 'binary'
fig, axes = plt.subplots(nrows=1, ncols=3)
ax = axes.ravel()

ax[0].imshow(sample_image.permute(1, 2, 0).detach().to(cpu).numpy())
ax[0].set_title("Image")

ax[1].imshow(sample_label.permute(1, 2, 0).detach().to(cpu).numpy())
ax[1].set_title("Label")

ax[2].imshow(output_image_rounded.permute(1, 2, 0).detach().to(cpu).numpy())
ax[2].set_title("Output")

plt.tight_layout()
plt.show()
