In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import logging
import random
from pathlib import Path
from typing import Any

import torch
import torchvision as tv
from PIL import Image
from torch.utils.data import DataLoader
from torchinfo import summary
from torchvision.transforms import v2 as v2
from tqdm import tqdm

import bb
import tt

LOG = logging.getLogger(__name__)
tt.logging_init()

SEED = 325
tt.seed(SEED)

In [3]:
data_path = Path.home() / "src/data"
mc_data_path = data_path / "minecraft/info.json"
dset = bb.Dataset.load(mc_data_path)
torch_root = data_path / "torchvision"

# bb.TorchDataset

In [None]:
tdset = bb.TorchDataset(data_path / "minecraft")
tdset

In [None]:
loader = DataLoader(tdset, batch_size=8, collate_fn=bb.TorchDataset.collate_fn)
images, targets = next(iter(loader))
result = tv.utils.make_grid(
    [bb.torch_plot_bb(img, target, tdset.categories) for img, target in zip(images, targets)], nrow=2
)
v2.functional.to_pil_image(result)

In [None]:
img, target = tdset[10]
categories = tdset.dset.categories
label_names = [categories[label.item()] for label in target["labels"]]
result = bb.torch_plot_bb(img, target, tdset.categories)
v2.functional.to_pil_image(result)

# Minecraft COCO

In [None]:
# https://docs.pytorch.org/vision/main/auto_examples/transforms/plot_transforms_e2e.html

IMAGES_PATH = data_path / "coco/minecraft/images"
ANNOTATIONS_PATH = data_path / "coco/minecraft/annotations.json"
coco_dataset = tv.datasets.wrap_dataset_for_transforms_v2(
    # The transforms can be v2 since they're handled by the wrapper.
    tv.datasets.CocoDetection(IMAGES_PATH, ANNOTATIONS_PATH, transforms=v2.ToImage())
)

coco_categories = {
    cat["id"]: cat["name"] for cat in coco_dataset.coco.loadCats(coco_dataset.coco.getCatIds())
}
print(coco_categories)

In [None]:
img, target = coco_dataset[0]
label_names = [coco_categories[label.item()] for label in target["labels"]]
print(target)
print(label_names)
v2.ToPILImage()(img)

In [None]:
def collate_fn(
    batch: list[tuple[tv.tv_tensors.Image, dict[str, Any]]],
) -> tuple[torch.Tensor, list[dict[str, Any]]]:
    """For use with Dataloader - keep targets as a list"""
    images = torch.stack([item[0] for item in batch])
    targets = [item[1] for item in batch]
    return images, targets


coco_loader = DataLoader(coco_dataset, batch_size=8,
    collate_fn=collate_fn)
next(iter(coco_loader))

# MCDataset

In [4]:
mcd_root = data_path / "coco/minecraft"
mcd = bb.MCDataset(mcd_root)
img, target = mcd.coco_dataset[0]

loading annotations into memory...
Done (t=0.00s)
creating index...
index created!


In [7]:
mcd_loader = DataLoader(mcd, batch_size=8, collate_fn=bb.MCDataset.collate_fn)
images, targets = next(iter(mcd_loader))
# bb.plot_bb_grid(images, targets, mcd.categories)

# Fine Tune

In [None]:
print(img.shape)
batch = [trainer.weights.transforms()(img)]
# summary(trainer.model, input_data=[batch])
summary(
    trainer.model,
    input_size=[1, 3, 640, 640],
    col_names=[
        "input_size",
        "output_size",
        # "num_params",
        # "params_percent",
        "kernel_size",
        # "mult_adds",
        "trainable",
    ],
    row_settings=[
        # "ascii_only",
        "depth",
        "var_names",
    ],
    depth=10,
    verbose=0
)

In [None]:
from torchvision.models.detection import fcos


class Trainer:
    def __init__(self, device="mps") -> None:
        self.device = torch.device(device)
        # 11 classes - 10 plus background
        num_classes = 11

        self.weights = fcos.FCOS_ResNet50_FPN_Weights.COCO_V1
        self.model = fcos.fcos_resnet50_fpn(weights=self.weights)
        self.preprocess = self.weights.transforms()

        # Conv2d(256, 91, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.orig_cls_logits = self.model.head.classification_head.cls_logits
        self.model.head.classification_head.cls_logits = torch.nn.Conv2d(
            256, num_classes, kernel_size=3, stride=1, padding=1
        )
        self.model.head.classification_head.num_classes = num_classes

        # Move to device after head is replaced
        self.model = self.model.to(device)

        for param in self.model.parameters():
            param.requires_grad_(False)
        for param in self.model.head.classification_head.cls_logits.parameters():
            param.requires_grad_(True)
        self.optimizer = torch.optim.AdamW(params=self.model.parameters(), lr=1e-4)

    def infer(self, img: tv.tv_tensors.Image) -> Image.Image:
        self.model.eval()
        img = img.to(self.device)
        batch = [self.preprocess(img)]
        with torch.inference_mode():
            prediction = self.model(batch)[0]
        labels = [self.weights.meta["categories"][i] for i in prediction["labels"]]
        box = tv.utils.draw_bounding_boxes(
            img,
            boxes=prediction["boxes"],
            labels=labels,
            colors="red",
            width=4,
            font="/System/Library/Fonts/Helvetica.ttc",  # macOS
            font_size=20,
        )
        return v2.functional.to_pil_image(box.detach())

    def train_one_epoch(self, train_loader):
        self.model.train()

        for images, targets in tqdm(train_loader):
            images = images.to(self.device)
            targets = [
                {
                    # Handle images with no boxes
                    "boxes": t.get("boxes", torch.zeros(0, 4)).to(self.device),
                    "labels": t.get("labels", torch.zeros(0, dtype=torch.int64)).to(
                        self.device
                    ),
                }
                for t in targets
            ]

            # Forward pass of image through network and get output
            batch = self.preprocess(images)
            # torchvision models return loss in train mode.
            loss_dict = self.model(batch, targets)
            loss = sum(loss_dict.values())

            # Zero gradients
            self.optimizer.zero_grad()

            # Backpropagate gradients
            loss.backward()
            # Do a single optimization step
            self.optimizer.step()


img = mcd[0][0]
trainer = Trainer()

In [None]:
trainer.infer(img)

In [None]:
trainer.train_one_epoch(mcd_loader)

In [None]:
images, targets = next(iter(mcd_loader))
trainer.preprocess(images).shape