In [None]:
import os
data_folder = "../input/uw-madison-gi-tract-image-segmentation/" if os.environ.get("KAGGLE_KERNEL_RUN_TYPE", "") else "./data/"

# List all imports below
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scv_utility import *
import torch
from torch.utils.data import DataLoader

np.random.seed(0)
torch.manual_seed(0)
pd.set_option("display.width", 120)

In [None]:
# Load small train and test datasets with only stomach labels
labels = pd.read_csv(data_folder + "train.csv", converters={"id": str, "class": str, "segmentation": str})
print(f"Classes in train set: {labels['class'].unique()}")

train_cases = ["case2_", "case7_", "case15_", "case20_", "case22_", "case24_", "case29_", "case30_", "case32_", "case123_"]
test_cases = ["case156_", "case154_", "case149_"]
label_filters = (labels["class"] == "stomach") & (labels["segmentation"] != "")
train_labels = labels[labels["id"].str.contains("|".join(train_cases)) & label_filters]
test_labels = labels[labels["id"].str.contains("|".join(test_cases)) & label_filters]
train_data = MRIDataset(data_folder, train_labels)
test_data = MRIDataset(data_folder, test_labels)
print(f"Number of train images: {len(train_data)}, test images: {len(test_data)}")

In [None]:
# Analyze train and test dataset to assure they have the same resolution
for sample_id in np.concatenate((train_labels["id"].unique(), test_labels["id"].unique())):
    try:
        sample_image, sample_image_res, sample_pixel_size = get_image_data_from_id(sample_id, data_folder)
        #print(f"Image shape: {sample_image.shape}, reported resolution: {sample_image_res}, reported pixel size: {sample_pixel_size}")
        assert sample_image_res == (266, 266) and sample_pixel_size == 1.50, "Incorrect resolution or pixel size"
    except Exception as e:
        print(f"Exception {e} while reading image {sample_id}")
print("Dataset analysis successfull")

In [None]:
# Visualize example image and mask
sample_train_images, sample_train_masks = next(iter(DataLoader(train_data, batch_size=16)))
sample_test_images, sample_test_masks = next(iter(DataLoader(test_data, batch_size=16)))

sample_id = 6
plt.imshow(sample_train_images[sample_id][0])
plt.imshow(sample_train_masks[sample_id][0], cmap="jet", alpha=0.3)
plt.show()

plt.imshow(sample_test_images[sample_id][0])
plt.imshow(sample_test_masks[sample_id][0], cmap="jet", alpha=0.3)
plt.show()

In [None]:
# CREDITS for a big portion of the training loop: CS4240 DL assignment 3

# Training parameters
batch_size = 32
learning_rate = 0.01
criterion = torch.nn.BCEWithLogitsLoss()
epochs = 20

# Try using gpu instead of cpu
if torch.cuda.is_available():
    device = torch.device('cuda:0')
else:
    device = torch.device('cpu')

train_loader = DataLoader(train_data, batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size)

# Initialize network
from torchvision.models.segmentation import fcn_resnet50
net = fcn_resnet50(pretrained=False, num_classes=1)
net.to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

# Training loop
print(f"Start training on device {device}, batch size {batch_size}, {len(train_data)} train samples ({len(train_loader)} batches)")
for epoch in range(epochs):
    train_epoch_loss = 0
    net.train() # Switch network to train mode
    print("0% [", end="")
    for i, (x_batch, y_batch) in enumerate(train_loader):
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        x_batch = x_batch.expand(-1, 3, -1, -1) # Expand input to 3 channels (required by fcn_resnet50)
        optimizer.zero_grad()               # Set the gradients to zero
        y_pred = net(x_batch)["out"]        # Perform forward pass
        loss = criterion(y_pred, y_batch)   # Compute the loss
        loss.backward()                     # Backward pass to compute gradients
        optimizer.step()                    # Update parameters
        train_epoch_loss += loss.item()     # Discard gradients and store total loss
        if i % (len(train_loader) // 20) == 0:
            print("#", end="")              # Print progress every 5%
    print("] 100%")
    test_epoch_loss = 0
    with torch.no_grad():
        net.eval() # Switch network to eval mode
        for (x_batch, y_batch) in test_loader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            x_batch = x_batch.expand(-1, 3, -1, -1) # Expand input to 3 channels (required by fcn_resnet50)
            y_pred = net(x_batch)["out"]        # Perform forward pass
            loss = criterion(y_pred, y_batch)   # Compute the loss
            test_epoch_loss += loss.item()      # Discard gradients and store total loss
    # calculate the average training and validation loss
    avg_train_loss = train_epoch_loss / len(train_loader)
    avg_test_loss = test_epoch_loss / len(test_loader)
    print(f"Epoch: {epoch+1}, train loss: {train_epoch_loss}, test loss: {test_epoch_loss}")
    model_state = {"model_state_dict": net.state_dict(), "optimizer_state_dict": optimizer.state_dict()}
    torch.save(model_state, f"checkpoint{epoch+1}.pkl")

print("Training done")

In [None]:
# Evaluate network on sample images
sample_train_predictions = net(sample_train_images.to(device)).detach().cpu().numpy()
sample_test_predictions = net(sample_test_images.to(device)).detach().cpu().numpy()

In [None]:
sample_id = 6

# Plot sample from training set
plt.imshow(sample_train_images[sample_id][0])
plt.imshow(sample_train_masks[sample_id][0], cmap="jet", alpha=0.3)
plt.show()
plt.imshow(sample_train_images[sample_id][0])
plt.imshow(sample_train_predictions[sample_id][0], cmap="jet", alpha=0.3)
plt.show()
print(f"Train sample prediction min: {np.min(sample_train_predictions[sample_id][0])}")
print(f"Train sample prediction max: {np.max(sample_train_predictions[sample_id][0])}")

# Plot sample from test set
plt.imshow(sample_test_images[sample_id][0])
plt.imshow(sample_test_masks[sample_id][0], cmap="jet", alpha=0.3)
plt.show()
plt.imshow(sample_test_images[sample_id][0])
plt.imshow(sample_test_predictions[sample_id][0], cmap="jet", alpha=0.3)
plt.show()
print(f"Test sample prediction min: {np.min(sample_test_predictions[sample_id][0])}")
print(f"Test sample prediction max: {np.max(sample_test_predictions[sample_id][0])}")