At this stage, you have a complete dataset of images and annotations tailored to our study, so now let’s develop a training script for an AI model to detect data matrix code cells.

To better understand the design of our model’s architecture, I strongly encourage you to refer to section 'x' of my report 'y'.

Let’s start by loading all the libraries we will need.

In [1]:
import os
import math
import json
import datetime
import argparse
import subprocess
import threading
import time
import random
import numpy as np
import pandas as pd
from PIL import Image
import cv2
from scipy.spatial.distance import cdist
from scipy.ndimage import maximum_filter
from sklearn.metrics import average_precision_score
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
import torchvision.models as models
import albumentations as A
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

  check_for_updates()


Let's define the class that will let us handle our data (annotations and images), build heatmaps, and apply visual transformations so that, at each epoch, the images and annotations seen by our model are never the same—thus limiting the risk of overfitting.

In [2]:
class KeypointDatasetAugmented(Dataset):
    def __init__(self, coco_json, img_dir, img_list=None, sigma=2, target_size=(512, 512),
                 augment=True, augment_prob=0.8, unified_cells=False):

        with open(coco_json, 'r') as f:
            self.coco_data = json.load(f)

        self._clean_invalid_keypoints()
        self.img_dir = img_dir
        self.sigma = sigma
        self.target_size = target_size
        self.augment = augment
        self.augment_prob = augment_prob
        self.unified_cells = unified_cells


        self.id_list = {ann['id'] for ann in self.coco_data['annotations']}
        self.id_to_image_id = {ann['id']: ann['image_id'] for ann in self.coco_data['annotations']}
        self.img_id_to_file = {img['id']: img['file_name'] for img in self.coco_data['images']}
        self.id_to_keypoints = {ann['id']: ann['keypoints'] for ann in self.coco_data['annotations']}


        self.img_to_anns = {}
        for ann in self.coco_data['annotations']:
            img_id = ann['image_id']
            if img_id not in self.img_to_anns:
                self.img_to_anns[img_id] = []
            self.img_to_anns[img_id].append(ann)


        if img_list:
            self.img_ids = [
                img_id for img_id in self.img_id_to_file
                if self.img_id_to_file[img_id] in img_list and img_id in self.img_to_anns
            ]
        else:
            self.img_ids = list(self.img_to_anns.keys())

        if self.augment:
            self.geometric_transform = A.Compose([
                A.Rotate(limit=15, p=0.85, border_mode=cv2.BORDER_REFLECT),
                A.HorizontalFlip(p=0.15),
                A.VerticalFlip(p=0),
                A.ShiftScaleRotate(
                    shift_limit=0.1,
                    scale_limit=0.2,
                    rotate_limit=15,
                    border_mode=cv2.BORDER_REFLECT,
                    p=0
                ),
            ], keypoint_params=A.KeypointParams(format='xy', remove_invisible=True))

            self.photometric_transform = A.Compose([
                A.RandomBrightnessContrast(
                    brightness_limit=0.3,
                    contrast_limit=0.2,
                    p=0.7
                ),
                A.ColorJitter(
                    brightness=0.2,
                    contrast=0.2,
                    saturation=0.2,
                    hue=0.1,
                    p=0.5
                ),
                A.RandomGamma(gamma_limit=(80, 120), p=0.4),
                A.HueSaturationValue(
                    hue_shift_limit=20,
                    sat_shift_limit=30,
                    val_shift_limit=20,
                    p=0.5
                ),
                A.OneOf([
                    A.GaussNoise(p=0.3),
                    A.ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.5), p=0.3),
                ], p=0.4),
                A.OneOf([
                    A.MotionBlur(blur_limit=3, p=0.3),
                    A.GaussianBlur(blur_limit=3, p=0.3),
                ], p=0.3),
            ])
        else:
            self.geometric_transform = None
            self.photometric_transform = None

    def _clean_invalid_keypoints(self):
        cleaned_annotations = []
        total_keypoints = 0
        invalid_keypoints = 0

        for ann in self.coco_data['annotations']:
            kpts = ann['keypoints']
            cleaned_kpts = []

            for i in range(0, len(kpts), 3):
                x, y, v = kpts[i:i + 3]
                total_keypoints += 1
                if np.isfinite(x) and np.isfinite(y) and not (np.isnan(x) or np.isnan(y)):
                    cleaned_kpts.extend([float(x), float(y), int(v)])
                else:
                    cleaned_kpts.extend([0.0, 0.0, 0])
                    invalid_keypoints += 1

            ann['keypoints'] = cleaned_kpts
            cleaned_annotations.append(ann)

        self.coco_data['annotations'] = cleaned_annotations


    def __len__(self):
        return len(self.img_ids)

    def _extract_keypoints(self, annotations, original_width, original_height):
        keypoints = []
        keypoint_info = []
        invalid_count = 0

        for ann in annotations:
            kpts = ann['keypoints']
            for i in range(0, len(kpts), 3):
                x, y, v = kpts[i:i + 3]

                if not (np.isfinite(x) and np.isfinite(y)) or np.isnan(x) or np.isnan(y):
                    invalid_count += 1
                    continue

                if v > 0:
                    if 0 <= x < original_width and 0 <= y < original_height:
                        keypoints.append((float(x), float(y)))

                        if self.unified_cells:
                            visibility = 1
                        else:
                            visibility = v

                        keypoint_info.append({
                            'visibility': visibility,
                            'label': 1 if self.unified_cells else v
                        })
                    else:
                        invalid_count += 1

        return keypoints, keypoint_info

    def _apply_geometric_augmentation(self, image, keypoints):
        try:
            if keypoints:
                valid_keypoints = []
                for kp in keypoints:
                    if len(kp) >= 2 and np.isfinite(kp[0]) and np.isfinite(kp[1]):
                        valid_keypoints.append(kp)

                if valid_keypoints:
                    transformed = self.geometric_transform(image=image, keypoints=valid_keypoints)
                    return transformed['image'], transformed['keypoints']
                else:

                    transformed = self.geometric_transform(image=image, keypoints=[])
                    return transformed['image'], []
            else:
                transformed = self.geometric_transform(image=image, keypoints=[])
                return transformed['image'], []
        except Exception as e:
            return image, keypoints

    def _apply_photometric_augmentation(self, image):
        try:
            transformed = self.photometric_transform(image=image)
            return transformed['image']
        except Exception as e:
            return image

    def _create_heatmaps_from_keypoints(self, keypoints, keypoint_info, width, height):
        if self.unified_cells:
            heatmap_unified = np.zeros((height, width), dtype=np.float32)

            for (x, y), info in zip(keypoints, keypoint_info):

                if np.isfinite(x) and np.isfinite(y):
                    self._add_gaussian(heatmap_unified, x, y, self.sigma)

            return heatmap_unified, np.zeros((height, width), dtype=np.float32)
        else:
            heatmap_black = np.zeros((height, width), dtype=np.float32)
            heatmap_white = np.zeros((height, width), dtype=np.float32)

            for (x, y), info in zip(keypoints, keypoint_info):
                if np.isfinite(x) and np.isfinite(y):
                    if info['visibility'] == 2:
                        self._add_gaussian(heatmap_black, x, y, self.sigma)
                    elif info['visibility'] == 1:
                        self._add_gaussian(heatmap_white, x, y, self.sigma)

            return heatmap_black, heatmap_white

    def _add_gaussian(self, heatmap, x, y, sigma):
        height, width = heatmap.shape


        if not (np.isfinite(x) and np.isfinite(y)) or np.isnan(x) or np.isnan(y):
            return

        try:
            x = float(x)
            y = float(y)
            x, y = int(round(x)), int(round(y))
        except (ValueError, OverflowError):
            return

        x = max(0, min(x, width - 1))
        y = max(0, min(y, height - 1))

        size = 6 * sigma + 1
        radius = int(size // 2)

        x0 = int(max(0, x - radius))
        y0 = int(max(0, y - radius))
        x1 = int(min(width, x + radius + 1))
        y1 = int(min(height, y + radius + 1))

        xs = np.arange(x0, x1)
        ys = np.arange(y0, y1)

        if len(xs) == 0 or len(ys) == 0:
            return

        xx, yy = np.meshgrid(xs, ys)

        gaussian = np.exp(-((xx - x) ** 2 + (yy - y) ** 2) / (2 * sigma ** 2))

        try:
            heatmap[y0:y1, x0:x1] = np.maximum(heatmap[y0:y1, x0:x1], gaussian)
        except (ValueError, IndexError):
            pass

    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        img_file = self.img_id_to_file[img_id]
        img_path = os.path.join(self.img_dir, img_file)

        try:
            image = Image.open(img_path).convert('RGB')
            original_width, original_height = image.size


            image = image.resize(self.target_size, Image.LANCZOS)
            image = np.array(image)

            height, width = self.target_size
            annotations = self.img_to_anns[img_id]


            keypoints, keypoint_info = self._extract_keypoints(annotations, original_width, original_height)


            scale_x = width / original_width
            scale_y = height / original_height

            scaled_keypoints = []
            for x, y in keypoints:
                new_x = x * scale_x
                new_y = y * scale_y

                if np.isfinite(new_x) and np.isfinite(new_y):
                    scaled_keypoints.append((new_x, new_y))


            if self.augment and random.random() < self.augment_prob:

                image, scaled_keypoints = self._apply_geometric_augmentation(image, scaled_keypoints)


                image = self._apply_photometric_augmentation(image)


            heatmap_first, heatmap_second = self._create_heatmaps_from_keypoints(
                scaled_keypoints, keypoint_info[:len(scaled_keypoints)], width, height
            )

            if not (np.isfinite(heatmap_first).all() and np.isfinite(heatmap_second).all()):
                heatmap_first = np.nan_to_num(heatmap_first, nan=0.0, posinf=0.0, neginf=0.0)
                heatmap_second = np.nan_to_num(heatmap_second, nan=0.0, posinf=0.0, neginf=0.0)

            image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
            heatmap_first = torch.from_numpy(heatmap_first).unsqueeze(0).float()
            heatmap_second = torch.from_numpy(heatmap_second).unsqueeze(0).float()

            return image, (heatmap_first, heatmap_second)

        except Exception as e:
            height, width = self.target_size
            empty_image = torch.zeros(3, height, width, dtype=torch.float32)
            empty_heatmap = torch.zeros(1, height, width, dtype=torch.float32)
            return empty_image, (empty_heatmap, empty_heatmap)

    def set_augmentation(self, enabled, prob=0.8):
        self.augment = enabled
        self.augment_prob = prob

Define the class that lets us instantiate one of the following backbones: ResNet-18, ResNet-34, or ResNet-50.

In [3]:
class KeypointDetector(nn.Module):
    def __init__(self, backbone='resnet18', pretrained=True, unified_cells=False):
        super().__init__()
        self.unified_cells = unified_cells

        if backbone == 'resnet18':
            if pretrained:
                self.backbone = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
            else:
                self.backbone = models.resnet18(weights=None)
            self.backbone = nn.Sequential(*list(self.backbone.children())[:-2])
            backbone_features = 512
        elif backbone == 'resnet34':
            if pretrained:
                self.backbone = models.resnet34(weights=models.ResNet34_Weights.DEFAULT)
            else:
                self.backbone = models.resnet34(weights=None)
            self.backbone = nn.Sequential(*list(self.backbone.children())[:-2])
            backbone_features = 512
        elif backbone == 'resnet50':
            if pretrained:
                self.backbone = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
            else:
                self.backbone = models.resnet50(weights=None)
            self.backbone = nn.Sequential(*list(self.backbone.children())[:-2])
            backbone_features = 2048
        else:
            raise ValueError(f"Backbone '{backbone}' not supported.")

        output_channels = 1 if unified_cells else 2

        self.upsample = nn.Sequential(
            nn.ConvTranspose2d(backbone_features, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.1),

            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.1),

            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.1),

            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),

            nn.Conv2d(16, output_channels, kernel_size=3, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        features = self.backbone(x)
        heatmaps = self.upsample(features)
        return heatmaps

Let’s now define all the loss functions to obtain two final losses (each combining multiple loss functions): one for training and one for evaluation.

The idea is to swap the training loss to optimize our model, while always keeping the same evaluation loss so we can compare performance across different training losses.

In [4]:
def focal_loss(pred, target, zeta=0.25, eta=2.0):
    bce_loss = nn.functional.binary_cross_entropy(pred, target, reduction='none')
    pt = torch.exp(-bce_loss)
    focal_loss = zeta * (1 - pt) ** eta * bce_loss
    return focal_loss.mean()


EPS = 1e-8

def f(d,SIGMA):
    return np.exp(-(d ** 2) / (2 * SIGMA ** 2))

SIGMA = 2

CONTRAST_BIN_THRESHOLD = f(1*SIGMA,SIGMA)


def L_pull_torch(H: torch.Tensor, G: torch.Tensor, eps: float = EPS) -> torch.Tensor:
    G = G.to(device=H.device, dtype=H.dtype)
    O = torch.ones_like(G)

    cells = G * H
    none_cells = (O - G) * H

    N1 = G.sum()
    N2 = (O - G).sum()

    avg_1 = cells.sum() / (N1 + eps)
    avg_2 = none_cells.sum() / (N2 + eps)

    term1 = ((cells - avg_1 * G) ** 2).sum() / (N1 + eps)
    term2 = ((none_cells - avg_2 * (O - G)) ** 2).sum() / (N2 + eps)

    return 0.5 * (term1 + term2)


def L_push_torch(H: torch.Tensor, G: torch.Tensor, eps: float = EPS) -> torch.Tensor:
    G = G.to(device=H.device, dtype=H.dtype)
    O = torch.ones_like(G)

    cells = G * H
    none_cells = (O - G) * H

    N1 = G.sum()
    N2 = (O - G).sum()

    avg_1 = cells.sum() / (N1 + eps)
    avg_2 = none_cells.sum() / (N2 + eps)

    mat_1 = G - (cells - avg_2 * G) ** 2
    mat_2 = (O - G) - (none_cells - avg_1 * (O - G)) ** 2

    return (F.relu(mat_1).sum() + F.relu(mat_2).sum()) / (N1 + N2 + eps)


def contrast_loss_torch(predicted_heatmaps: torch.Tensor,
                                     target_heatmaps,
                                     sigma: float = 2.0,
                                     unified_cells: bool = False):
    batch_size = predicted_heatmaps.shape[0]
    device = predicted_heatmaps.device
    dtype = predicted_heatmaps.dtype

    bin_thresh = CONTRAST_BIN_THRESHOLD

    total_l_push = predicted_heatmaps.new_zeros(())
    total_l_pull = predicted_heatmaps.new_zeros(())

    for b in range(batch_size):
        if unified_cells:
            pred = predicted_heatmaps[b, 0]
            tgt = target_heatmaps[0][b, 0].to(device=device, dtype=dtype)
            G = (tgt >= bin_thresh).float()
            total_l_push = total_l_push + L_push_torch(pred, G)
            total_l_pull = total_l_pull + L_pull_torch(pred, G)
        else:
            pred_b = predicted_heatmaps[b, 0]
            pred_w = predicted_heatmaps[b, 1]
            tgt_b = target_heatmaps[0][b, 0].to(device=device, dtype=dtype)
            tgt_w = target_heatmaps[1][b, 0].to(device=device, dtype=dtype)
            G_b = (tgt_b >= bin_thresh).float()
            G_w = (tgt_w >= bin_thresh).float()

            total_l_push = total_l_push + L_push_torch(pred_b, G_b) + L_push_torch(pred_w, G_w)
            total_l_pull = total_l_pull + L_pull_torch(pred_b, G_b) + L_pull_torch(pred_w, G_w)

    denom = max(1, batch_size)
    return total_l_push / denom, total_l_pull / denom


def mse_loss_only(outputs, targets, unified_cells=False):
    if unified_cells:
        heatmap_pred = outputs[:, 0, :, :].unsqueeze(1)
        heatmap_true = targets[0]
        mse_loss = nn.functional.mse_loss(heatmap_pred, heatmap_true)
        total_loss = mse_loss
    else:
        heatmap_black_pred = outputs[:, 0, :, :].unsqueeze(1)
        heatmap_white_pred = outputs[:, 1, :, :].unsqueeze(1)

        heatmap_black_true = targets[0]
        heatmap_white_true = targets[1]

        mse_loss_black = nn.functional.mse_loss(heatmap_black_pred, heatmap_black_true)
        mse_loss_white = nn.functional.mse_loss(heatmap_white_pred, heatmap_white_true)

        total_loss = mse_loss_black + mse_loss_white

    return total_loss, {
        'mse_loss': total_loss.item(),
        'focal_loss': 0.0,
        'l_push': 0.0,
        'l_pull': 0.0,
        'total_loss': total_loss.item()
    }


def mse_focal_loss(outputs, targets, alpha=1.0, beta=0.5, unified_cells=False):
    if unified_cells:
        heatmap_pred = outputs[:, 0, :, :].unsqueeze(1)
        heatmap_true = targets[0]

        mse_loss = nn.functional.mse_loss(heatmap_pred, heatmap_true)
        focal_loss_val = focal_loss(heatmap_pred, heatmap_true)

        total_loss = alpha * mse_loss + beta * focal_loss_val
    else:
        heatmap_black_pred = outputs[:, 0, :, :].unsqueeze(1)
        heatmap_white_pred = outputs[:, 1, :, :].unsqueeze(1)

        heatmap_black_true = targets[0]
        heatmap_white_true = targets[1]

        mse_loss_black = nn.functional.mse_loss(heatmap_black_pred, heatmap_black_true)
        mse_loss_white = nn.functional.mse_loss(heatmap_white_pred, heatmap_white_true)

        focal_loss_black = focal_loss(heatmap_black_pred, heatmap_black_true)
        focal_loss_white = focal_loss(heatmap_white_pred, heatmap_white_true)

        total_loss = (alpha * (mse_loss_black + mse_loss_white) +
                      beta * (focal_loss_black + focal_loss_white))

    return total_loss, {
        'mse_loss': (mse_loss_black + mse_loss_white).item() if not unified_cells else mse_loss.item(),
        'focal_loss': (focal_loss_black + focal_loss_white).item() if not unified_cells else focal_loss_val.item(),
        'l_push': 0.0,
        'l_pull': 0.0,
        'total_loss': total_loss.item()
    }


def combined_loss(outputs, targets, alpha=1.0, beta=0.5, gamma=0.3, delta=0.2, sigma=2, unified_cells=False):
    if unified_cells:
        heatmap_pred = outputs[:, 0, :, :].unsqueeze(1)
        heatmap_true = targets[0]

        mse_loss_val = nn.functional.mse_loss(heatmap_pred, heatmap_true)
        focal_loss_val = focal_loss(heatmap_pred, heatmap_true)
    else:
        heatmap_black_pred = outputs[:, 0, :, :].unsqueeze(1)
        heatmap_white_pred = outputs[:, 1, :, :].unsqueeze(1)

        heatmap_black_true = targets[0]
        heatmap_white_true = targets[1]

        mse_loss_black = nn.functional.mse_loss(heatmap_black_pred, heatmap_black_true)
        mse_loss_white = nn.functional.mse_loss(heatmap_white_pred, heatmap_white_true)

        focal_loss_black = focal_loss(heatmap_black_pred, heatmap_black_true)
        focal_loss_white = focal_loss(heatmap_white_pred, heatmap_white_true)

        mse_loss_val = mse_loss_black + mse_loss_white
        focal_loss_val = focal_loss_black + focal_loss_white

    l_push, l_pull = contrast_loss_torch(outputs, targets, sigma, unified_cells)

    total_loss = alpha * mse_loss_val + beta * focal_loss_val + gamma * l_push + delta * l_pull

    return total_loss, {
        'mse_loss': mse_loss_val.item(),
        'focal_loss': focal_loss_val.item(),
        'l_push': float(l_push.detach().cpu()),
        'l_pull': float(l_pull.detach().cpu()),
        'total_loss': float(total_loss.detach().cpu()),
    }


def evaluation_mse_loss(outputs, targets, unified_cells=False):
    if unified_cells:
        heatmap_pred = outputs[:, 0, :, :].unsqueeze(1)
        heatmap_true = targets[0]
        return nn.functional.mse_loss(heatmap_pred, heatmap_true)
    else:
        heatmap_black_pred = outputs[:, 0, :, :].unsqueeze(1)
        heatmap_white_pred = outputs[:, 1, :, :].unsqueeze(1)

        heatmap_black_true = targets[0]
        heatmap_white_true = targets[1]

        mse_loss_black = nn.functional.mse_loss(heatmap_black_pred, heatmap_black_true)
        mse_loss_white = nn.functional.mse_loss(heatmap_white_pred, heatmap_white_true)

        return mse_loss_black + mse_loss_white


def safe_tensor_check(tensor, name="tensor"):
    if torch.any(~torch.isfinite(tensor)):
        tensor = torch.nan_to_num(tensor, nan=0.0, posinf=1.0, neginf=0.0)
    return tensor


def get_loss_function(loss_type, loss_weights, unified_cells=False):
    if loss_type == 'mse':
        return lambda outputs, targets: mse_loss_only(outputs, targets, unified_cells)
    elif loss_type == 'mse_focal':
        return lambda outputs, targets: mse_focal_loss(
            outputs, targets,
            alpha=loss_weights.get('alpha', 1.0),
            beta=loss_weights.get('beta', 0.5),
            unified_cells=unified_cells
        )
    elif loss_type == 'mse_focal_contrast':
        return lambda outputs, targets: combined_loss(
            outputs, targets,
            alpha=loss_weights.get('alpha', 1.0),
            beta=loss_weights.get('beta', 0.5),
            gamma=loss_weights.get('gamma', 0.3),
            delta=loss_weights.get('delta', 0.2),
            sigma=loss_weights.get('sigma', 2.0),
            unified_cells=unified_cells
        )
    else:
        raise ValueError(f"Unknown loss type: {loss_type}")

Now let’s define our evaluation metrics and their visualizations to assess the model’s performance.

In [5]:
def extract_keypoints_from_heatmap(heatmap, threshold=0.5, min_distance=10):
    if torch.is_tensor(heatmap):
        heatmap = heatmap.cpu().numpy()

    binary_map = heatmap > threshold

    local_maxima = maximum_filter(heatmap, size=min_distance) == heatmap
    local_maxima = local_maxima & binary_map

    y_coords, x_coords = np.where(local_maxima)

    keypoints = [(x, y) for x, y in zip(x_coords, y_coords)]
    return keypoints


def calculate_pck(pred_keypoints, true_keypoints, threshold=20):
    if len(pred_keypoints) == 0 or len(true_keypoints) == 0:
        return 0.0

    pred_array = np.array(pred_keypoints)
    true_array = np.array(true_keypoints)

    distances = cdist(pred_array, true_array)
    min_distances = np.min(distances, axis=1)

    correct_keypoints = np.sum(min_distances < threshold)
    pck = correct_keypoints / len(pred_keypoints)

    return pck


def calculate_ap(pred_heatmap, true_heatmap, threshold=0.5):
    if torch.is_tensor(pred_heatmap):
        pred_heatmap = pred_heatmap.cpu().numpy()
    if torch.is_tensor(true_heatmap):
        true_heatmap = true_heatmap.cpu().numpy()

    pred_flat = pred_heatmap.flatten()
    true_flat = (true_heatmap > threshold).astype(int).flatten()

    try:
        ap = average_precision_score(true_flat, pred_flat)
    except:
        ap = 0.0

    return ap


def create_output_directory(dataset_name):
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    output_dir = f"training_results_{dataset_name}_{timestamp}"

    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(os.path.join(output_dir, "predictions"), exist_ok=True)
    os.makedirs(os.path.join(output_dir, "models"), exist_ok=True)
    os.makedirs(os.path.join(output_dir, "plots"), exist_ok=True)

    return output_dir


def create_custom_colormap():
    red_cmap = plt.cm.Reds
    blue_cmap = plt.cm.Blues
    return red_cmap, blue_cmap


def visualize_predictions_improved(model, data_loader, epoch, output_dir, num_samples=4, unified_cells=False):
    model.eval()
    device = next(model.parameters()).device
    images, targets = next(iter(data_loader))
    images = images.to(device)

    with torch.no_grad():
        outputs = model(images)

    images_np = images.cpu().numpy()
    targets_first_np = targets[0].cpu().numpy()
    targets_second_np = targets[1].cpu().numpy()

    if unified_cells:
        outputs_unified_np = outputs[:, 0, :, :].cpu().numpy()
    else:
        outputs_black_np = outputs[:, 0, :, :].cpu().numpy()
        outputs_white_np = outputs[:, 1, :, :].cpu().numpy()

    red_cmap, blue_cmap = create_custom_colormap()

    for i in range(min(num_samples, images.shape[0])):
        img = np.transpose(images_np[i], (1, 2, 0))

        fig, axes = plt.subplots(1, 3, figsize=(18, 6))

        axes[0].imshow(img)
        axes[0].set_title('DataMatrix Image', fontsize=14, fontweight='bold')
        axes[0].axis('off')

        if unified_cells:
            gt_combined = np.zeros((img.shape[0], img.shape[1], 3))
            gt_combined[:, :, 1] = targets_first_np[i, 0]

            axes[1].imshow(gt_combined, vmin=0, vmax=1)
            axes[1].set_title('Ground Truth Heatmap\n(Green: All Cells)', fontsize=14, fontweight='bold')
            axes[1].axis('off')

            pred_combined = np.zeros((img.shape[0], img.shape[1], 3))
            pred_combined[:, :, 1] = outputs_unified_np[i]

            axes[2].imshow(pred_combined, vmin=0, vmax=1)
            axes[2].set_title('Predicted Heatmap\n(Green: All Cells)', fontsize=14, fontweight='bold')
            axes[2].axis('off')
        else:
            gt_combined = np.zeros((img.shape[0], img.shape[1], 3))
            gt_combined[:, :, 0] = targets_first_np[i, 0]
            gt_combined[:, :, 2] = targets_second_np[i, 0]

            axes[1].imshow(gt_combined, vmin=0, vmax=1)
            axes[1].set_title('Ground Truth Heatmap\n(Red: Black, Blue: White)', fontsize=14, fontweight='bold')
            axes[1].axis('off')

            pred_combined = np.zeros((img.shape[0], img.shape[1], 3))
            pred_combined[:, :, 0] = outputs_black_np[i]
            pred_combined[:, :, 2] = outputs_white_np[i]

            axes[2].imshow(pred_combined, vmin=0, vmax=1)
            axes[2].set_title('Predicted Heatmap\n(Red: Black, Blue: White)', fontsize=14, fontweight='bold')
            axes[2].axis('off')

        plt.tight_layout()

        filename = f"epoch_{epoch:03d}_sample_{i:02d}.png"
        filepath = os.path.join(output_dir, "predictions", filename)
        plt.savefig(filepath, dpi=150, bbox_inches='tight')
        plt.close()


def visualize_augmented_samples(dataset, num_samples=4, save_path=None):
    original_augment = dataset.augment
    original_prob = dataset.augment_prob
    dataset.set_augmentation(True, 1.0)

    fig, axes = plt.subplots(2, num_samples, figsize=(5 * num_samples, 10))
    if num_samples == 1:
        axes = axes.reshape(2, 1)

    for i in range(num_samples):
        dataset.set_augmentation(False, 0.0)
        img_orig, (heatmap_first_orig, heatmap_second_orig) = dataset[i]

        dataset.set_augmentation(True, 1.0)
        img_aug, (heatmap_first_aug, heatmap_second_aug) = dataset[i]

        img_orig_np = img_orig.permute(1, 2, 0).numpy()
        img_aug_np = img_aug.permute(1, 2, 0).numpy()

        overlay_orig = img_orig_np.copy()
        overlay_aug = img_aug_np.copy()

        if dataset.unified_cells:
            overlay_orig[:, :, 1] = np.maximum(overlay_orig[:, :, 1], heatmap_first_orig[0].numpy() * 0.8)
            overlay_aug[:, :, 1] = np.maximum(overlay_aug[:, :, 1], heatmap_first_aug[0].numpy() * 0.8)
        else:
            overlay_orig[:, :, 0] = np.maximum(overlay_orig[:, :, 0], heatmap_first_orig[0].numpy() * 0.8)
            overlay_orig[:, :, 2] = np.maximum(overlay_orig[:, :, 2], heatmap_second_orig[0].numpy() * 0.8)
            overlay_aug[:, :, 0] = np.maximum(overlay_aug[:, :, 0], heatmap_first_aug[0].numpy() * 0.8)
            overlay_aug[:, :, 2] = np.maximum(overlay_aug[:, :, 2], heatmap_second_aug[0].numpy() * 0.8)

        axes[0, i].imshow(overlay_orig)
        axes[0, i].set_title(f'Original Sample {i + 1}', fontweight='bold')
        axes[0, i].axis('off')

        axes[1, i].imshow(overlay_aug)
        axes[1, i].set_title(f'Augmented Sample {i + 1}', fontweight='bold')
        axes[1, i].axis('off')

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()

    dataset.set_augmentation(original_augment, original_prob)


def plot_training_metrics(train_losses, val_losses, ap_scores, pck_black_scores, pck_white_scores,
                          contrast_push_losses, contrast_pull_losses, output_dir, loss_type, unified_cells=False):
    epochs = range(1, len(train_losses) + 1)

    if loss_type == 'mse_focal_contrast':
        fig, axes = plt.subplots(3, 2, figsize=(20, 18))
        subplot_config = [(0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1)]
    else:
        fig, axes = plt.subplots(2, 2, figsize=(16, 12))
        subplot_config = [(0, 0), (0, 1), (1, 0), (1, 1)]

    ax = axes[subplot_config[0]]
    ax.plot(epochs, train_losses, 'b-', label='Train Loss', linewidth=3, marker='o', markersize=4)
    ax.set_title(f'Training Loss ({loss_type.upper()})', fontsize=16, fontweight='bold')
    ax.set_xlabel('Epoch', fontsize=12)
    ax.set_ylabel('Loss', fontsize=12)
    ax.legend(fontsize=12)
    ax.grid(True, alpha=0.3)

    ax = axes[subplot_config[1]]
    ax.plot(epochs, val_losses, 'r-', label='Validation Loss (MSE)', linewidth=3, marker='s', markersize=4)
    ax.set_title('Validation Loss (MSE - Standardized)', fontsize=16, fontweight='bold')
    ax.set_xlabel('Epoch', fontsize=12)
    ax.set_ylabel('Loss', fontsize=12)
    ax.legend(fontsize=12)
    ax.grid(True, alpha=0.3)

    ax = axes[subplot_config[2]]
    ax.plot(epochs, ap_scores, 'g-', label='Average Precision', linewidth=3, marker='^', markersize=4)
    ax.set_title('Average Precision Score', fontsize=16, fontweight='bold')
    ax.set_xlabel('Epoch', fontsize=12)
    ax.set_ylabel('AP Score', fontsize=12)
    ax.legend(fontsize=12)
    ax.grid(True, alpha=0.3)

    if unified_cells:
        ax = axes[subplot_config[3]]
        ax.plot(epochs, pck_black_scores, 'purple', label='PCK All Cells', linewidth=3, marker='D', markersize=4)
        ax.set_title('PCK Score - All Cells', fontsize=16, fontweight='bold')
        ax.set_xlabel('Epoch', fontsize=12)
        ax.set_ylabel('PCK Score', fontsize=12)
        ax.legend(fontsize=12)
        ax.grid(True, alpha=0.3)
    else:
        ax = axes[subplot_config[3]]
        ax.plot(epochs, pck_black_scores, 'darkred', label='PCK Black Cells', linewidth=3, marker='v', markersize=4)
        ax.plot(epochs, pck_white_scores, 'darkblue', label='PCK White Cells', linewidth=3, marker='*', markersize=6)
        ax.set_title('PCK Scores - Black vs White Cells', fontsize=16, fontweight='bold')
        ax.set_xlabel('Epoch', fontsize=12)
        ax.set_ylabel('PCK Score', fontsize=12)
        ax.legend(fontsize=12)
        ax.grid(True, alpha=0.3)

    if loss_type == 'mse_focal_contrast':
        ax = axes[subplot_config[4]]
        ax.plot(epochs, contrast_push_losses, 'orange', label='L_push', linewidth=3, marker='h', markersize=4)
        ax.set_title('Contrast Loss - L_push Loss', fontsize=16, fontweight='bold')
        ax.set_xlabel('Epoch', fontsize=12)
        ax.set_ylabel('L_push Loss', fontsize=12)
        ax.legend(fontsize=12)
        ax.grid(True, alpha=0.3)

        ax = axes[subplot_config[5]]
        ax.plot(epochs, contrast_pull_losses, 'purple', label='L_pull', linewidth=3, marker='p', markersize=4)
        ax.set_title('Contrast Loss - L_pull Loss', fontsize=16, fontweight='bold')
        ax.set_xlabel('Epoch', fontsize=12)
        ax.set_ylabel('L_pull Loss', fontsize=12)
        ax.legend(fontsize=12)
        ax.grid(True, alpha=0.3)

    plt.tight_layout()
    mode_suffix = "_unified" if unified_cells else "_separated"
    plt.savefig(os.path.join(output_dir, "plots", f"training_metrics_{loss_type}{mode_suffix}.png"),
                dpi=150, bbox_inches='tight')
    plt.close()

    plt.figure(figsize=(14, 8))
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, 'b-', label=f'Train Loss ({loss_type.upper()})',
             linewidth=3, marker='o', markersize=6)
    plt.title('Training Loss Evolution', fontsize=16, fontweight='bold')
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('Loss', fontsize=12)
    plt.legend(fontsize=12)
    plt.grid(True, alpha=0.3)

    plt.subplot(1, 2, 2)
    plt.plot(epochs, val_losses, 'r-', label='Validation Loss (MSE)',
             linewidth=3, marker='s', markersize=6)
    plt.title('Validation Loss Evolution', fontsize=16, fontweight='bold')
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('Loss', fontsize=12)
    plt.legend(fontsize=12)
    plt.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "plots", f"train_vs_val_losses{mode_suffix}.png"),
                dpi=150, bbox_inches='tight')
    plt.close()

    if loss_type == 'mse_focal_contrast':
        plt.figure(figsize=(12, 7))
        plt.plot(epochs, contrast_push_losses, 'orange', label='L_push', linewidth=3,
                 marker='o', markersize=5, alpha=0.8)
        plt.plot(epochs, contrast_pull_losses, 'purple', label='L_pull', linewidth=3,
                 marker='s', markersize=5, alpha=0.8)
        plt.title('Contrast Losses Comparison', fontsize=18, fontweight='bold')
        plt.xlabel('Epoch', fontsize=14)
        plt.ylabel('Loss Value', fontsize=14)
        plt.legend(fontsize=14)
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, "plots", f"contrast_losses_comparison{mode_suffix}.png"),
                    dpi=150, bbox_inches='tight')
        plt.close()

Let’s implement a mechanism to save model weights during training (checkpoints).

In [6]:
def save_model_checkpoint(model, epoch, optimizer, scheduler, output_dir, is_best=False):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
    }

    filename = f'model_epoch_{epoch:03d}.pth'
    if is_best:
        filename = f'best_model_epoch_{epoch:03d}.pth'

    filepath = os.path.join(output_dir, "models", filename)
    torch.save(checkpoint, filepath)

Let’s implement functions to visualize system resource usage during training.

In [7]:
def print_memory_usage():
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / (1024 ** 3)
        reserved = torch.cuda.memory_reserved() / (1024 ** 3)
        print(f"GPU Memory - Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB")


def monitor_gpu_power(log_interval, stop_event, power_data):
    while not stop_event.is_set():
        try:
            result = subprocess.run(
                ['nvidia-smi', '--query-gpu=power.draw', '--format=csv,noheader,nounits'],
                stdout=subprocess.PIPE, text=True, timeout=5
            )
            power = float(result.stdout.strip().split('\n')[0])
            power_data.append((time.time(), power))
        except (ValueError, subprocess.TimeoutExpired):
            pass
        time.sleep(log_interval)

Finally, let’s implement our model’s training function—covering the training loop, visualization of evaluation metrics, choosing the model’s hyperparameters and an estimate of the training’s environmental impact (carbon footprint).

In [8]:
def train_model_configurable(model, train_loader, val_loader, optimizer, scheduler, num_epochs, output_dir,
                             dataset_name, loss_type, loss_weights, use_augmentation, backbone, unified_cells=False):
    device = next(model.parameters()).device
    best_val_loss = float('inf')
    best_epoch = 0


    train_losses = []
    val_losses = []
    ap_scores = []
    pck_black_scores = []
    pck_white_scores = []
    contrast_push_losses = []
    contrast_pull_losses = []

    train_loss_fn = get_loss_function(loss_type, loss_weights, unified_cells)

    power_data = []
    stop_event = threading.Event()
    monitor_thread = threading.Thread(target=monitor_gpu_power, args=(5, stop_event, power_data))
    monitor_thread.start()
    start_time = time.time()

    print(f" Starting configurable training - Dataset: {dataset_name}")
    print(f"️  Backbone: {backbone.upper()}")
    print(f" Training loss: {loss_type.upper()}")
    print(f" Evaluation loss: MSE (standardized)")
    print(f" Data augmentation: {'ENABLED' if use_augmentation else 'DISABLED'}")
    print(f" Cell detection mode: {'UNIFIED (all cells)' if unified_cells else 'SEPARATED (black/white)'}")
    print(f"️  Loss weights: {loss_weights}")
    print(f" Results saved to: {output_dir}")
    print("=" * 90)

    for epoch in range(num_epochs):
        torch.cuda.empty_cache()
        print_memory_usage()


        model.train()
        train_loss = 0.0
        epoch_push_losses = []
        epoch_pull_losses = []

        for batch_idx, (images, targets) in enumerate(train_loader):
            images = images.to(device)
            heatmap_first, heatmap_second = targets
            heatmap_first = heatmap_first.to(device)
            heatmap_second = heatmap_second.to(device)

            outputs = model(images)

            loss, loss_components = train_loss_fn(outputs, (heatmap_first, heatmap_second))

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            train_loss += loss.item() * images.size(0)
            epoch_push_losses.append(loss_components['l_push'])
            epoch_pull_losses.append(loss_components['l_pull'])

            if batch_idx % 10 == 0:
                torch.cuda.empty_cache()

        train_loss /= len(train_loader.dataset)

        model.eval()
        val_loss = 0.0
        epoch_ap_scores = []
        epoch_pck_black_scores = []
        epoch_pck_white_scores = []
        val_push_losses = []
        val_pull_losses = []

        with torch.no_grad():
            for val_batch_idx, (images, targets) in enumerate(val_loader):
                images = images.to(device)
                heatmap_first, heatmap_second = targets
                heatmap_first = heatmap_first.to(device)
                heatmap_second = heatmap_second.to(device)

                outputs = model(images)

                mse_loss = evaluation_mse_loss(outputs, (heatmap_first, heatmap_second), unified_cells)
                val_loss += mse_loss.item() * images.size(0)

                if loss_type == 'mse_focal_contrast':
                    _, contrast_components = combined_loss(
                        outputs, (heatmap_first, heatmap_second),
                        alpha=loss_weights.get('alpha', 1.0),
                        beta=loss_weights.get('beta', 0.5),
                        gamma=loss_weights.get('gamma', 0.3),
                        delta=loss_weights.get('delta', 0.2),
                        sigma=loss_weights.get('sigma', 2.0),
                        unified_cells=unified_cells
                    )
                    val_push_losses.append(contrast_components['l_push'])
                    val_pull_losses.append(contrast_components['l_pull'])
                else:
                    val_push_losses.append(0.0)
                    val_pull_losses.append(0.0)

                for i in range(images.size(0)):
                    if unified_cells:

                        ap_unified = calculate_ap(outputs[i, 0], heatmap_first[i, 0])
                        epoch_ap_scores.append(ap_unified)

                        pred_unified_kpts = extract_keypoints_from_heatmap(outputs[i, 0])
                        true_unified_kpts = extract_keypoints_from_heatmap(heatmap_first[i, 0])
                        pck_unified = calculate_pck(pred_unified_kpts, true_unified_kpts)
                        epoch_pck_black_scores.append(pck_unified)  
                        epoch_pck_white_scores.append(pck_unified) 
                    else:
                        ap_black = calculate_ap(outputs[i, 0], heatmap_first[i, 0])
                        ap_white = calculate_ap(outputs[i, 1], heatmap_second[i, 0])
                        epoch_ap_scores.append((ap_black + ap_white) / 2)

                        pred_black_kpts = extract_keypoints_from_heatmap(outputs[i, 0])
                        true_black_kpts = extract_keypoints_from_heatmap(heatmap_first[i, 0])
                        pck_black = calculate_pck(pred_black_kpts, true_black_kpts)
                        epoch_pck_black_scores.append(pck_black)

                        pred_white_kpts = extract_keypoints_from_heatmap(outputs[i, 1])
                        true_white_kpts = extract_keypoints_from_heatmap(heatmap_second[i, 0])
                        pck_white = calculate_pck(pred_white_kpts, true_white_kpts)
                        epoch_pck_white_scores.append(pck_white)

                if val_batch_idx % 10 == 0:
                    torch.cuda.empty_cache()

        val_loss /= len(val_loader.dataset)
        avg_ap = np.mean(epoch_ap_scores) if epoch_ap_scores else 0.0
        avg_pck_black = np.mean(epoch_pck_black_scores) if epoch_pck_black_scores else 0.0
        avg_pck_white = np.mean(epoch_pck_white_scores) if epoch_pck_white_scores else 0.0
        avg_push_loss = np.mean(epoch_push_losses + val_push_losses)
        avg_pull_loss = np.mean(epoch_pull_losses + val_pull_losses)

        scheduler.step(val_loss)


        train_losses.append(train_loss)
        val_losses.append(val_loss)
        ap_scores.append(avg_ap)
        pck_black_scores.append(avg_pck_black)
        pck_white_scores.append(avg_pck_white)
        contrast_push_losses.append(avg_push_loss)
        contrast_pull_losses.append(avg_pull_loss)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_epoch = epoch
            save_model_checkpoint(model, epoch, optimizer, scheduler, output_dir, is_best=True)

        if (epoch + 1) % 10 == 0:
            save_model_checkpoint(model, epoch, optimizer, scheduler, output_dir)

        loss_display = f'Train({loss_type.upper()}): {train_loss:.4f} | Val(MSE): {val_loss:.4f}'

        if unified_cells:
            metrics_display = f'AP: {avg_ap:.4f} | PCK_All: {avg_pck_black:.4f}'
        else:
            metrics_display = f'AP: {avg_ap:.4f} | PCK_B: {avg_pck_black:.4f} | PCK_W: {avg_pck_white:.4f}'

        if loss_type == 'mse_focal_contrast':
            contrast_display = f' | L_push: {avg_push_loss:.4f} | L_pull: {avg_pull_loss:.4f}'
        else:
            contrast_display = ''

        print(
            f'Epoch {epoch + 1:3d}/{num_epochs} | {loss_display} | {metrics_display}{contrast_display} | LR: {optimizer.param_groups[0]["lr"]:.6f}')

        if (epoch + 1) % 5 == 0 or epoch == 0:
            visualize_predictions_improved(model, val_loader, epoch + 1, output_dir, unified_cells=unified_cells)

        if (epoch + 1) % 10 == 0:
            plot_training_metrics(train_losses, val_losses, ap_scores,
                                  pck_black_scores, pck_white_scores,
                                  contrast_push_losses, contrast_pull_losses, output_dir, loss_type, unified_cells)

    stop_event.set()
    monitor_thread.join()
    end_time = time.time()

    visualize_predictions_improved(model, val_loader, num_epochs, output_dir, num_samples=8,
                                   unified_cells=unified_cells)
    plot_training_metrics(train_losses, val_losses, ap_scores,
                          pck_black_scores, pck_white_scores,
                          contrast_push_losses, contrast_pull_losses, output_dir, loss_type, unified_cells)

    duration_hours = (end_time - start_time) / 3600
    if power_data:
        df_power = pd.DataFrame(power_data, columns=['timestamp', 'power_watts'])
        mean_power = df_power['power_watts'].mean()
        energy_consumed_kwh = (mean_power * duration_hours) / 1000
        CO2_FACTOR_BELGIUM = 0.11  
        carbon_impact_kg = energy_consumed_kwh * CO2_FACTOR_BELGIUM

        print("\n" + "=" * 50)
        print(" ENERGY REPORT")
        print("=" * 50)
        print(f" Training duration: {duration_hours:.2f} h")
        print(f" Average GPU power: {mean_power:.2f} W")
        print(f" GPU energy consumed: {energy_consumed_kwh:.4f} kWh")
        print(f" Estimated carbon footprint (Belgium): {carbon_impact_kg * 1000:.2f} g CO2")

        df_power.to_csv(os.path.join(output_dir, "gpu_power_log.csv"), index=False)

    print(f"\n Best model: Epoch {best_epoch + 1} with validation loss (MSE) = {best_val_loss:.4f}")
    print(f" Complete results available in: {output_dir}")

    return {
        'best_epoch': best_epoch,
        'best_val_loss': best_val_loss,
        'final_ap': ap_scores[-1] if ap_scores else 0,
        'final_pck_black': pck_black_scores[-1] if pck_black_scores else 0,
        'final_pck_white': pck_white_scores[-1] if pck_white_scores else 0,
        'final_l_push': contrast_push_losses[-1] if contrast_push_losses else 0,
        'final_l_pull': contrast_pull_losses[-1] if contrast_pull_losses else 0,
        'training_time_hours': duration_hours,
        'loss_type': loss_type,
        'use_augmentation': use_augmentation,
        'backbone': backbone,
        'unified_cells': unified_cells
    }


def get_user_configuration():
    print("=" * 70)
    print("CONFIGURATION SETUP")
    print("=" * 70)

    print("\n  SELECT BACKBONE ARCHITECTURE:")
    print("1. ResNet-18 ")
    print("2. ResNet-34 ")
    print("3. ResNet-50 ")

    while True:
        try:
            backbone_choice = int(input("\nEnter your choice (1-3): "))
            if backbone_choice in [1, 2, 3]:
                break
            else:
                print("Please enter 1, 2, or 3")
        except ValueError:
            print("Please enter a valid number")

    backbone_types = {1: 'resnet18', 2: 'resnet34', 3: 'resnet50'}
    backbone = backbone_types[backbone_choice]

    print("\n SELECT CELL DETECTION MODE:")
    print("1. Separated mode - Distinguish black (v=2) and white (v=1) cells")
    print("2. Unified mode - Detect all cells (v=1 or v=2) without polarity distinction")

    while True:
        try:
            cell_mode_choice = int(input("\nEnter your choice (1-2): "))
            if cell_mode_choice in [1, 2]:
                break
            else:
                print(" Please enter 1 or 2")
        except ValueError:
            print(" Please enter a valid number")

    unified_cells = cell_mode_choice == 2

    print("\n SELECT LOSS FUNCTION:")
    print("1. MSE only")
    print("2. MSE + Focal Loss")
    print("3. MSE + Focal Loss + Contrast Loss")

    while True:
        try:
            loss_choice = int(input("\nEnter your choice (1-3): "))
            if loss_choice in [1, 2, 3]:
                break
            else:
                print(" Please enter 1, 2, or 3")
        except ValueError:
            print(" Please enter a valid number")

    loss_types = {1: 'mse', 2: 'mse_focal', 3: 'mse_focal_contrast'}
    loss_type = loss_types[loss_choice]

    print("\n ENABLE DATA AUGMENTATION:")
    print("1. Yes - Enable online data augmentation")
    print("2. No - Disable data augmentation")

    while True:
        try:
            aug_choice = int(input("\nEnter your choice (1-2): "))
            if aug_choice in [1, 2]:
                break
            else:
                print(" Please enter 1 or 2")
        except ValueError:
            print(" Please enter a valid number")

    use_augmentation = aug_choice == 1

    print("\n" + "=" * 70)
    print(" CONFIGURATION SELECTED:")
    print(f"️     Backbone: {backbone.upper()}")
    print(f"     Cell mode: {'UNIFIED (all cells)' if unified_cells else 'SEPARATED (black/white)'}")
    print(f"     Loss function: {loss_type.upper()}")
    print(f"     Data augmentation: {'ENABLED' if use_augmentation else 'DISABLED'}")
    print(f"     Evaluation: Always MSE (for fair comparison)")
    print("=" * 70)

    return backbone, loss_type, use_augmentation, unified_cells

Time to train the model.

 Please enter your configuration in CONFIG.

After running the cell, you’ll need to choose the dataset, the backbone, the training loss function, and whether to apply data augmentations.

In [None]:
backbone, loss_type, use_augmentation, unified_cells = get_user_configuration()


mode_suffix = "_unified" if unified_cells else "_separated"
CONFIG = {
    'dataset_name': f'DataMatrix_{backbone}_{loss_type}{"_aug" if use_augmentation else "_no_aug"}{mode_suffix}',
    'batch_size': 8,
    'learning_rate': 0.001,
    'epochs': 80,
    'sigma': 2,
    'weight_decay': 1e-4,
    'target_size': (512, 512),

    'train_json':,
    'train_img_dir': ,
    'val_json': ,
    'val_img_dir': ,

    'backbone': backbone,
    'loss_type': loss_type,
    'use_augmentation': use_augmentation,
    'unified_cells': unified_cells,

    'augmentation': {
        'train_prob': 0.80 if use_augmentation else 0.0,
        'val_prob': 0.0, 
    },

    'loss_weights': {
        'alpha': 1.0,  # MSE loss weight
        'beta': 0.5,  # Focal loss weight
        'gamma': 0.1,  # L_push loss weight
        'delta': 0.4,  # L_pull loss weight
        'sigma': 2
    }
}


output_dir = create_output_directory(CONFIG['dataset_name'])

with open(os.path.join(output_dir, 'config.json'), 'w') as f:
    json.dump(CONFIG, f, indent=2)


train_img_list = os.listdir(CONFIG['train_img_dir'])
train_img_list = [f for f in train_img_list if f.lower().endswith(('.png', '.jpg', '.jpeg'))]

val_img_list = os.listdir(CONFIG['val_img_dir'])
val_img_list = [f for f in val_img_list if f.lower().endswith(('.png', '.jpg', '.jpeg'))]

train_dataset = KeypointDatasetAugmented(
    coco_json=CONFIG['train_json'],
    img_dir=CONFIG['train_img_dir'],
    img_list=train_img_list,
    sigma=CONFIG['sigma'],  
    target_size=CONFIG['target_size'],
    augment=CONFIG['use_augmentation'],
    augment_prob=CONFIG['augmentation']['train_prob'],
    unified_cells=CONFIG['unified_cells']
)

val_dataset = KeypointDatasetAugmented(
    coco_json=CONFIG['val_json'],
    img_dir=CONFIG['val_img_dir'],
    img_list=val_img_list,
    sigma=CONFIG['sigma'],  
    target_size=CONFIG['target_size'],
    augment=False, 
    augment_prob=0.0,
    unified_cells=CONFIG['unified_cells']
)

train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=True,
    num_workers=0,
    pin_memory=True if torch.cuda.is_available() else False
)

val_loader = DataLoader(
    val_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=0,
    pin_memory=True if torch.cuda.is_available() else False
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = KeypointDetector(backbone=CONFIG['backbone'], pretrained=True, unified_cells=CONFIG['unified_cells'])
model = model.to(device)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=CONFIG['learning_rate'],
    weight_decay=CONFIG['weight_decay']
)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.7,
    patience=8,
    min_lr=1e-6
)

if CONFIG['use_augmentation']:
    viz_path = os.path.join(output_dir, "augmentation_preview.png")
    visualize_augmented_samples(train_dataset, num_samples=4, save_path=viz_path)

results = train_model_configurable(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    scheduler=scheduler,
    num_epochs=CONFIG['epochs'],
    output_dir=output_dir,
    dataset_name=CONFIG['dataset_name'],
    loss_type=CONFIG['loss_type'],
    loss_weights=CONFIG['loss_weights'],
    use_augmentation=CONFIG['use_augmentation'],
    backbone=CONFIG['backbone'],
    unified_cells=CONFIG['unified_cells']
)
