In [None]:
import os
data_folder = "../input/uw-madison-gi-tract-image-segmentation/" if os.environ.get("KAGGLE_KERNEL_RUN_TYPE", "") else "./data/"

# List all imports below
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scv_utility import *
import torch
from torch.utils.data import DataLoader

np.random.seed(0)
torch.manual_seed(0)
pd.set_option("display.width", 120)

In [None]:
# Load small train and test datasets with only stomach labels
labels = pd.read_csv(data_folder + "train.csv", converters={"id": str, "class": str, "segmentation": str})
print(f"Classes in train set: {labels['class'].unique()}")

train_cases = ["case123_"] #, "case2_", "case7_", "case15_", "case20_", "case22_", "case24_", "case29_", "case30_", "case32_", "case123_"]
test_cases = ["case156_"] #, "case154_", "case149_"]
train_labels = labels[labels["id"].str.contains("|".join(train_cases)) & (labels["class"] == "stomach")]
test_labels = labels[labels["id"].str.contains("|".join(test_cases)) & (labels["class"] == "stomach")]
train_data = MRIDataset(data_folder, train_labels)
test_data = MRIDataset(data_folder, test_labels)
print(f"Number of train images: {len(train_data)}, test images: {len(test_data)}")

In [None]:
# Analyze train and test dataset to assure they have the same resolution
for sample_id in np.concatenate((train_labels["id"].unique(), test_labels["id"].unique())):
    try:
        sample_image, sample_image_res, sample_pixel_size = get_image_data_from_id(sample_id, data_folder)
        #print(f"Image shape: {sample_image.shape}, reported resolution: {sample_image_res}, reported pixel size: {sample_pixel_size}")
        assert sample_image_res == (266, 266) and sample_pixel_size == 1.50, "Incorrect resolution or pixel size"
    except Exception as e:
        print(f"Exception {e} while reading image {sample_id}")
print("Dataset analysis successfull")

In [None]:
# Visualize example image and mask
sample_train_loader = DataLoader(train_data, batch_size=2)
sample_images, sample_masks = next(iter(sample_train_loader))

plt.imshow(sample_images[0][0])
plt.show()
plt.imshow(sample_masks[0][0], cmap="jet", alpha=0.3)
plt.show()
print(type(train_data[0][0]))

In [None]:
# CREDITS for a big portion of the training loop: CS4240 DL assignment 3
def try_gpu():
    """
    If GPU is available, return torch.device as cuda:0; else return torch.device
    as cpu.
    """
    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')
    return device

In [None]:
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64)

# Training parameters
learning_rate = 0.01
epochs = 2

# Initialize network
net = UNet(enc_chs=(1,64,128,256,512,1024), dec_chs=(1024, 512, 256, 128, 64), num_class=1, retain_dim=True, out_sz=(266,266))
optimizer = torch.optim.SGD(net.parameters(), lr = learning_rate)

# TODO: A more reasonable loss function !!!!!!!!!!!!
criterion = nn.CrossEntropyLoss()

# Define list to store losses and performances of each iteration
train_losses = []
train_accs = []
test_accs = []

# Try using gpu instead of cpu
device = try_gpu()

for epoch in range(epochs):

    # Network in training mode and to device
    net.train()
    net.to(device)

    # Training loop
    for i, (x_batch, y_batch) in enumerate(train_loader):

        # Set to same device
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)

        # Set the gradients to zero
        optimizer.zero_grad()

        # Perform forward pass
        y_pred = net(x_batch)

        # Compute the loss
        loss = criterion(y_pred, y_batch)
        train_losses.append(loss)
        
        # Backward computation and update
        loss.backward()
        optimizer.step()

    # TODO: Print performance (IoU metric?)
    print('Epoch: {:.0f}'.format(epoch+1))
    print('')

In [None]:
unet = UNet(enc_chs=(1,64,128,256,512,1024),retain_dim=True, out_sz=(266,266))
x    = torch.randn(1, 1, 266, 266)
print(x.dtype)
print(sample_images.dtype)
print(unet(x).shape)
print(unet(sample_images.to(torch.float32)).shape)

In [None]:
result = net(sample_images)

print(sample_images.shape)
print(result.shape)

In [None]:
plt.imshow(sample_images[0][0])
plt.show()
plt.imshow(result[0][0].detach().numpy(), cmap="Greys")
plt.show()

plt.imshow(sample_images[0][0])
plt.imshow(result[0][0].detach().numpy(), cmap="Greys", alpha=0.3)
plt.show()