In [None]:
import os
data_folder = "../input/uw-madison-gi-tract-image-segmentation/" if os.environ.get("KAGGLE_KERNEL_RUN_TYPE", "") else "./data/"
model_folder = "../input/scv-model-data/" if os.environ.get("KAGGLE_KERNEL_RUN_TYPE", "") else "./model_data/"

# List all imports below
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scv_utility import *
import torch
from torchvision import transforms

np.random.seed(0)
torch.manual_seed(0)
pd.set_option("display.width", 120)

In [None]:
# Load small train and test datasets with only stomach labels
labels = pd.read_csv(data_folder + "train.csv", converters={"id": str, "class": str, "segmentation": str})
print(f"Classes in train set: {labels['class'].unique()}")

all_cases = get_all_cases(data_folder)
train_cases = all_cases[:30]
val_cases = all_cases[30:35]
test_cases = all_cases[35:]

# Toy data; uncomment and comment the above values
# train_cases = ["case2_", "case7_", "case15_", "case20_", "case22_", "case24_", "case29_", "case30_", "case32_", "case123_"]
# val_cases = ["case146_", "case147_", "case148_"]
# test_cases = ["case156_", "case154_", "case149_"]

train_labels = labels[labels["id"].str.contains("|".join(train_cases))]
val_labels = labels[labels["id"].str.contains("|".join(val_cases))]
test_labels = labels[labels["id"].str.contains("|".join(test_cases))]
print(f"Data split sizes: train: {len(train_labels)}, val: {len(val_labels)}, test: {len(test_labels)}")

In [None]:
EVAL_CLASSIFICATION = False

if EVAL_CLASSIFICATION:
    # Try using gpu instead of cpu
    device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

    val_data = MRIClassificationDataset(data_folder, val_labels, transform = transforms.Compose([Rescale((266,266))]))
    test_data = MRIClassificationDataset(data_folder, test_labels, transform = transforms.Compose([Rescale((266,266))]))
    print(f"Number of val images: {len(val_data)}, test images: {len(test_data)}")
    # Load trained network from disk
    from torchvision.models import resnet50
    net = torch.nn.Sequential(resnet50(pretrained=False), torch.nn.Linear(1000, 3))
    net.load_state_dict(torch.load(f"{model_folder}/best_classifier.pkl", map_location=device)["model_state_dict"])
    net.to(device)

    # Generate predictions using network
    y_pred = np.zeros((0, 3))
    y_true = np.zeros((0, 3))
    for x_batch, y_batch in DataLoader(val_data, batch_size=16):
        output = torch.nn.Sigmoid()(net(x_batch.expand(-1, 3, -1, -1).to(device)).detach()).cpu().numpy()
        y_pred = np.append(y_pred, output, axis=0)          # Save prediction
        y_true = np.append(y_true, y_batch.numpy(), axis=0) # Save truth

    # Evaluate classification network
    from sklearn.metrics import confusion_matrix, f1_score, precision_score, recall_score
    print("Scanning threshold values on validation set...")
    for threshold in np.linspace(0.1, 0.9, num=9):
        y_pred_binary = (y_pred > threshold).astype(float)
        # Build confusion matrix
        print(f"Threshold: {threshold}, precision: {precision_score(y_true, y_pred_binary)}, recall: {recall_score(y_true, y_pred_binary)}, F1-score: {f1_score(y_true, y_pred_binary)}")
        print(confusion_matrix(y_true, y_pred_binary))
    print("Done evaluating all threshold values.")

## Segmentation network

In [None]:
EVAL_SEGMENTATION = True
ONLY_NON_EMPTY_GT = True

if EVAL_SEGMENTATION:
    # Try using gpu instead of cpu
    device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

    if ONLY_NON_EMPTY_GT:
        train_labels = train_labels[train_labels["segmentation"] != ""]
        val_labels = val_labels[val_labels["segmentation"] != ""]
        test_labels = test_labels[test_labels["segmentation"] != ""]
    
    train_data = MRISegmentationDataset(data_folder, train_labels, transform = transforms.Compose([
                                               Rescale((266,266)), 
                                               RandomCrop(),
                                               LabelSmoothing(p=0.4),
                                               Normalize(mean=0.458, std=0.229)]))
    val_data = MRISegmentationDataset(data_folder, val_labels, transform = transforms.Compose([Rescale((266,266)), Normalize(mean=0.458, std=0.229)]))
    test_data = MRISegmentationDataset(data_folder, test_labels, transform = transforms.Compose([Rescale((266,266)), Normalize(mean=0.458, std=0.229)]))
    print(f"Number of val images: {len(val_data)}, test images: {len(test_data)}")

    # Initialize network
    from torchvision.models.segmentation import fcn_resnet50
    net = fcn_resnet50(pretrained=False, num_classes=3)
    net.load_state_dict(torch.load(f"{model_folder}/segmentation_best_model.pkl", map_location=device)["model_state_dict"])
    net.to(device)
    
    print("Scanning threshold values on validation set...")
    best_threshold = 0
    best_threshold_score = 0
    for threshold in np.linspace(0.1, 0.9, num=9):
        threshold = round(threshold, 1)
        dice_score = evaluate_segmentation_model(net, threshold, val_data)
        print(f"Threshold: {round(threshold, 1)}, average dice score: {round(dice_score, 4)}")
        if dice_score > best_threshold_score:
            best_threshold = threshold
            best_threshold_score = dice_score
    print(f"Done evaluating all threshold values. Best threshold: {best_threshold} (dice score: {best_threshold_score})")

    print("Evaluating model with final threshold on test set...")
    dice_score = evaluate_segmentation_model(net, best_threshold, test_data)
    print(f"Evaluation done. Dice score on test set: {dice_score}")

In [None]:
import matplotlib.patches as mpatches

# Visualize example image and mask
sample_train_images, sample_train_masks = next(iter(DataLoader(train_data, batch_size=32)))
sample_test_images, sample_test_masks = next(iter(DataLoader(test_data, batch_size=32)))

# Evaluate network on sample images
sample_train_predictions = torch.sigmoid(net(sample_train_images.expand(-1, 3, -1, -1).to(device))["out"]).detach().cpu().numpy()
sample_test_predictions = torch.sigmoid(net(sample_test_images.expand(-1, 3, -1, -1).to(device))["out"]).detach().cpu().numpy()

# Samples to numpy (for visulization purposes)
sample_train_images, sample_train_masks = sample_train_images.cpu().detach().numpy(), sample_train_masks.cpu().detach().numpy()
sample_test_images, sample_test_masks = sample_test_images.cpu().detach().numpy(), sample_test_masks.cpu().detach().numpy()

a = 0.7

plt.rcParams["figure.figsize"] = (10,20)

for sample_id in range(16):
    f, axarr = plt.subplots(1,2)
    
    # Show original image
    axarr[0].imshow(sample_train_images[sample_id][0], cmap="binary")
    axarr[0].title.set_text("Original image")
    
    # Generate and show image + mask overlay
    axarr[1].imshow(sample_train_images[sample_id][0], cmap="binary")
    stomach_mask = np.ma.masked_where(sample_train_masks[sample_id][0] == 0, sample_train_masks[sample_id][0])
    small_bowel_mask = np.ma.masked_where(sample_train_masks[sample_id][1] == 0, sample_train_masks[sample_id][1])
    large_bowel_mask = np.ma.masked_where(sample_train_masks[sample_id][2] == 0, sample_train_masks[sample_id][2])
    axarr[1].imshow(stomach_mask, cmap="Greens", alpha=a)
    axarr[1].imshow(small_bowel_mask, cmap="Reds", alpha=a)
    axarr[1].imshow(large_bowel_mask, cmap="Blues", alpha=a)
    axarr[1].title.set_text("Ground truth overlay")
    
    patches = [mpatches.Patch(color='green',label='stomach'), mpatches.Patch(color='red',label='small_bowel'), mpatches.Patch(color='blue',label='large_bowel')]
    plt.legend(handles=patches, bbox_to_anchor=(1.2, 1.0))
    plt.show()

In [None]:
plt.rcParams["figure.figsize"] = (20,20)
threshold = best_threshold

sample_train_predictions[sample_train_predictions > threshold] = 1.
sample_train_predictions[sample_train_predictions < 1.] = 0.

sample_test_predictions[sample_test_predictions > threshold] = 1.
sample_test_predictions[sample_test_predictions < 1.] = 0.

for sample_id in range(16):
    f, axarr = plt.subplots(1,3)
    
    # Show original image
    axarr[0].imshow(sample_train_images[sample_id][0], cmap="binary")
    axarr[0].title.set_text("Original image")
    
    # Generate and show image + mask overlay
    axarr[1].imshow(sample_train_images[sample_id][0], cmap="binary")
    stomach_mask = np.ma.masked_where(sample_train_masks[sample_id][0] == 0, sample_train_masks[sample_id][0])
    small_bowel_mask = np.ma.masked_where(sample_train_masks[sample_id][1] == 0, sample_train_masks[sample_id][1])
    large_bowel_mask = np.ma.masked_where(sample_train_masks[sample_id][2] == 0, sample_train_masks[sample_id][2])
    axarr[1].imshow(stomach_mask, cmap="Greens", alpha=a)
    axarr[1].imshow(small_bowel_mask, cmap="Reds", alpha=a)
    axarr[1].imshow(large_bowel_mask, cmap="Blues", alpha=a)
    axarr[1].title.set_text("Ground truth overlay")
    
    
    # Generate and show image + predicted mask overlay
    axarr[2].imshow(sample_train_images[sample_id][0], cmap="binary")
    predicted_stomach_mask = np.ma.masked_where(sample_train_predictions[sample_id][0] == 0, sample_train_predictions[sample_id][0])
    predicted_small_bowel_mask = np.ma.masked_where(sample_train_predictions[sample_id][1] == 0, sample_train_predictions[sample_id][1])
    predicted_large_bowel_mask = np.ma.masked_where(sample_train_predictions[sample_id][2] == 0, sample_train_predictions[sample_id][2])
    axarr[2].imshow(predicted_stomach_mask, cmap="Greens", alpha=a, vmin=0.0, vmax=1.0)
    axarr[2].imshow(predicted_small_bowel_mask, cmap="Reds", alpha=a, vmin=0.0, vmax=1.0)
    axarr[2].imshow(predicted_large_bowel_mask, cmap="Blues", alpha=a, vmin=0.0, vmax=1.0)
    axarr[2].title.set_text("Prediction overlay")
    
    patches = [mpatches.Patch(color='green',label='stomach'), mpatches.Patch(color='red',label='small_bowel'), mpatches.Patch(color='blue',label='large_bowel')]
    plt.legend(handles=patches, bbox_to_anchor=(1.2, 1.0))
    plt.show()

# Plot sample from test set
print('========================================= TEST IMAGES ===============================================')

for sample_id in range(32):
    f, axarr = plt.subplots(1,3)
    
    # Show original image
    axarr[0].imshow(sample_test_images[sample_id][0], cmap="binary")
    axarr[0].title.set_text("Original image")
    
    # Generate and show image + mask overlay
    axarr[1].imshow(sample_test_images[sample_id][0], cmap="binary")
    stomach_mask = np.ma.masked_where(sample_test_masks[sample_id][0] == 0, sample_test_masks[sample_id][0])
    small_bowel_mask = np.ma.masked_where(sample_test_masks[sample_id][1] == 0, sample_test_masks[sample_id][1])
    large_bowel_mask = np.ma.masked_where(sample_test_masks[sample_id][2] == 0, sample_test_masks[sample_id][2])
    axarr[1].imshow(stomach_mask, cmap="Greens", alpha=a, vmin=0.0, vmax=1.0)
    axarr[1].imshow(small_bowel_mask, cmap="Reds", alpha=a, vmin=0.0, vmax=1.0)
    axarr[1].imshow(large_bowel_mask, cmap="Blues", alpha=a, vmin=0.0, vmax=1.0)
    axarr[1].title.set_text("Ground truth overlay")
    
    # Generate and show image + predicted mask overlay
    axarr[2].imshow(sample_test_images[sample_id][0], cmap="binary")
    predicted_stomach_mask = np.ma.masked_where(sample_test_predictions[sample_id][0] == 0, sample_test_predictions[sample_id][0])
    predicted_small_bowel_mask = np.ma.masked_where(sample_test_predictions[sample_id][1] == 0, sample_test_predictions[sample_id][1])
    predicted_large_bowel_mask = np.ma.masked_where(sample_test_predictions[sample_id][2] == 0, sample_test_predictions[sample_id][2])
    axarr[2].imshow(predicted_stomach_mask, cmap="Greens", alpha=a, vmin=0.0, vmax=1.0)
    axarr[2].imshow(predicted_small_bowel_mask, cmap="Reds", alpha=a, vmin=0.0, vmax=1.0)
    axarr[2].imshow(predicted_large_bowel_mask, cmap="Blues", alpha=a, vmin=0.0, vmax=1.0)
    axarr[2].title.set_text("Prediction overlay")
    
    patches = [mpatches.Patch(color='green',label='stomach'), mpatches.Patch(color='red',label='small_bowel'), mpatches.Patch(color='blue',label='large_bowel')]
    plt.legend(handles=patches, bbox_to_anchor=(1.2, 1.0))
    plt.show()