In [None]:
import cv2
import numpy as np
import os
import glob as glob
import time
import matplotlib.pyplot as plt
from functools import partial
from tqdm.auto import tqdm

import albumentations as A
from albumentations.pytorch import ToTensorV2

import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision.models.detection import RetinaNet_ResNet50_FPN_V2_Weights
from torchvision.models.detection.retinanet import RetinaNetClassificationHead
from torchmetrics.detection.mean_ap import MeanAveragePrecision
from torch.optim.lr_scheduler import StepLR


plt.style.use("ggplot")



BATCH_SIZE = 4  # Increase / decrease according to GPU memeory.
RESIZE_TO = 640  # Resize the image for training and transforms.
NUM_EPOCHS = 40  # Number of epochs to train for.
NUM_WORKERS = 4  # Number of parallel workers for data loading.

DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Training images and labels files directory.
TRAIN_DIR = "data/train"
# Validation images and labels files directory.
VALID_DIR = "data/valid"

# Classes: 0 index is reserved for background.
CLASSES = [
    "__background__",
    "elbow positive",
    "fingers positive",
    "forearm fracture",
    "humerus fracture",
    "humerus",
    "shoulder fracture",
    "wrist positive",
]

NUM_CLASSES = len(CLASSES)

# Whether to visualize images after crearing the data loaders.
VISUALIZE_TRANSFORMED_IMAGES = True

# Location to save model and plots.
OUT_DIR = "outputs"

In [None]:
class Averager:
    """
    A class to keep track of running average of values (e.g. training loss).
    """

    def __init__(self):
        self.current_total = 0.0
        self.iterations = 0.0

    def send(self, value):
        self.current_total += value
        self.iterations += 1

    @property
    def value(self):
        if self.iterations == 0:
            return 0
        else:
            return self.current_total / self.iterations

    def reset(self):
        self.current_total = 0.0
        self.iterations = 0.0


class SaveBestModel:
    """
    Saves the model if the current epoch's validation mAP is higher
    than all previously observed values.
    """

    def __init__(self, best_valid_map=float(0)):
        self.best_valid_map = best_valid_map

    def __call__(
        self,
        model,
        current_valid_map,
        epoch,
        OUT_DIR,
    ):
        if current_valid_map > self.best_valid_map:
            self.best_valid_map = current_valid_map
            print(f"\nBEST VALIDATION mAP: {self.best_valid_map}")
            print(f"SAVING BEST MODEL FOR EPOCH: {epoch+1}\n")
            torch.save(
                {
                    "epoch": epoch + 1,
                    "model_state_dict": model.state_dict(),
                },
                f"{OUT_DIR}/best_model.pth",
            )


def collate_fn(batch):
    """
    To handle the data loading as different images may have different
    numbers of objects, and to handle varying-size tensors as well.
    """
    return tuple(zip(*batch))


def get_train_transform():
    # We keep "pascal_voc" because bounding box format is [x_min, y_min, x_max, y_max].
    return A.Compose(
        [
            A.HorizontalFlip(p=0.5),
            ToTensorV2(p=1.0),
        ],
        bbox_params={"format": "pascal_voc", "label_fields": ["labels"]},
    )


def get_valid_transform():
    return A.Compose(
        [
            ToTensorV2(p=1.0),
        ],
        bbox_params={"format": "pascal_voc", "label_fields": ["labels"]},
    )


def show_tranformed_image(train_loader):
    """
    Visualize transformed images from the `train_loader` for debugging.
    Only runs if `VISUALIZE_TRANSFORMED_IMAGES = True` in config.py.
    """
    if len(train_loader) > 0:
        for i in range(BATCH_SIZE):
            images, targets = next(iter(train_loader))
            images = list(image.to(DEVICE) for image in images)

            targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
            for i in range(len(images)):
                if len(targets[i]["boxes"]) == 0:
                    continue
                boxes = targets[i]["boxes"].cpu().numpy().astype(np.int32)
                labels = targets[i]["labels"].cpu().numpy().astype(np.int32)
                sample = images[i].permute(1, 2, 0).cpu().numpy()
                sample = cv2.cvtColor(sample, cv2.COLOR_RGB2BGR)

                for box_num, box in enumerate(boxes):
                    cv2.rectangle(sample, (box[0], box[1]), (box[2], box[3]), (0, 0, 255), 2)
                    cv2.putText(
                        sample,
                        CLASSES[labels[box_num]],
                        (box[0], box[1] - 10),
                        cv2.FONT_HERSHEY_SIMPLEX,
                        1.0,
                        (0, 0, 255),
                        2,
                    )
                cv2.imshow("Transformed image", sample)
                cv2.waitKey(0)
                cv2.destroyAllWindows()


def save_model(epoch, model, optimizer):
    """
    Save the trained model (state dict) and optimizer state to disk.
    """
    torch.save(
        {
            "epoch": epoch + 1,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
        },
        "outputs/last_model.pth",
    )


def save_loss_plot(OUT_DIR, train_loss_list, x_label="iterations", y_label="train loss", save_name="train_loss"):
    """
    Saves the training loss curve.
    """
    plt.figure(figsize=(10, 7))
    plt.plot(train_loss_list, color="tab:blue")
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.savefig(f"{OUT_DIR}/{save_name}.png")
    print("SAVING PLOTS COMPLETE...")


def save_mAP(OUT_DIR, map_05, map):
    """
    Saves the mAP@0.5 and mAP@0.5:0.95 curves per epoch.
    """
    plt.figure(figsize=(10, 7))
    plt.plot(map_05, color="tab:orange", linestyle="-", label="mAP@0.5")
    plt.plot(map, color="tab:red", linestyle="-", label="mAP@0.5:0.95")
    plt.xlabel("Epochs")
    plt.ylabel("mAP")
    plt.legend()
    plt.savefig(f"{OUT_DIR}/map.png")
    print("SAVING mAP PLOTS COMPLETE...")


In [None]:
class CustomDataset(Dataset):
    def __init__(self, dir_path, width, height, classes, transforms=None):
        """
        :param dir_path: Directory with 'images/' and 'labels/' sub-folders.
        :param width: Resized width for images.
        :param height: Resized height for images.
        :param classes: List of class names.
        :param transforms: Albumentations transforms to apply.
        """
        self.transforms = transforms
        self.dir_path = dir_path
        self.image_dir = os.path.join(self.dir_path, "images")
        self.label_dir = os.path.join(self.dir_path, "labels")
        self.width = width
        self.height = height
        self.classes = classes

        self.image_file_types = ["*.jpg", "*.jpeg", "*.png", "*.ppm", "*.JPG"]
        self.all_image_paths = []
        for file_type in self.image_file_types:
            self.all_image_paths.extend(glob.glob(os.path.join(self.image_dir, file_type)))

        # Sort the paths so that images and labels stay in consistent order.
        self.all_image_paths = sorted(self.all_image_paths)
        self.all_image_names = [os.path.basename(img_p) for img_p in self.all_image_paths]

    def __getitem__(self, idx):
        image_name = self.all_image_names[idx]
        image_path = os.path.join(self.image_dir, image_name)
        label_filename = os.path.splitext(image_name)[0] + ".txt"
        label_path = os.path.join(self.label_dir, label_filename)

        # Read and preprocess image
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
        orig_h, orig_w = image.shape[:2]

        # Resize the image to (self.width, self.height)
        image_resized = cv2.resize(image, (self.width, self.height))
        image_resized /= 255.0  # scale to [0,1]

        # Read the normalized bounding boxes from the .txt file
        boxes = []
        labels = []
        if os.path.exists(label_path):
            with open(label_path, "r") as f:
                lines = f.readlines()

            for line in lines:
                line = line.strip()
                if not line:
                    continue
                # Each line: class_id x_min y_min x_max y_max (normalized)
                parts = line.split()
                class_id = parts[0]
                xmin = float(parts[1])
                ymin = float(parts[2])
                xmax = float(parts[3])
                ymax = float(parts[4])

                # Convert class_name to label index
                label_idx = int(class_id) + 1

                # Because bounding boxes are normalized [0..1],
                # directly multiply by self.width, self.height:
                x_min_final = xmin * self.width
                x_max_final = xmax * self.width
                y_min_final = ymin * self.height
                y_max_final = ymax * self.height

                # Ensure max coords > min coords
                if x_max_final == x_min_final:
                    x_max_final += 1
                if y_max_final == y_min_final:
                    y_max_final += 1

                # Clip coords if they exceed boundaries after scaling
                x_min_final = max(0, min(x_min_final, self.width - 1))
                x_max_final = max(0, min(x_max_final, self.width))
                y_min_final = max(0, min(y_min_final, self.height - 1))
                y_max_final = max(0, min(y_max_final, self.height))

                boxes.append([x_min_final, y_min_final, x_max_final, y_max_final])
                labels.append(label_idx)

        # Convert to tensors
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        area = (
            (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
            if len(boxes) > 0
            else torch.tensor([], dtype=torch.float32)
        )
        iscrowd = torch.zeros((boxes.shape[0],), dtype=torch.int64)
        image_id = torch.tensor([idx])

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["area"] = area
        target["iscrowd"] = iscrowd
        target["image_id"] = image_id

        # Apply transforms if any
        if self.transforms:
            sample = self.transforms(image=image_resized, bboxes=target["boxes"], labels=labels)
            image_resized = sample["image"]
            # 'sample["bboxes"]' can be a list of tuples; convert to tensor
            target["boxes"] = torch.tensor(sample["bboxes"], dtype=torch.float32)

        # If no boxes or NaNs, fix that scenario
        if target["boxes"].numel() == 0:
            target["boxes"] = torch.zeros((0, 4), dtype=torch.float32)

        return image_resized, target

    def __len__(self):
        return len(self.all_image_paths)


# Create dataset & loader
def create_train_dataset(DIR):
    train_dataset = CustomDataset(
        dir_path=DIR, width=RESIZE_TO, height=RESIZE_TO, classes=CLASSES, transforms=get_train_transform()
    )
    return train_dataset


def create_valid_dataset(DIR):
    valid_dataset = CustomDataset(
        dir_path=DIR, width=RESIZE_TO, height=RESIZE_TO, classes=CLASSES, transforms=get_valid_transform()
    )
    return valid_dataset


def create_train_loader(train_dataset, num_workers=0):
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=num_workers,
        collate_fn=collate_fn,
        drop_last=True,
    )
    return train_loader


def create_valid_loader(valid_dataset, num_workers=0):
    valid_loader = DataLoader(
        valid_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=num_workers,
        collate_fn=collate_fn,
        drop_last=True,
    )
    return valid_loader


# Quick test if this file is run directly
if __name__ == "__main__":
    dataset = CustomDataset(TRAIN_DIR, RESIZE_TO, RESIZE_TO, CLASSES)
    print(f"Number of training images: {len(dataset)}")

    def visualize_sample(image, target):
        # Convert to NumPy for OpenCV visualization
        # img = image.transpose(1, 2, 0)  # shape: (H, W, C)
        img = image
        img = (img * 255).astype(np.uint8)
        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)

        boxes = target["boxes"].cpu().numpy().astype(np.int32)
        labels = target["labels"].cpu().numpy().astype(np.int32)
        for box_num, box in enumerate(boxes):
            cv2.rectangle(img, (box[0], box[1]), (box[2], box[3]), (0, 0, 255), 2)
            class_str = CLASSES[labels[box_num]]
            cv2.putText(img, class_str, (box[0], box[1] - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
        cv2.imshow("Sample", img)
        cv2.waitKey(0)

    NUM_SAMPLES_TO_VISUALIZE = 3
    for i in range(NUM_SAMPLES_TO_VISUALIZE):
        image, target = dataset[i]
        visualize_sample(image, target)

In [None]:
def create_model(num_classes=91):
    """
    Creates a RetinaNet-ResNet50-FPN v2 model pre-trained on COCO.
    Replaces the classification head for the required number of classes.
    """
    model = torchvision.models.detection.retinanet_resnet50_fpn_v2(weights=RetinaNet_ResNet50_FPN_V2_Weights.COCO_V1)
    num_anchors = model.head.classification_head.num_anchors

    # Replace the classification head
    model.head.classification_head = RetinaNetClassificationHead(
        in_channels=256, num_anchors=num_anchors, num_classes=num_classes, norm_layer=partial(torch.nn.GroupNorm, 32)
    )
    return model


if __name__ == "__main__":
    model = create_model(num_classes=8)
    print(model)
    # Total parameters:
    total_params = sum(p.numel() for p in model.parameters())
    print(f"{total_params:,} total parameters.")
    # Trainable parameters:
    total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"{total_trainable_params:,} training parameters.")


In [None]:
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)


# Function for running training iterations.
def train(train_data_loader, model):
    print("Training")
    model.train()

    # initialize tqdm progress bar
    prog_bar = tqdm(train_data_loader, total=len(train_data_loader))

    for i, data in enumerate(prog_bar):
        optimizer.zero_grad()
        images, targets = data

        images = list(image.to(DEVICE) for image in images)
        targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
        loss_dict = model(images, targets)

        losses = sum(loss for loss in loss_dict.values())
        loss_value = losses.item()

        train_loss_hist.send(loss_value)

        losses.backward()
        optimizer.step()

        # update the loss value beside the progress bar for each iteration
        prog_bar.set_description(desc=f"Loss: {loss_value:.4f}")
    return loss_value


# Function for running validation iterations.
def validate(valid_data_loader, model):
    print("Validating")
    model.eval()

    # Initialize tqdm progress bar.
    prog_bar = tqdm(valid_data_loader, total=len(valid_data_loader))
    target = []
    preds = []
    for i, data in enumerate(prog_bar):
        images, targets = data

        images = list(image.to(DEVICE) for image in images)
        targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]

        with torch.no_grad():
            outputs = model(images, targets)

        # For mAP calculation using Torchmetrics.
        #####################################
        for i in range(len(images)):
            true_dict = dict()
            preds_dict = dict()
            true_dict["boxes"] = targets[i]["boxes"].detach().cpu()
            true_dict["labels"] = targets[i]["labels"].detach().cpu()
            preds_dict["boxes"] = outputs[i]["boxes"].detach().cpu()
            preds_dict["scores"] = outputs[i]["scores"].detach().cpu()
            preds_dict["labels"] = outputs[i]["labels"].detach().cpu()
            preds.append(preds_dict)
            target.append(true_dict)
        #####################################

    metric.reset()
    metric.update(preds, target)
    metric_summary = metric.compute()
    return metric_summary


if __name__ == "__main__":
    os.makedirs("outputs", exist_ok=True)
    train_dataset = create_train_dataset(TRAIN_DIR)
    valid_dataset = create_valid_dataset(VALID_DIR)
    train_loader = create_train_loader(train_dataset, NUM_WORKERS)
    valid_loader = create_valid_loader(valid_dataset, NUM_WORKERS)
    print(f"Number of training samples: {len(train_dataset)}")
    print(f"Number of validation samples: {len(valid_dataset)}\n")

    # Initialize the model and move to the computation device.
    model = create_model(num_classes=NUM_CLASSES)
    model = model.to(DEVICE)
    print(model)
    # Total parameters and trainable parameters.
    total_params = sum(p.numel() for p in model.parameters())
    print(f"{total_params:,} total parameters.")
    total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"{total_trainable_params:,} training parameters.")
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(params, lr=0.002, momentum=0.9, nesterov=True)
    scheduler = StepLR(optimizer=optimizer, step_size=15, gamma=0.1)

    # To monitor training loss
    train_loss_hist = Averager()
    # To store training loss and mAP values.
    train_loss_list = []
    map_50_list = []
    map_list = []

    # Mame to save the trained model with.
    MODEL_NAME = "model"

    # Whether to show transformed images from data loader or not.
    if VISUALIZE_TRANSFORMED_IMAGES:
        from custom_utils import show_tranformed_image

        show_tranformed_image(train_loader)

    # To save best model.
    save_best_model = SaveBestModel()

    metric = MeanAveragePrecision()

    # Training loop.
    for epoch in range(NUM_EPOCHS):
        print(f"\nEPOCH {epoch+1} of {NUM_EPOCHS}")

        # Reset the training loss histories for the current epoch.
        train_loss_hist.reset()

        # Start timer and carry out training and validation.
        start = time.time()
        train_loss = train(train_loader, model)
        metric_summary = validate(valid_loader, model)
        print(f"Epoch #{epoch+1} train loss: {train_loss_hist.value:.3f}")
        print(f"Epoch #{epoch+1} mAP: {metric_summary['map']:.3f}")
        end = time.time()
        print(f"Took {((end - start) / 60):.3f} minutes for epoch {epoch}")

        train_loss_list.append(train_loss)
        map_50_list.append(metric_summary["map_50"])
        map_list.append(metric_summary["map"])

        # save the best model till now.
        save_best_model(model, float(metric_summary["map"]), epoch, "outputs")
        # Save the current epoch model.
        save_model(epoch, model, optimizer)

        # Save loss plot.
        save_loss_plot(OUT_DIR, train_loss_list)

        # Save mAP plot.
        save_mAP(OUT_DIR, map_50_list, map_list)
        scheduler.step()
        print("Current LR:", scheduler.get_last_lr())
