# `qcardia-models` tutorial

In [1]:
from pathlib import Path

import torch
import yaml

from qcardia_models.losses import DiceCELoss, EarlyStopper, MultiScaleLoss, NTXentLoss
from qcardia_models.metrics import DiceMetric
from qcardia_models.models import EncoderMLP2d, UNet2d
from qcardia_models.utils import seed_everything

# U-Net

In [2]:
# get config dict from file
unet_config_path = Path("unet-config.yaml")
unet_config = yaml.load(unet_config_path.open(), Loader=yaml.FullLoader)

### Dummy data

In [3]:
# attributes of random dummy data for demo
nr_input_channels = unet_config["unet"]["nr_image_channels"]
nr_output_channels = unet_config["unet"]["nr_output_classes"]
height, width = 64, 64
batch_size = 4

# image batch with one channel
images = torch.rand(batch_size, nr_input_channels, height, width)

# label batch with one-hot encoding
labels = torch.zeros((batch_size, nr_output_channels, height, width))
idxs = torch.randint(0, nr_output_channels, (batch_size, 1, height, width))
labels = torch.scatter(labels, 1, idxs, 1.0)

# position labels for positional contrastive loss (normalized positions), with every two
# images sharing the same position, simulated by the repeat_interleave function.
position_labels = torch.rand(batch_size // 2).repeat_interleave(2)

print(f"images shape: {images.shape}")
print(f"labels shape: {labels.shape}")
print(f"position shape: {position_labels.shape}")

images shape: torch.Size([4, 1, 64, 64])
labels shape: torch.Size([4, 4, 64, 64])
position shape: torch.Size([4])


### Model
Only shows how to initialize and then use the model. For model visualization: `model_visualization.ipynb`

In [4]:
# 2D U-Net model
unet_model = UNet2d(
    nr_input_channels=unet_config["unet"]["nr_image_channels"],
    channels_list=unet_config["unet"]["channels_list"],
    nr_output_classes=unet_config["unet"]["nr_output_classes"],
    nr_output_scales=unet_config["unet"]["nr_output_scales"],
)

# Setting gradients to temporary freeze/unfreeze encoder (finetuning)
unet_model.set_encoder_requires_grad(False)  # freeze encoder
unet_model.set_encoder_requires_grad(True)  # unfreeze encoder

In [5]:
# forward pass to get
outputs = unet_model(images)
print(f"outputs list length: {len(outputs)}, with shapes:")
for i, output in enumerate(outputs):
    print(f"  {i}\t{output.shape}")

# length of list depends on nr_output_scales setting, indicating how many scales
# (decoder blocks) should give an output. Negative numbers reduce the number of scales
# by that amount, starting with ommiting the deepest scales.
nr_output_scales = unet_config["unet"]["nr_output_scales"]
print(f"\nnr_output_scales: {nr_output_scales}")

outputs list length: 6, with shapes:
  0	torch.Size([4, 4, 64, 64])
  1	torch.Size([4, 4, 32, 32])
  2	torch.Size([4, 4, 16, 16])
  3	torch.Size([4, 4, 8, 8])
  4	torch.Size([4, 4, 4, 4])
  5	torch.Size([4, 4, 2, 2])

nr_output_scales: -1


### Loss

In [6]:
# Dice and CrossEntropy loss for deep supervision
loss_function = MultiScaleLoss(
    loss_function=DiceCELoss(
        cross_entropy_loss_weight=unet_config["loss"]["cross_entropy_loss_weight"],
        dice_loss_weight=unet_config["loss"]["dice_loss_weight"],
        dice_classes_weights=unet_config["loss"]["dice_classes_weights"],
    ),
)

loss, ce_loss, dice_loss = loss_function(outputs, labels)

print(f"total loss: {loss.item():0.3f}")
print(f"ce loss component: {ce_loss.item():0.3f}")
print(f"dice loss component: {dice_loss.item():0.3f}")

total loss: 1.104
ce loss component: 1.449
dice loss component: 0.759


### Metric

In [7]:
# Dice metric for specified classes
dice_metric = DiceMetric(unet_config["metrics"]["dice_class_idxs"])

dice_scores = dice_metric(outputs[0], labels)
print(f"dice scores: {dice_scores}")  # only include classes in dice_class_idxs
print(f"dice mean: {dice_scores.mean().item():0.3f}")

dice scores: tensor([0.0634, 0.1986, 0.3389])
dice mean: 0.200


# Encoder + MLP

In [8]:
# get config dict from file
encodermlp_config_path = Path("encodermlp-config.yaml")
encodermlp_config = yaml.load(encodermlp_config_path.open(), Loader=yaml.FullLoader)

### Model
Only shows how to initialize and then use the model. For model visualization: `model_visualization.ipynb`

In [9]:
# 2D Encoder-MLP model
encodermlp_model = EncoderMLP2d(
    nr_input_channels=encodermlp_config["encoder"]["nr_image_channels"],
    encoder_channels_list=encodermlp_config["encoder"]["channels_list"],
    mlp_channels_list=encodermlp_config["mlp"]["channels_list"],
)

In [10]:
# forward pass to get
output = encodermlp_model(images)
print(f"output tensor shape: {output.shape}")

output tensor shape: torch.Size([4, 512])


### Loss

In [11]:
ntxent_threshold = encodermlp_config["loss"]["ntxent_supervised_threshold"]
loss_function = NTXentLoss(
    temperature=encodermlp_config["loss"]["ntxent_temperature"],
    threshold=ntxent_threshold if isinstance(ntxent_threshold, float) else None,
    cyclic_relative_labels=encodermlp_config["loss"]["ntxent_supervised_cyclic_labels"],
)

loss = loss_function(output, position_labels)  # calculate loss
print(loss)

tensor(1.1159, grad_fn=<NegBackward0>)


# General

### Seeding

In [12]:
# seed the random number generators for random, numpy and torch packages
seed = 0
seed_everything(seed)

### Weight loading

In [13]:
# load weights if specified
if not unet_config["unet"]["weights_path"].lower() == "none":
    weights_path = Path(unet_config["unet"]["weights_path"])
    if not weights_path.exists():
        raise FileNotFoundError(f"weights not found at {weights_path}")
    nr_epochs_frozen_encoder = unet_config["training"]["nr_epochs_frozen_encoder"]
    state_dict = torch.load(weights_path)

    # check how many keys match between weights and model
    unet_keys = unet_model.state_dict().keys()
    nr_matching_keys = sum([key in unet_keys for key in state_dict])
    if nr_matching_keys == 0:
        raise ValueError("No keys match between weights and model.")

    # load weights
    unet_model.load_state_dict(state_dict, strict=False)

    # freeze encoder if specified
    if nr_epochs_frozen_encoder > 0:
        unet_model.set_encoder_requires_grad(False)

### Early stopping

In [14]:
# Initialize early stopper
if unet_config["training"]["early_stopping"]["active"]:
    early_stopper = EarlyStopper(
        patience=unet_config["training"]["early_stopping"]["patience"],
        min_delta=unet_config["training"]["early_stopping"]["min_delta"],
    )

# Example training loop
for epoch_nr in range(1):
    valid_loss = 0.0

    # Check if early stopping should be performed
    if unet_config["training"]["early_stopping"]["active"] and early_stopper.early_stop(
        valid_loss
    ):
        break