In [None]:
import os
data_folder = "../input/uw-madison-gi-tract-image-segmentation/" if os.environ.get("KAGGLE_KERNEL_RUN_TYPE", "") else "./data/"

# List all imports below
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scv_utility import *
import torch

np.random.seed(0)
torch.manual_seed(0)
pd.set_option("display.width", 120)

In [None]:
# Load small train and test datasets with only stomach labels
labels = pd.read_csv(data_folder + "train.csv", converters={"id": str, "class": str, "segmentation": str})
print(f"Classes in train set: {labels['class'].unique()}")

train_cases = ["case2_", "case7_", "case15_", "case20_", "case22_", "case24_", "case29_", "case30_", "case32_", "case123_"]
test_cases = ["case156_", "case154_", "case149_"]
train_labels = labels[labels["id"].str.contains("|".join(train_cases))]
test_labels = labels[labels["id"].str.contains("|".join(test_cases))]

In [None]:
# Analyze train and test dataset to assure they have the same resolution
for sample_id in np.concatenate((train_labels["id"].unique(), test_labels["id"].unique())):
    try:
        sample_image, sample_image_res, sample_pixel_size = get_image_data_from_id(sample_id, data_folder)
        #print(f"Image shape: {sample_image.shape}, reported resolution: {sample_image_res}, reported pixel size: {sample_pixel_size}")
        assert sample_image_res == (266, 266) and sample_pixel_size == 1.50, "Incorrect resolution or pixel size"
    except Exception as e:
        print(f"Exception {e} while reading image {sample_id}")
print("Dataset analysis successfull")

## Classification network

In [None]:
TRAIN_CLASSIFICATION = False
if TRAIN_CLASSIFICATION:
    # Training parameters
    batch_size = 32
    learning_rate = 0.01
    criterion = torch.nn.BCEWithLogitsLoss()
    epochs = 3

    # Try using gpu instead of cpu
    device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

    train_labels = train_labels[(train_labels["class"] == "stomach")]
    test_labels = test_labels[(test_labels["class"] == "stomach")]
    train_data = MRIClassificationDataset(data_folder, train_labels)
    test_data = MRIClassificationDataset(data_folder, test_labels)
    print(f"Number of train images: {len(train_data)}, test images: {len(test_data)}")

    # Initialize network
    from torchvision.models import resnet50
    net = torch.nn.Sequential(resnet50(pretrained=False), torch.nn.Linear(1000, 1))
    net.to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

    # Training loop
    train(net, train_data, test_data, criterion, optimizer, batch_size, epochs, "classifier")

    # Evaluate classification network
    from sklearn.metrics import confusion_matrix, f1_score, precision_score, recall_score
    for threshold in np.linspace(0.1, 0.9, num=9):
        y_pred = []
        y_true = []
        # iterate over test data
        for x_batch, y_batch in DataLoader(train_data, batch_size=16):
            output = torch.nn.Sigmoid()(net(x_batch.expand(-1, 3, -1, -1).to(device)).detach())
            output[output > threshold] = 1.
            output[output < 1.] = 0.
            output = output.cpu().numpy()
            y_pred.extend(output)          # Save prediction
            y_true.extend(y_batch.numpy()) # Save truth
        # Build confusion matrix
        print(f"Threshold: {threshold}, precision: {precision_score(y_true, y_pred)}, recall: {recall_score(y_true, y_pred)}, F1-score: {f1_score(y_true, y_pred)}")
        print(confusion_matrix(y_true, y_pred))

## Segmentation network

In [None]:
TRAIN_SEGMENTATION = True
if TRAIN_SEGMENTATION:
    # Training parameters
    batch_size = 32
    learning_rate = 0.01
    criterion = torch.nn.BCEWithLogitsLoss()
    epochs = 3

    # Try using gpu instead of cpu
    device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

    train_labels = train_labels[train_labels["segmentation"] != ""] # train_labels[(train_labels["class"] == "stomach") & (train_labels["segmentation"] != "")]
    test_labels = test_labels[test_labels["segmentation"] != ""] # test_labels[(test_labels["class"] == "stomach") & (test_labels["segmentation"] != "")]
    train_data = MRISegmentationDataset(data_folder, train_labels)
    test_data = MRISegmentationDataset(data_folder, test_labels)
    print(f"Number of train images: {len(train_data)}, test images: {len(test_data)}")

    # Initialize network
    from torchvision.models.segmentation import fcn_resnet50
    net = fcn_resnet50(pretrained=False, num_classes=3)
    net.to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

    # Training loop
    train(net, train_data, test_data, criterion, optimizer, batch_size, epochs, "segmentation", lambda out : out["out"])

    # TODO calculate segmentation metrics on test set

In [None]:
# Visualize example image and mask
sample_train_images, sample_train_masks = next(iter(DataLoader(train_data, batch_size=16)))
sample_test_images, sample_test_masks = next(iter(DataLoader(test_data, batch_size=16)))

# Evaluate network on sample images
sample_train_predictions = torch.sigmoid(net(sample_train_images.expand(-1, 3, -1, -1).to(device))["out"]).detach().cpu().numpy()
sample_test_predictions = torch.sigmoid(net(sample_test_images.expand(-1, 3, -1, -1).to(device))["out"]).detach().cpu().numpy()

# Samples to numpy (for visulization purposes)
sample_train_images, sample_train_masks = sample_train_images.cpu().detach().numpy(), sample_train_masks.cpu().detach().numpy()
sample_test_images, sample_test_masks = sample_test_images.cpu().detach().numpy(), sample_test_masks.cpu().detach().numpy()

a = 0.3

plt.rcParams["figure.figsize"] = (10,20)

for sample_id in range(16):
    f, axarr = plt.subplots(1,3)
    axarr[0].imshow(sample_train_images[sample_id][0], cmap="binary")
    axarr[1].imshow(sample_train_masks[sample_id][0], cmap="Greens", alpha=a, vmin=0.0, vmax=1.0)
    axarr[1].imshow(sample_train_masks[sample_id][1], cmap="Reds", alpha=a, vmin=0.0, vmax=1.0)
    axarr[1].imshow(sample_train_masks[sample_id][2], cmap="Blues", alpha=a, vmin=0.0, vmax=1.0)
    axarr[2].imshow(sample_train_images[sample_id][0], cmap="binary")
    axarr[2].imshow(sample_train_masks[sample_id][0], cmap="Greens", alpha=a, vmin=0.0, vmax=1.0)
    axarr[2].imshow(sample_train_masks[sample_id][1], cmap="Reds", alpha=a, vmin=0.0, vmax=1.0)
    axarr[2].imshow(sample_train_masks[sample_id][2], cmap="Blues", alpha=a, vmin=0.0, vmax=1.0)
    plt.show()

for sample_id in range(16):
    f, axarr = plt.subplots(1,3)
    axarr[0].imshow(sample_test_images[sample_id][0], cmap="binary")
    axarr[1].imshow(sample_test_masks[sample_id][0], cmap="Greens", alpha=a)
    axarr[1].imshow(sample_test_masks[sample_id][1], cmap="Reds", alpha=a)
    axarr[1].imshow(sample_test_masks[sample_id][2], cmap="Blues", alpha=a)
    axarr[2].imshow(sample_test_images[sample_id][0], cmap="Greys")
    axarr[2].imshow(sample_test_masks[sample_id][0], cmap="Greens", alpha=a)
    axarr[2].imshow(sample_test_masks[sample_id][1], cmap="Reds", alpha=a)
    axarr[2].imshow(sample_test_masks[sample_id][2], cmap="Blues", alpha=a)
    plt.show()

In [None]:
plt.rcParams["figure.figsize"] = (20,20)

for sample_id in range(16):
    f, axarr = plt.subplots(1,5)
    axarr[0].imshow(sample_train_images[sample_id][0], cmap="binary")
    axarr[1].imshow(sample_train_masks[sample_id][0], cmap="Greens", alpha=a, vmin=0.0, vmax=1.0)
    axarr[1].imshow(sample_train_masks[sample_id][1], cmap="Reds", alpha=a, vmin=0.0, vmax=1.0)
    axarr[1].imshow(sample_train_masks[sample_id][2], cmap="Blues", alpha=a, vmin=0.0, vmax=1.0)
    axarr[2].imshow(sample_train_images[sample_id][0], cmap="binary")
    axarr[2].imshow(sample_train_masks[sample_id][0], cmap="Greens", alpha=a, vmin=0.0, vmax=1.0)
    axarr[2].imshow(sample_train_masks[sample_id][1], cmap="Reds", alpha=a, vmin=0.0, vmax=1.0)
    axarr[2].imshow(sample_train_masks[sample_id][2], cmap="Blues", alpha=a, vmin=0.0, vmax=1.0)
    axarr[3].imshow(sample_train_predictions[sample_id][0], cmap="Greens", alpha=a, vmin=0.0, vmax=1.0)
    axarr[3].imshow(sample_train_predictions[sample_id][1], cmap="Reds", alpha=a, vmin=0.0, vmax=1.0)
    axarr[3].imshow(sample_train_predictions[sample_id][2], cmap="Blues", alpha=a, vmin=0.0, vmax=1.0)
    axarr[4].imshow(sample_train_images[sample_id][0], cmap="binary")
    axarr[4].imshow(sample_train_predictions[sample_id][0], cmap="Greens", alpha=a, vmin=0.0, vmax=1.0)
    axarr[4].imshow(sample_train_predictions[sample_id][1], cmap="Reds", alpha=a, vmin=0.0, vmax=1.0)
    axarr[4].imshow(sample_train_predictions[sample_id][2], cmap="Blues", alpha=a, vmin=0.0, vmax=1.0)
    plt.show()

# Plot sample from test set
print('========================================= TEST IMAGES ===============================================')

for sample_id in range(16):
    f, axarr = plt.subplots(1,5)
    axarr[0].imshow(sample_test_images[sample_id][0], cmap="binary")
    axarr[1].imshow(sample_test_masks[sample_id][0], cmap="Greens", alpha=a, vmin=0.0, vmax=1.0)
    axarr[1].imshow(sample_test_masks[sample_id][1], cmap="Reds", alpha=a, vmin=0.0, vmax=1.0)
    axarr[1].imshow(sample_test_masks[sample_id][2], cmap="Blues", alpha=a, vmin=0.0, vmax=1.0)
    axarr[2].imshow(sample_test_images[sample_id][0], cmap="binary")
    axarr[2].imshow(sample_test_masks[sample_id][0], cmap="Greens", alpha=a, vmin=0.0, vmax=1.0)
    axarr[2].imshow(sample_test_masks[sample_id][1], cmap="Reds", alpha=a, vmin=0.0, vmax=1.0)
    axarr[2].imshow(sample_test_masks[sample_id][2], cmap="Blues", alpha=a, vmin=0.0, vmax=1.0)
    axarr[3].imshow(sample_test_predictions[sample_id][0], cmap="Greens", alpha=a, vmin=0.0, vmax=1.0)
    axarr[3].imshow(sample_test_predictions[sample_id][1], cmap="Reds", alpha=a, vmin=0.0, vmax=1.0)
    axarr[3].imshow(sample_test_predictions[sample_id][2], cmap="Blues", alpha=a, vmin=0.0, vmax=1.0)
    axarr[4].imshow(sample_test_images[sample_id][0], cmap="binary")
    axarr[4].imshow(sample_test_predictions[sample_id][0], cmap="Greens", alpha=a, vmin=0.0, vmax=1.0)
    axarr[4].imshow(sample_test_predictions[sample_id][1], cmap="Reds", alpha=a, vmin=0.0, vmax=1.0)
    axarr[4].imshow(sample_test_predictions[sample_id][2], cmap="Blues", alpha=a, vmin=0.0, vmax=1.0)
    plt.show()
    print(f"Test sample prediction min: {np.min(sample_test_predictions[sample_id][0])}")
    print(f"Test sample prediction max: {np.max(sample_test_predictions[sample_id][0])}")