In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import logging
import random
from pathlib import Path
from typing import Any

import torch
import torchvision as tv
from torch.utils.data.dataloader import DataLoader
from torchvision.transforms import v2 as v2

import bb
import tt

LOG = logging.getLogger(__name__)
tt.logging_init()

SEED = 325
random.seed(SEED)

In [None]:
data_path = Path.home() / "src/data"
mc_data_path = data_path / "minecraft/info.json"
dset = bb.Dataset.load(mc_data_path)
torch_root = data_path / "torchvision"

# bb.TorchDataset

In [None]:
tdset = bb.TorchDataset(data_path / "minecraft")
tdset

In [None]:
loader = DataLoader(tdset, batch_size=8, collate_fn=bb.TorchDataset.collate_fn)
images, targets = next(iter(loader))
result = tv.utils.make_grid(
    [bb.torch_plot_bb(img, target, tdset.categories) for img, target in zip(images, targets)], nrow=2
)
v2.functional.to_pil_image(result)

In [None]:
img, target = tdset[10]
categories = tdset.dset.categories
label_names = [categories[label.item()] for label in target["labels"]]
result = bb.torch_plot_bb(img, target, tdset.categories)
v2.functional.to_pil_image(result)

# Minecraft COCO

In [None]:
# https://docs.pytorch.org/vision/main/auto_examples/transforms/plot_transforms_e2e.html

IMAGES_PATH = data_path / "coco/minecraft/images"
ANNOTATIONS_PATH = data_path / "coco/minecraft/annotations.json"
coco_dataset = tv.datasets.wrap_dataset_for_transforms_v2(
    # The transforms can be v2 since they're handled by the wrapper.
    tv.datasets.CocoDetection(IMAGES_PATH, ANNOTATIONS_PATH, transforms=v2.ToImage())
)

coco_categories = {
    cat["id"]: cat["name"] for cat in coco_dataset.coco.loadCats(coco_dataset.coco.getCatIds())
}
print(coco_categories)

In [None]:
img, target = coco_dataset[0]
label_names = [coco_categories[label.item()] for label in target["labels"]]
print(target)
print(label_names)
v2.ToPILImage()(img)

In [None]:
def collate_fn(
    batch: list[tuple[tv.tv_tensors.Image, dict[str, Any]]],
) -> tuple[torch.Tensor, list[dict[str, Any]]]:
    """For use with Dataloader - keep targets as a list"""
    images = torch.stack([item[0] for item in batch])
    targets = [item[1] for item in batch]
    return images, targets


coco_loader = DataLoader(coco_dataset, batch_size=8,
    collate_fn=collate_fn)
next(iter(coco_loader))

# MCDataset

In [None]:
mcd_root = data_path / "coco/minecraft"
mcd = bb.MCDataset(mcd_root)
img, target = mcd.coco_dataset[0]

In [None]:
mcd_loader = DataLoader(mcd, batch_size=8, collate_fn=bb.MCDataset.collate_fn)
next(iter(mcd_loader))