In [1]:
import torch
import os
import pandas as pd
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
from torchvision.io import decode_image

## Dataset
Just a sample to see how to fw a custom dataset object, shaping it in FashionMNIST format.

In [2]:
class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform = None, target_transform = None):
        self.image_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.image_labels)

    def __get_item__(self, idx):
        """Returns image's tensor and the label for each input index"""
        img_path = os.path.join(self.img_dir, self.image_labels.iloc[idx, 0])
        image = decode_image(img_path) # decodes image and returns its tensor
        label = self.image_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

## Transforms
can create our own transforms with Lambda :v

the lambda here creates a one hot encoded representation of the label, by creating a tensor of [10] shape and then setting the label's index value to 1.

In [None]:
from torchvision import datasets
from torchvision.transforms import Lambda

In [None]:
train_ds = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
    target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))
)