In [None]:
import os
os.chdir('/content/drive/MyDrive/python/Computer_Vision/TDT4265_Project/')

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from src.data import DatasetLoader, train_test_val_split
from src.visualize import plot_loss_acc
from src.model import Unet2D
from src.train import train
from src.metrics import acc_metric, dice_metric
from src.utils import *

import albumentations as A
from pathlib import Path

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, sampler

## Parameters

In [None]:
#batch size
bs = 8 #12

#epochs
epochs_val =  1#50

#learning rate
learn_rate = 0.01

# Load the data (raw and gt images)
base_path = Path('Data') # /work/datasets/medical_project
dataset = "extracted_CAMUS" #CAMUS_resized

# Resolution of the image - Watch out Memory usage (Naive)
# (1, 1.5 or 2, 3, 4) 
scale = 1 
image_resolution =  int(384 * scale)  

## Prepare data

In [None]:
train_files, test_files, val_files = train_test_val_split(base_path, dataset)

In [None]:
# The order is relevant here, so be careful when you put something
prep_steps = ['GaussBlur'
            #'MedianFilter',
            #'bright',
            #'EDGE_ENHANCE',
            #'MedianFilter'
            #'Sharp'
            #'MaxFilter'
            ]

pre_processing_verbose(prep_steps)

In [None]:
train_transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.2),
])

In [None]:
train_dataset = DatasetLoader(train_files,
                              Path.joinpath(base_path, dataset, 'train_gt'),
                              prep_steps = prep_steps,
                              transform = train_transform,
                              image_resolution = image_resolution)

test_dataset = DatasetLoader(test_files,
                             Path.joinpath(base_path, dataset, 'train_gt'),
                              prep_steps = prep_steps,
                              image_resolution = image_resolution)

valid_dataset = DatasetLoader(val_files,
                              Path.joinpath(base_path, dataset, 'train_gt'),
                              prep_steps = prep_steps,
                              image_resolution = image_resolution)

train_data = DataLoader(train_dataset, batch_size=bs, shuffle=True)
test_data = DataLoader(test_dataset, batch_size=bs, shuffle=True)
valid_data = DataLoader(valid_dataset, batch_size=bs, shuffle=True)
print(f"\nItems loaded: {len(train_dataset)+len(test_dataset)+len(valid_dataset)} [training: {len(train_dataset)}, test: {len(test_dataset)}, valid: {len(valid_dataset)}]")

# Visualize shape of raw and ground true images
xb, yb = next(iter(train_data))
print(f"RAW IMAGES: {xb.shape}\n GT IMAGES: {yb.shape}\n")

## Train

In [None]:
# MODEL: Unet2D (one input channel, 4 output channels)
# Outputs: Probabilities for each class for each pixel in different layer)
unet = Unet2D(1, out_channels=4)
#if pretrained:
#    unet.load_state_dict(torch.load(model_path))

loss_fn = nn.CrossEntropyLoss()
opt = torch.optim.Adam(unet.parameters(), lr=learn_rate)

# Training process
train_loss, test_loss, train_acc, test_acc = train(unet, train_data, test_data, loss_fn, opt, acc_metric, epochs=epochs_val)

In [None]:
plot_loss_acc(train_loss, test_loss, train_acc, test_acc)

In [None]:
# Save the model
if model_path is not None:
    if not os.path.exists(model_path):
        os.makedirs(model_path)

    # Save performance
    now = datetime.now().strftime('%d/%m/%Y %H:%M:%S')
    with open(model_path + "/performance.txt", "a") as text_file:
        print(f"data:{now}\ndataset:{dataset}\n"
              f"epoch:{epochs_val}\nimage_resolution:{image_resolution}\n" 
              f"pre_proc:{pre_processing_steps}\nacc:{round(accuracy, 4)}\n"
              f"avg_dice:{round(average_dice, 4)}\nclass_dice_scores:{str(class_dice)}\n", 
              file = text_file)

    # Save model
    torch.save(unet.state_dict(), model_path + file_name)
    print(f"Model state has been saved in /{model_path}")


## Test

In [None]:
# Predict on the validation data
xb, yb = next(iter(valid_data))
with torch.no_grad():
    predb = unet(to_cuda(xb))

# Evaluation - Accuracy
accuracy = acc_metric(predb, yb).item()
baseline_accuracy = 0.9705810547 # TRAINING TIME: 102m 6s
print(f"\nFinal Accuracy: {round(accuracy, 4)} (delta to baseline {round(accuracy - baseline_accuracy, 4)})")

# Evaluation - Dice score
average_dice, class_dice = dice_metric(predb, yb)
baseline_dice =  0.607425 # [0.9652, 0.5956, 0.3764, 0.4925]
print(f"Final average DICE score: {round(average_dice, 4)} {class_dice} (delta to baseline {round(average_dice - baseline_dice, 4)})")