In [None]:
import os
data_folder = "../input/uw-madison-gi-tract-image-segmentation/" if os.environ.get("KAGGLE_KERNEL_RUN_TYPE", "") else "./data/"

# List all imports below
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scv_utility import *
import torch
from torch.utils.data import DataLoader

np.random.seed(0)
torch.manual_seed(0)
pd.set_option("display.width", 120)

In [None]:
# Load small train and test datasets with only stomach labels
labels = pd.read_csv(data_folder + "train.csv", converters={"id": str, "class": str, "segmentation": str})
print(f"Classes in train set: {labels['class'].unique()}")

train_cases = ["case2_", "case7_", "case15_", "case20_", "case22_", "case24_", "case29_", "case30_", "case32_", "case123_"]
test_cases = ["case156_"] #, "case154_", "case149_"]
label_filters = (labels["class"] == "stomach") & (labels["segmentation"] != "")
train_labels = labels[labels["id"].str.contains("|".join(train_cases)) & label_filters]
test_labels = labels[labels["id"].str.contains("|".join(test_cases)) & label_filters]
train_data = MRIDataset(data_folder, train_labels)
test_data = MRIDataset(data_folder, test_labels)
print(f"Number of train images: {len(train_data)}, test images: {len(test_data)}")

In [None]:
# Analyze train and test dataset to assure they have the same resolution
for sample_id in np.concatenate((train_labels["id"].unique(), test_labels["id"].unique())):
    try:
        sample_image, sample_image_res, sample_pixel_size = get_image_data_from_id(sample_id, data_folder)
        #print(f"Image shape: {sample_image.shape}, reported resolution: {sample_image_res}, reported pixel size: {sample_pixel_size}")
        assert sample_image_res == (266, 266) and sample_pixel_size == 1.50, "Incorrect resolution or pixel size"
    except Exception as e:
        print(f"Exception {e} while reading image {sample_id}")
print("Dataset analysis successfull")

In [None]:
# Visualize example image and mask
sample_train_loader = DataLoader(train_data, batch_size=84)
sample_images, sample_masks = next(iter(sample_train_loader))

plt.imshow(sample_images[66][0])
plt.imshow(sample_masks[66][0], cmap="jet", alpha=0.3)
plt.show()

In [None]:
# CREDITS for a big portion of the training loop: CS4240 DL assignment 3

# Training parameters
batch_size = 32
learning_rate = 0.01
momentum = 0.9
criterion = bce_dice_loss
epochs = 20

# Try using gpu instead of cpu
if torch.cuda.is_available():
    device = torch.device('cuda:0')
else:
    device = torch.device('cpu')
print(f"Selected device: {device}")

train_loader = DataLoader(train_data, batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size)

# Initialize network
net = UNet(enc_chs=(1,64,128,256,512,1024), dec_chs=(1024, 512, 256, 128, 64), num_class=1, retain_dim=True, out_sz=(266,266))
net.train()
net.to(device)
optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate, momentum=momentum)

# Training loop
for epoch in range(epochs):
    total_epoch_loss = 0
    for i, (x_batch, y_batch) in enumerate(train_loader):
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        optimizer.zero_grad()               # Set the gradients to zero
        y_pred = net(x_batch)               # Perform forward pass
        loss = criterion(y_pred, y_batch)   # Compute the loss
        loss.backward()                     # Backward pass to compute gradients
        optimizer.step()                    # Update parameters
        total_epoch_loss += loss.item()     # Discard gradients and store total loss
    print(f"Epoch: {epoch+1}, loss: {total_epoch_loss}")
    model_state = {"model_state_dict": net.state_dict(), "optimizer_state_dict": optimizer.state_dict()}
    torch.save(model_state, f"checkpoint{epoch+1}.pkl")

print("Training done")

In [None]:
# Evaluate network on sample images
predictions = net(sample_images.to(device)).detach().cpu().numpy()

In [None]:
prediction_id = 66
plt.imshow(sample_images[prediction_id][0])
plt.show()
plt.imshow(predictions[prediction_id][0])
plt.show()

plt.imshow(sample_images[prediction_id][0])
plt.imshow(predictions[prediction_id][0], cmap="jet", alpha=0.3)
plt.show()
print(np.min(predictions[prediction_id][0]))
print(np.max(predictions[prediction_id][0]))