In [86]:
import torch
import torchvision.models as models
from torchvision import datasets, transforms as T
from torchvision.datasets import wrap_dataset_for_transforms_v2
from torchvision.models import ResNet50_Weights  # <-- import this
from torchvision.transforms import v2
from torchvision.utils import draw_bounding_boxes


In [55]:
resnet50 = models.resnet50(weights = ResNet50_Weights.IMAGENET1K_V2)

In [146]:
# haven't checked if these are official. website down
itol = [
    "aeroplane", "bicycle", "bird", "boat", "bottle",
    "bus", "car", "cat", "chair", "cow",
    "diningtable", "dog", "horse", "motorbike", "person",
    "pottedplant", "sheep", "sofa", "train", "tvmonitor"
]
itol[13-1], itol[15-1]

('horse', 'person')

In [None]:
def plot(sample):
    # img is normalized, so have to unnormalize
    mean= torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)
    std=torch.tensor([0.229, 0.224, 0.225]).view(3,1,1)
    img, target = sample
    img = img.data
    img = img*std + mean
    toimg = v2.ToPILImage()
    labels = [str(itol[i-1]) for i in target['labels']]
    toimg(draw_bounding_boxes(img, target['boxes'].data, width = 3, labels = labels)).show()

In [178]:
# imagenet stats here: https://docs.pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html#torchvision.models.ResNet50_Weights
transform = v2.Compose([
        v2.ToImage(),
        v2.ToDtype(torch.float32, scale=True),
        v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])


data_filepath = '/Users/veb/ms/nanoDETR/data'
dataset = datasets.VOCDetection(root = data_filepath, 
                                year = '2012', 
                                image_set = 'train', 
                                download = False,
                                transform = transform) # len 5717, consisten with data
dataset = wrap_dataset_for_transforms_v2(dataset) 

In [190]:
plot(dataset[0])

In [None]:
# original img: (3, H0, W0)
# backbone output: (C, H, W), where typically C = 2048, H, W = H0 / 32, W0 / 32

# The input images are batched together, applying 0-padding adequately to ensure
# they all have the same dimensions (H0,W0) as the largest image of the batch.