In [None]:
import os
data_folder = "../input/uw-madison-gi-tract-image-segmentation/" if os.environ.get("KAGGLE_KERNEL_RUN_TYPE", "") else "./data/"

# List all imports below
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scv_utility import *
import torch
from torchvision import transforms

np.random.seed(0)
torch.manual_seed(0)
pd.set_option("display.width", 120)

In [None]:
# Load small train and test datasets with only stomach labels
labels = pd.read_csv(data_folder + "train.csv", converters={"id": str, "class": str, "segmentation": str})
print(f"Classes in train set: {labels['class'].unique()}")

all_cases = get_all_cases(data_folder)
train_cases = all_cases[:30]
val_cases = all_cases[30:35]
test_cases = all_cases[35:]

# Toy data; uncomment and comment the above values
# train_cases = ["case2_", "case7_", "case15_", "case20_", "case22_", "case24_", "case29_", "case30_", "case32_", "case123_"]
# val_cases = ["case146_", "case147_", "case148_"]
# test_cases = ["case156_", "case154_", "case149_"]

train_labels = labels[labels["id"].str.contains("|".join(train_cases))]
val_labels = labels[labels["id"].str.contains("|".join(val_cases))]
test_labels = labels[labels["id"].str.contains("|".join(test_cases))]
print(f"Data split sizes: train: {len(train_labels)}, val: {len(val_labels)}, test: {len(test_labels)}")

## Classification network

In [None]:
TRAIN_CLASSIFICATION = False

if TRAIN_CLASSIFICATION:
    # Training parameters
    batch_size = 32
    learning_rate = 0.001
    criterion = torch.nn.BCEWithLogitsLoss()
    max_epochs = 30

    # Try using gpu instead of cpu
    device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

    train_labels = train_labels[(train_labels["class"] == "stomach")]
    val_labels = val_labels[(val_labels["class"] == "stomach")]
    test_labels = test_labels[(test_labels["class"] == "stomach")]
    train_data = MRIClassificationDataset(data_folder, train_labels, transform = transforms.Compose([
                                               Rescale((266,266)), 
                                               RandomCrop(),
                                               LabelSmoothing(p=0.4)]))
    val_data = MRIClassificationDataset(data_folder, val_labels, transform = transforms.Compose([Rescale((266,266))]))
    test_data = MRIClassificationDataset(data_folder, test_labels, transform = transforms.Compose([Rescale((266,266))]))
    print(f"Number of train images: {len(train_data)}, val images: {len(val_data)}, test images: {len(test_data)}")

    # Initialize network
    from torchvision.models import resnet50
    net = torch.nn.Sequential(resnet50(pretrained=False), torch.nn.Linear(1000, 3))
    net.to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

    # Training loop
    train_losses, val_losses, test_losses = train(net, train_data, val_data, test_data, criterion, optimizer, batch_size, max_epochs, "classifier")

    plt.plot(list(range(1, len(train_losses) + 1)), train_losses, label="Train set")
    plt.plot(list(range(1, len(val_losses) + 1)), val_losses, label="Validation set")
    plt.plot(list(range(1, len(test_losses) + 1)), test_losses, label="Test set")
    plt.xlabel("Epoch")
    plt.ylabel("Loss (BCE with logits)")
    plt.legend()
    plt.show()

## Segmentation network

In [None]:
TRAIN_SEGMENTATION = True
ONLY_NON_EMPTY_GT = True

if TRAIN_SEGMENTATION:
    # Training parameters
    batch_size = 32
    learning_rate = 0.0001
    criterion = torch.nn.BCEWithLogitsLoss()
    max_epochs = 30

    # Try using gpu instead of cpu
    device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

    if ONLY_NON_EMPTY_GT:
        train_labels = train_labels[train_labels["segmentation"] != ""]
        val_labels = val_labels[val_labels["segmentation"] != ""]
        test_labels = test_labels[test_labels["segmentation"] != ""]
    
    train_data = MRISegmentationDataset(data_folder, train_labels, transform = transforms.Compose([
                                               Rescale((266,266)), 
                                               RandomCrop(),
                                               LabelSmoothing(p=0.4),
                                               Normalize(mean=0.458, std=0.229)]))
    val_data = MRISegmentationDataset(data_folder, val_labels, transform = transforms.Compose([Rescale((266,266)), Normalize(mean=0.458, std=0.229)]))
    test_data = MRISegmentationDataset(data_folder, test_labels, transform = transforms.Compose([Rescale((266,266)), Normalize(mean=0.458, std=0.229)]))
    print(f"Number of train images: {len(train_data)}, val images: {len(val_data)}, test images: {len(test_data)}")

    # Initialize network
    from torchvision.models.segmentation import fcn_resnet50
    net = fcn_resnet50(pretrained=False, num_classes=3)
    net.to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

    # Training loop
    start_training_time = time.time()
    train_losses, val_losses, test_losses = train(net, train_data, val_data, test_data, criterion, optimizer, batch_size, max_epochs, "segmentation", lambda out : out["out"])
    end_training_time = time.time()
    time_training_lapsed = end_training_time - start_training_time
    time_convert("training", time_training_lapsed)
    
    plt.plot(list(range(1, len(train_losses) + 1)), train_losses, label="Train set")
    plt.plot(list(range(1, len(val_losses) + 1)), val_losses, label="Validation set")
    plt.plot(list(range(1, len(test_losses) + 1)), test_losses, label="Test set")
    plt.xlabel("Epoch")
    plt.ylabel("Loss (BCE with logits)")
    plt.legend()
    plt.show()