In [None]:
import os
data_folder = "../input/uw-madison-gi-tract-image-segmentation/" if os.environ.get("KAGGLE_KERNEL_RUN_TYPE", "") else "./data/"

# List all imports below
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scv_utility import *
import torch

np.random.seed(0)
torch.manual_seed(0)
pd.set_option("display.width", 120)

In [None]:
# Load small train and test datasets with only stomach labels
labels = pd.read_csv(data_folder + "train.csv", converters={"id": str, "class": str, "segmentation": str})
print(f"Classes in train set: {labels['class'].unique()}")

train_cases = ["case2_", "case7_", "case15_", "case20_", "case22_", "case24_", "case29_", "case30_", "case32_", "case123_"]
test_cases = ["case156_", "case154_", "case149_"]
train_labels = labels[labels["id"].str.contains("|".join(train_cases))]
test_labels = labels[labels["id"].str.contains("|".join(test_cases))]

In [None]:
# Analyze train and test dataset to assure they have the same resolution
for sample_id in np.concatenate((train_labels["id"].unique(), test_labels["id"].unique())):
    try:
        sample_image, sample_image_res, sample_pixel_size = get_image_data_from_id(sample_id, data_folder)
        #print(f"Image shape: {sample_image.shape}, reported resolution: {sample_image_res}, reported pixel size: {sample_pixel_size}")
        assert sample_image_res == (266, 266) and sample_pixel_size == 1.50, "Incorrect resolution or pixel size"
    except Exception as e:
        print(f"Exception {e} while reading image {sample_id}")
print("Dataset analysis successfull")

## Classification network

In [None]:
# Training parameters
batch_size = 32
learning_rate = 0.01
criterion = torch.nn.BCEWithLogitsLoss()
epochs = 3

# Try using gpu instead of cpu
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

train_labels = train_labels[(train_labels["class"] == "stomach")]
test_labels = test_labels[(test_labels["class"] == "stomach")]
train_data = MRIClassificationDataset(data_folder, train_labels)
test_data = MRIClassificationDataset(data_folder, test_labels)
print(f"Number of train images: {len(train_data)}, test images: {len(test_data)}")

# Initialize network
from torchvision.models import resnet50
net = torch.nn.Sequential(resnet50(pretrained=False), torch.nn.Linear(1000, 1))
net.to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

# Training loop
train(net, train_data, test_data, criterion, optimizer, batch_size, epochs, "classifier")

# TODO calculate accuracy, precision, recall, etc on test set

## Segmentation network

In [None]:
# Training parameters
batch_size = 32
learning_rate = 0.01
criterion = torch.nn.BCEWithLogitsLoss()
epochs = 3

# Try using gpu instead of cpu
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

train_labels = train_labels[(train_labels["class"] == "stomach") & (train_labels["segmentation"] != "")]
test_labels = test_labels[(test_labels["class"] == "stomach") & (test_labels["segmentation"] != "")]
train_data = MRISegmentationDataset(data_folder, train_labels)
test_data = MRISegmentationDataset(data_folder, test_labels)
print(f"Number of train images: {len(train_data)}, test images: {len(test_data)}")

# Initialize network
from torchvision.models.segmentation import fcn_resnet50
net = fcn_resnet50(pretrained=False, num_classes=3)
net.to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

# Training loop
train(net, train_data, test_data, criterion, optimizer, batch_size, epochs, "segmentation", lambda out : out["out"])

# TODO calculate segmentation metrics on test set

In [None]:
# Visualize example image and mask
sample_train_images, sample_train_masks = next(iter(DataLoader(train_data, batch_size=16)))
sample_test_images, sample_test_masks = next(iter(DataLoader(test_data, batch_size=16)))

sample_id = 6
plt.imshow(sample_train_images[sample_id][0])
plt.imshow(sample_train_masks[sample_id][0], cmap="jet", alpha=0.3)
plt.show()

plt.imshow(sample_test_images[sample_id][0])
plt.imshow(sample_test_masks[sample_id][0], cmap="jet", alpha=0.3)
plt.show()

In [None]:
# Evaluate network on sample images
sample_train_predictions = torch.sigmoid(net(sample_train_images.expand(-1, 3, -1, -1).to(device))["out"]).detach().cpu().numpy()
sample_test_predictions = torch.sigmoid(net(sample_test_images.expand(-1, 3, -1, -1).to(device))["out"]).detach().cpu().numpy()

In [None]:
sample_id = 6

# Plot sample from training set
plt.imshow(sample_train_images[sample_id][0])
plt.imshow(sample_train_masks[sample_id][0], cmap="jet", alpha=0.3)
plt.show()
plt.imshow(sample_train_images[sample_id][0])
plt.imshow(sample_train_predictions[sample_id][0], cmap="jet", alpha=0.3)
plt.show()
print(f"Train sample prediction min: {np.min(sample_train_predictions[sample_id][0])}")
print(f"Train sample prediction max: {np.max(sample_train_predictions[sample_id][0])}")

# Plot sample from test set
plt.imshow(sample_test_images[sample_id][0])
plt.imshow(sample_test_masks[sample_id][0], cmap="jet", alpha=0.3)
plt.show()
plt.imshow(sample_test_images[sample_id][0])
plt.imshow(sample_test_predictions[sample_id][0], cmap="jet", alpha=0.3)
plt.show()
print(f"Test sample prediction min: {np.min(sample_test_predictions[sample_id][0])}")
print(f"Test sample prediction max: {np.max(sample_test_predictions[sample_id][0])}")