In [1]:
import torch

from torchvision.transforms import v2
from torch.utils.data import DataLoader
from torchvision.datasets import FakeData


NUM_CLASSES = 17

In [2]:
transform = v2.Compose([
    v2.PILToTensor(),
    v2.RandomResizedCrop(size=(224, 224), antialias=True),
    v2.RandomHorizontalFlip(p=0.5),
    v2.ToDtype(torch.float32, scale=True),  # to float32 in [0, 1]
    v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),  # typically from ImageNet
])

In [3]:
dataset = FakeData(size=1000, num_classes=NUM_CLASSES, transform=transform)

img, label = dataset[0]
print(f"{type(img) = }, {img.dtype = }, {img.shape = }, {label = }")

type(img) = <class 'torch.Tensor'>, img.dtype = torch.float32, img.shape = torch.Size([3, 224, 224]), label = 16


In [4]:
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

for images, labels in dataloader:
    print(f"{images.shape = }, {labels.shape = }")
    print(labels.dtype)
    # <rest of the training loop here>
    break

images.shape = torch.Size([4, 3, 224, 224]), labels.shape = torch.Size([4])
torch.int64


In [6]:
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

cutmix = v2.CutMix(num_classes=NUM_CLASSES)
mixup = v2.MixUp(num_classes=NUM_CLASSES)
cutmix_or_mixup = v2.RandomChoice([cutmix, mixup])

for images, labels in dataloader:
    print(f"Before CutMix/MixUp: {images.shape = }, {labels.shape = }")
    images, labels = cutmix_or_mixup(images, labels)
    print(f"After CutMix/MixUp: {images.shape = }, {labels.shape = }")

    print(labels)
    break

Before CutMix/MixUp: images.shape = torch.Size([4, 3, 224, 224]), labels.shape = torch.Size([4])
After CutMix/MixUp: images.shape = torch.Size([4, 3, 224, 224]), labels.shape = torch.Size([4, 17])
tensor([[0.0000, 0.0000, 0.0000, 0.2740, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.7260, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.7260, 0.0000, 0.2740, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.2740, 0.0000, 0.0000, 0.0000, 0.7260, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.7260, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.2740, 0.0000, 0.0000, 0.0000]])
