In [2]:
from dataclasses import dataclass
from collections import deque

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torch import optim
from tqdm import tqdm
import pytorch_lightning as pl
from torchvision.ops import sigmoid_focal_loss, batched_nms

from modules.utils import convert_to_xywh, convert_to_xyxy, generate_subset, calc_iou
from modules.datasets import CocoDetection
import modules.transforms as T
from modules.models import RetinaNet


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
@torch.no_grad()
def post_process(
    preds_class: torch.Tensor,
    preds_box: torch.Tensor,
    anchors: torch.Tensor,
    targets: list[dict],
    conf_threshold: float = 0.05,
    nms_threshold: float = 0.5,
):
    batch_size = preds_class.shape[0]

    anchors_xywh = convert_to_xywh(anchors)

    preds_box[:, :, :2] = anchors_xywh[:, :2] + preds_box[:, :, :2] * anchors_xywh[:, 2:]
    preds_box[:, :, 2:] = preds_box[:, :, 2:].exp() * anchors_xywh[:, 2:]

    preds_box = convert_to_xyxy(preds_box)

    preds_class = preds_class.sigmoid()

    scores = []
    labels = []
    boxes = []
    for img_preds_class, img_preds_box, img_targets in zip(preds_class, preds_box, targets):
        # Clamp bounding box into image size.
        img_preds_box[:, ::2] = img_preds_box[:, ::2].clamp(min=0, max=img_targets["img_size"][0]) # x
        img_preds_box[:, 1::2] = img_preds_box[:, 1::2].clamp(min=0, max=img_targets["img_size"][1]) # y

        # Rescale bounding box to fit with original image.
        img_preds_box *= img_targets["orig_img_size"][0] / img_targets["img_size"][0]

        img_preds_score, img_preds_label = img_preds_class.max(dim=1)

        keep = img_preds_score > conf_threshold
        img_preds_score = img_preds_score[keep]
        img_preds_label = img_preds_label[keep]
        img_preds_box = img_preds_box[keep]

        # Apply NMS per class
        keep_indices = batched_nms(img_preds_box, img_preds_score, img_preds_label, nms_threshold)
        scores.append(img_preds_score[keep_indices])
        labels.append(img_preds_label[keep_indices])
        boxes.append(img_preds_box[keep_indices])

    return scores, labels, boxes


In [4]:
def loss_fn(
    preds_class: torch.Tensor,
    preds_box: torch.Tensor,
    anchors: torch.Tensor,
    targets: list[dict],
    iou_lower_threshold: float = 0.4,
    iou_upper_threshold: float = 0.5,
):
    """Compute Focal Loss.
    Args:
        preds_class (Tensor[N, num_anchors, num_classes]): Classes.
        preds_box (Tensor[N, num_anchors, 4]): Bounding boxes.
            Coordinate should be (x_diff, y_diff, w_diff, h_diff).
        anchors (Tensor[num_anchors, 4]): Coordinate should be (xmin, ymin, xmax, ymax).
        targets: Labels.
    """
    anchors_xywh = convert_to_xywh(anchors)

    # Calculate target function per image
    loss_class = preds_class.new_tensor(0)
    loss_box = preds_box.new_tensor(0)
    for img_preds_class, img_preds_box, img_targets in zip(
        preds_class, preds_box, targets
    ):
        # If no ground truth for this image.
        if img_targets["classes"].shape[0] == 0:
            # Create target class as background.
            targets_class = torch.zeros_like(img_preds_class)
            loss_class += sigmoid_focal_loss(
                img_preds_class, targets_class, reduction="sum"
            )
            continue

        # Get a bounding box which has max IoU.
        ious = calc_iou(anchors, img_targets["boxes"])[0]
        ious_max, ious_argmax = ious.max(dim=1)

        # Init class label as -1.
        # Set label of anchor box as -1 if iou_lower_threshold <= IoU <= iou_upper_threshold
        # in order not to calculate loss.
        targets_class = torch.full_like(img_preds_class, -1)

        targets_class[ious_max < iou_lower_threshold] = 0

        # If IoU > iou_upper_threshold, set as classification/regression target.
        positive_masks = ious_max > iou_upper_threshold
        num_positive_anchors = positive_masks.sum()

        targets_class[positive_masks] = 0
        assigned_classes = img_targets["classes"][ious_argmax]
        targets_class[positive_masks, assigned_classes[positive_masks]] = 1

        loss_class += (
            (targets_class != -1) * sigmoid_focal_loss(img_preds_class, targets_class)
        ).sum() / num_positive_anchors.clamp(min=1)

        # If no positive anchors, skip calculation of loss_box
        if num_positive_anchors == 0:
            continue

        # Get ground truth per anchor
        assgined_boxes = img_targets["boxes"][ious_argmax]
        assgined_boxes_xywh = convert_to_xywh(assgined_boxes)

        targets_box = torch.zeros_like(img_preds_box)
        targets_box[:, :2] = assgined_boxes_xywh[:, :2] - anchors_xywh[:, :2] / anchors_xywh[:, 2:]
        targets_box[:, 2:] = (assgined_boxes_xywh[:, 2:] / assgined_boxes_xywh[:, 2:]).log()

        loss_box += F.smooth_l1_loss(img_preds_box[positive_masks], targets_box[positive_masks], beta=1/9)

    batch_size = preds_class.shape[0]
    loss_class /= batch_size
    loss_box /= batch_size

    return loss_class, loss_box


In [5]:
@dataclass
class Config:
    img_dir: str = "./data/coco/val2014"
    annot_file: str = "./data/coco/instances_val2014_small.json"
    save_file: str = "./workdir/model/retinanet.pth"

    train_ratio: float = 0.8
    num_epochs: int = 50
    lr_drop: int = 45
    val_interval: int = 1
    lr: float = 1e-5
    clip: float = 0.1
    moving_avg: int = 100
    batch_size: int = 8
    num_workers: int = 4
    device: str = "cuda"


In [6]:
def collate_fn(batch):
    max_height = 0
    max_width = 0
    for img, _ in batch:
        h, w = img.shape[1:]
        max_height = max(max_height, h)
        max_width = max(max_width, w)

    height = (max_height + 31) // 32 * 32
    width = (max_width + 31) // 32 * 32

    imgs = batch[0][0].new_zeros((len(batch), 3, height, width))
    targets = []
    for i, (img, target) in enumerate(batch):
        h, w = img.shape[1:]
        imgs[i, :, :h, :w] = img
        targets.append(target)

    return imgs, targets


In [7]:
@torch.no_grad()
def evaluate(
    loader, model, loss_fn, conf_threshold=0.05, nms_threshold=0.5, device="cuda"
):
    model.eval()

    loss_class_list = []
    loss_box_list = []
    loss_list = []
    preds = []
    img_ids = []
    for imgs, targets in tqdm(loader, desc="[Validation]"):
        imgs = imgs.to(device)
        targets = [{k: v.to(device) for k, v in target.items()} for target in targets]

        preds_class, preds_box, anchors = model(imgs)

        loss_class, loss_box = loss_fn(preds_class, preds_box, anchors, targets)
        loss = loss_class + loss_box

        loss_class_list.append(loss_class)
        loss_box_list.append(loss_box)
        loss_list.append(loss)

        scores, labels, boxes = post_process(
            preds_class, preds_box, anchors, targets, conf_threshold, nms_threshold
        )
        for img_scores, img_labels, img_boxes, img_targets in zip(
            scores, labels, boxes, targets
        ):
            img_ids.append(img_targets["image_id"].item())

            # To xywh
            img_boxes[:, 2:] -= img_boxes[:, :2]

            for score, label, box in zip(img_scores, img_labels, img_boxes):
                preds.append(
                    {
                        "image_id": img_targets["image_id"].item(),
                        "category_id": loader.dataset.to_coco_label(label.item()),
                        "score": score.item(),
                        "bbox": box.to("cpu").numpy().tolist(),
                    }
                )

    loss_class = torch.stack(loss_class_list).mean().item()
    loss_box = torch.stack(loss_box_list).mean().item()
    loss = torch.stack(loss_list).mean().item()
    print(
        f"Validation loss = {loss:.3f}, class loss = {loss_class:.3f}, box loss = {loss_box:.3f}"
    )

    if len(preds) == 0:
        print("Nothing detected, skip evaluation.")
        return

    with open("tmp.json", "w") as f:
        json.dump(preds, f)

    coco_results = loader.dataset.coco.loadRes("tmp.json")

    coco_eval = COCOeval(loader.dataset.coco, coco_results, "bbox")
    coco_eval.params.imgIds = img_ids
    coco_eval.evaluate()
    coco_eval.accumulate()
    coco_eval.summarize()


In [8]:
def train_and_eval(config: Config):
    # Data augmentation
    min_sizes = (480, 512, 544, 576, 608)
    train_transform = T.Compose(
        [
            T.RandomHorizontalFlip(),
            T.RandomSelect(
                T.RandomResize(min_sizes, max_size=1024),
                T.Compose(
                    [
                        T.RandomSizeCrop(scale=(0.8, 1.0), ratio=(0.75, 1.333)),
                        T.RandomResize(min_sizes, max_size=1024),
                    ]
                ),
            ),
            T.ToTensor(),
            T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ]
    )

    test_transform = T.Compose(
        [
            T.RandomResize([min_sizes[-1]], max_size=1024),
            T.ToTensor(),
            T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ]
    )

    train_dataset = CocoDetection(
        config.img_dir, config.annot_file, transform=train_transform
    )
    val_dataset = CocoDetection(
        config.img_dir, config.annot_file, transform=test_transform
    )

    train_set, val_set = generate_subset(train_dataset, config.train_ratio)
    print(f"num of train samples", len(train_set))
    print(f"num of val samples", len(val_set))

    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        num_workers=config.num_workers,
        sampler=SubsetRandomSampler(train_set),
        collate_fn=collate_fn,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=config.batch_size,
        num_workers=config.num_workers,
        sampler=val_set,
        collate_fn=collate_fn,
    )

    model = RetinaNet(len(train_dataset.classes))
    torch.compile(model)
    model.to(config.device)

    optimizer = optim.AdamW(model.parameters(), lr=config.lr)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[config.lr_drop], gamma=0.1)

    for epoch in range(config.num_epochs):
        model.train()

        with tqdm(train_loader) as pbar:
            pbar.set_description(f"[Epoch {epoch + 1}]")

            loss_class_hist = deque()
            loss_box_hist = deque()
            loss_hist = deque()
            for imgs, targets in pbar:
                imgs = imgs.to(config.device)
                targets = [{k: v.to(config.device)for k, v in target.items()} for target in targets]

                optimizer.zero_grad()
                preds_class, preds_box, anchors = model(imgs)
                loss_class, loss_box = loss_fn(preds_class, preds_box, anchors, targets)
                loss = loss_class + loss_box

                loss.backward()
                torch.nn.utils.clip_grad.clip_grad_norm_(model.parameters(), config.clip)

                optimizer.step()

                loss_class_hist.append(loss_class.item())
                loss_box_hist.append(loss_box.item())
                loss_hist.append(loss.item())
                if len(loss_hist) > config.moving_avg:
                    loss_class_hist.popleft()
                    loss_box_hist.popleft()
                    loss_hist.popleft()
                pbar.set_postfix({
                    "loss": torch.Tensor(loss_hist).mean().item(),
                    "loss_class": torch.Tensor(loss_class_hist).mean().item(),
                    "loss_box": torch.Tensor(loss_box_hist).mean().item(),
                })
        scheduler.step()

        torch.save(model.state_dict(), config.save_file)

        if (epoch + 1) % config.val_interval == 0:
            evaluate(val_loader, model, loss_fn)


In [26]:
class CocoDetDataModule(pl.LightningDataModule):
    def __init__(
        self, root_dir, annot_file, batch_size=16, train_ratio: float = 0.8, num_workers: int = 4
    ):
        super().__init__()
        self.root_dir = root_dir
        self.annot_file = annot_file
        self.batch_size = batch_size
        self.train_ratio = train_ratio
        self.num_workers = num_workers

        min_sizes = (480, 512, 544, 576, 608)
        self.train_transform = T.Compose(
            [
                T.RandomHorizontalFlip(),
                T.RandomSelect(
                    T.RandomResize(min_sizes, max_size=1024),
                    T.Compose(
                        [
                            T.RandomSizeCrop(scale=(0.8, 1.0), ratio=(0.75, 1.333)),
                            T.RandomResize(min_sizes, max_size=1024),
                        ]
                    ),
                ),
                T.ToTensor(),
                T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ]
        )
        self.test_transform = T.Compose(
            [
                T.RandomResize([min_sizes[-1]], max_size=1024),
                T.ToTensor(),
                T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ]
        )

    def setup(self, stage: str) -> None:
        self.train_dataset = CocoDetection(self.root_dir, self.annot_file, self.train_transform)
        self.train_set, self.val_set = generate_subset(self.train_dataset, self.train_ratio)
        self.test_dataset = CocoDetection(self.root_dir, self.annot_file, self.test_transform)

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            sampler=SubsetRandomSampler(self.train_set),
            collate_fn=collate_fn,
        )

    def val_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            sampler=self.val_set,
            collate_fn=collate_fn,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            sampler=self.val_set,
            collate_fn=collate_fn,
        )


In [27]:
class RetinaNetModule(pl.LightningModule):
    def __init__(self, num_classes:int, learning_rate:float, lr_drop:float):
        super().__init__()
        self.model = RetinaNet(num_classes)
        self.learning_rate = learning_rate
        self.lr_drop = lr_drop
        self.save_hyperparameters()

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.learning_rate)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[self.lr_drop], gamma=0.1)
        return [optimizer], [scheduler]

    def training_step(self, batch):
        imgs, targets = batch
        preds_class, preds_box, anchors = self.model(imgs)
        loss_class, loss_box = loss_fn(preds_class, preds_box, anchors, targets)
        loss = loss_class + loss_box
        self.log("train_loss_class", loss_class)
        self.log("train_loss_box", loss_box)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch):
        imgs, targets = batch
        preds_class, preds_box, anchors = self.model(imgs)
        loss_class, loss_box = loss_fn(preds_class, preds_box, anchors, targets)
        loss = loss_class + loss_box
        self.log("val_loss_class", loss_class)
        self.log("val_loss_box", loss_box)
        self.log("val_loss", loss)
        return loss


In [28]:
config = Config()
data_module = CocoDetDataModule(config.img_dir, config.annot_file, config.batch_size, config.train_ratio, config.num_workers)
model_module = RetinaNetModule(num_classes=2, learning_rate=config.lr, lr_drop=config.lr_drop)
trainer = pl.Trainer(accelerator="gpu", max_epochs=10)
trainer.fit(model_module, datamodule=data_module)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


loading annotations into memory...
Done (t=0.48s)                                    
creating index...
index created!
loading annotations into memory...
Done (t=0.44s)
creating index...


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type      | Params
------------------------------------
0 | model | RetinaNet | 19.8 M
------------------------------------
19.8 M    Trainable params
0         Non-trainable params
19.8 M    Total params
79.168    Total estimated model params size (MB)


index created!
Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:00<00:00,  1.37it/s]

/opt/venv/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 8. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


Epoch 1:  10%|▉         | 96/1000 [00:56<08:48,  1.71it/s, v_num=4]        

/opt/venv/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
