In [16]:
from PIL import Image
from torchvision import transforms


def pad_img_by_size(img, size=512):
    """
    pad image to size * size with black background
    """
    w, h = img.size
    if w > h:
        img = transforms.functional.pad(img, (0, (w - h) // 2), fill=0)
    else:
        img = transforms.functional.pad(img, ((h - w) // 2, 0), fill=0)
    img = transforms.functional.resize(img, (size, size))
    return img

In [17]:
from datasets import load_dataset
import torch
from torchvision import transforms

train_dataset = load_dataset(
    "food101",
    split="all",
)
train_dataset[:3]

# {'image': [<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=384x512>,
#   <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x512>,
#   <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x383>],
#  'label': [6, 6, 6]}

{'image': [<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=384x512>,
  <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x512>,
  <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x383>],
 'label': [6, 6, 6]}

In [23]:
train_transforms = transforms.Compose(
    [
        transforms.Lambda(lambda pil_image: pad_img_by_size(pil_image, 512)),
        transforms.RandomHorizontalFlip(),
        # transforms.RandomVerticalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
    ]
)
# img = Image.open("results/imagenette/images-24790.jpg")
# img = train_transforms(img)
# img.shape

tensor(0.9451)

In [25]:
len(train_dataset)

101000

In [24]:
def transform_images(examples):
    images = [train_transforms(image.convert("RGB")) for image in examples["image"]]
    return {"input": images, "label": examples["label"]}

In [26]:
train_dataset.set_transform(transform_images)

In [27]:
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=4, shuffle=True, num_workers=2
)

In [None]:
for batch in train_dataloader:
    print(batch)
    break

In [32]:
batch["label"]

tensor([35, 68,  0, 75])

In [32]:
train_dataset[0]

(tensor([[[-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          ...,
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.]],
 
         [[-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          ...,
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.]],
 
         [[-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          ...,
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.]]]),
 0)