# Mitral Valve Segmentation

## 1. Setup

### 1.1. Library Imports

In [None]:
import random
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
from sklearn.metrics import jaccard_score
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, random_split
from torchvision.transforms import v2

from src.dataset import EchoDataset
from src.model import UNet
from src.train_utils import Trainer, WeightedCE

### 1.2. Configuration

In [None]:
ROOT = Path(".")

PATH_TO_DATA = ROOT / "data"
PATH_TO_TRAIN = PATH_TO_DATA / "train.pkl"
PATH_TO_TEST = PATH_TO_DATA / "test.pkl"

PATH_TO_OUTPUTS = ROOT / "outputs"
PATH_TO_MODELS = PATH_TO_OUTPUTS / "models"


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(DEVICE)


##################################################
# Reproducibility
##################################################
SEED = 7
random.seed(SEED)  # Used in some of the `torchvision` transformations.
RS_NUMPY = np.random.default_rng(SEED)
RS_TORCH = torch.manual_seed(SEED)


# Warning: the settings commented below may slow down the execution time.

# # PyTorch will only use deterministic operations.
# # If no deterministic alternative exist, an error will be raised.
# torch.use_deterministic_algorithms(True)


# # Reproducibility when using GPUs

# # Choice of algorithms (in `cuDNN`) is deterministic.
# torch.backends.cudnn.benchmark = False

# # Algorithms themselves (only the ones in `cuDNN`) are deterministic.
# torch.backends.cudnn.deterministic = True

"""
In some CUDA versions:
- We need to set as well the `CUBLAS_WORKSPACE_CONFIG` environment variable.
- RNN and LSTM networks may have non-deterministic behavior.
"""


##################################################
# Hyperparameters
##################################################
N_EPOCHS = 80

# Dataset
IMG_SIZE = (320, 384)  # (W, H) Used by the resize transformation.
N_CLASSES = 2  # Background (0), Mitral Valve (1).


# Data Loader
BATCH_SIZE = 25
NUM_WORKERS = 2


# Model
MODEL_PARAMS = {
    "in_channels": 1,
    "out_channels": N_CLASSES,
    "mid_channels": [64, 128, 256, 512, 1024],
    "kernel_size": 3,
    "max_pool_kernel_size": 2,
    "up_kernel_size": 2,
    "up_stride": 2,
    "dropout_p": 0.0,
}
model = UNet(**MODEL_PARAMS)


# Loss Function
# Weighting for classes to address imbalance.
CLASS_WEIGHT = torch.tensor([1.0, 10.0], dtype=torch.float)
# Weighting for data importance based on labeling source ('expert' vs. 'amateur' labeling).
DATA_WEIGHT = torch.tensor([1.0, 3.0], dtype=torch.float)
LOSS_PARAMS = {"class_weight": CLASS_WEIGHT, "data_weight": DATA_WEIGHT}
LOSS_FN = WeightedCE(**LOSS_PARAMS)


# Optimizer
OPTIMIZER_CLASS = Adam
OPTIMIZER_PARAMS = {"lr": 0.01}
optimizer = OPTIMIZER_CLASS(model.parameters(), **OPTIMIZER_PARAMS)


# Scheduler
SCHEDULER_CLASS = StepLR
SCHEDULER_PARAMS = {"gamma": 0.9, "step_size": 15}
scheduler = SCHEDULER_CLASS(optimizer, **SCHEDULER_PARAMS)

## 2. Data Preparation

### 2.1. Dataset Loading and Transformations

In [None]:
def seed_worker(worker_id):
    """
    Controlling randomness in multi-process data loading. The RNGs are used by the image transformations
    and the RandomSampler, which generates random indices for data shuffling.
    """
    # Use `torch.initial_seed` to access the PyTorch seed set for each worker.
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)


"""
Data augmentation should create realistic variations likely to be encountered in clinical settings.
This ensures the model's generalizability and effectiveness in real-world scenarios without training on implausible data.
"""

# Pixel values are in the range [0, 1]. To "normalize" them to the range [-1, 1], set `mean=0.5` and `std=0.5`.
transforms = v2.Compose(
    [
        v2.RandomRotation(degrees=(-15, 15)),
        v2.ColorJitter(brightness=(0.8, 1.2), contrast=(0.8, 1.2)),
        v2.RandomAffine(degrees=0, translate=(0.05, 0.05), scale=(0.9, 1.1)),
        v2.Normalize(mean=(0.5,), std=(0.5,)),
    ]
)

dataset = EchoDataset(PATH_TO_TRAIN, IMG_SIZE)
# We were actually given a test dataset. However, this one didn't come with the labels.
# Therefore, to assess the model's performance, we partition the training dataset in two.
train_dataset, test_dataset = random_split(
    dataset, lengths=[0.9, 0.1], generator=RS_TORCH
)
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    shuffle=True,
    worker_init_fn=seed_worker,
    generator=RS_TORCH,
)
test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    shuffle=False,
    worker_init_fn=seed_worker,
    generator=RS_TORCH,
)

### 2.2. Data Visualization

In [None]:
def plot_echo(img, target=None, prediction=None):
    """Plots an echocardiogram with optional ground truth and prediction segmentations."""
    subplot_count = 1 if target is None else (3 if prediction is not None else 2)
    plt.figure(figsize=(5 * subplot_count, 15))

    img = img.permute(2, 1, 0).detach().cpu().numpy()

    plt.subplot(1, subplot_count, 1)
    plt.imshow(img, cmap="gray")
    plt.title("Echocardiogram")

    if target is not None:
        target = target.permute(1, 0).detach().cpu().numpy()

        plt.subplot(1, subplot_count, 2)
        plt.imshow(img, cmap="gray")
        plt.imshow(target, alpha=0.5, cmap="binary_r")
        plt.title("Ground Truth")

    if prediction is not None:
        prediction = prediction.permute(1, 0).detach().cpu().numpy()

        plt.subplot(1, subplot_count, 3)
        plt.imshow(img, cmap="gray")
        plt.imshow(prediction, alpha=0.5, cmap="binary_r")
        plt.title("Prediction")

    plt.show()


frame_idx = 0
frame, segmentation, _ = dataset[frame_idx]
plot_echo(frame, segmentation)

## 3. Training Process

In [None]:
trainer = Trainer(
    train_loader,
    model,
    LOSS_FN,
    optimizer,
    scheduler,
    N_EPOCHS,
    DEVICE,
    PATH_TO_OUTPUTS,
)

# Start training process.
trainer.train()

# Save trainer state.
trainer.save()

## 4. Evaluation

### 4.1. Metrics Computation

In [None]:
def mean_jaccard_score(prediction, target):
    """Calculate the mean Jaccard score."""
    if target.ndim > 2:
        score = np.mean(
            [
                jaccard_score(target[i], prediction[i], average="micro")
                for i in range(target.shape[0])
            ]
        )
    else:
        score = jaccard_score(target, prediction, average="micro")

    return score


@torch.no_grad()
def evaluate(eval_loader, model, eval_metric_fn, device):
    """Compute the evaluation metric for a model over a given DataLoader."""
    eval_mean = 0

    model.eval()
    with torch.no_grad():
        for inputs, target, _ in eval_loader:
            inputs = inputs.to(device)

            prediction = model.predict(inputs).detach().cpu()

            eval_mean += eval_metric_fn(prediction, target).item()

    return eval_mean / len(eval_loader)

In [None]:
def load_latest_model(path_to_models, model, device):
    """Load the most recently trained model from the outputs directory."""
    try:
        latest_model_path = sorted(Path(path_to_models).iterdir())[-1]
    except IndexError:
        raise FileNotFoundError(
            "No model files found in the directory: {}".format(path_to_models)
        )

    state = torch.load(latest_model_path, map_location=device)
    model.load_state_dict(state["model_state_dict"])

    return model.to(device)


if "trainer" in globals() and hasattr(trainer, "model"):
    model = trainer.model
    print("Using recently trained model.")
else:
    # If no model was trained (i.e., training process cell was not executed),
    # load the most recently trained model.
    model = load_latest_model(PATH_TO_MODELS, model, DEVICE)
    print("Using last saved model.")


# Compute evaluation metric on the test dataset.
eval_metric = evaluate(test_loader, model, mean_jaccard_score, DEVICE)
print(f"Jaccard Score: {eval_metric}")

### 4.2. Result Visualization

In [None]:
# Plot a prediction.
model.eval()
obs_idx = 0
with torch.no_grad():
    inputs, targets, _ = next(iter(test_loader))
    inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

    predictions = model.predict(inputs)

    plot_echo(inputs[obs_idx], targets[obs_idx], predictions[obs_idx])