# Assignment 5 Vision Transformers

## Part 1 - Load CIFAR100 Dataset

In [4]:
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
from matplotlib import pyplot as plt

import os
import pickle
import numpy as np
from PIL import Image
from torchvision.io import read_image
from torch.utils.data import DataLoader
import torchvision


In [5]:
# Download CIFAR10 data
torchvision.datasets.CIFAR100(root="./data", train=True, download=True)
torchvision.datasets.CIFAR100(root="./data", train=False, download=True)

100%|██████████| 169M/169M [00:04<00:00, 36.6MB/s] 


Dataset CIFAR100
    Number of datapoints: 10000
    Root location: ./data
    Split: Test

In [6]:
# Define dataset class
class CustomImageDataset(Dataset):
    """
    Custom class to wrap around CIFAR100 dataset.
    """
    def __init__(self, data_path, train=True, transform=None, target_transform=None):
        """
        Init function for the class
        """
        self.transform = transform
        self.target_transform = target_transform

        batch_files = ["train"] if train else ["test"]
        self.data, self.labels = [], []

         # Load the batch file(s)
        for batch in batch_files:
            with open(os.path.join(data_path, "cifar-100-python", f"{batch}"), "rb") as f:
                batch_dict = pickle.load(f, encoding="bytes")
                self.data.append(batch_dict[b"data"])  # Image data (flattened)
                self.labels.extend(batch_dict[b"fine_labels"])  # Fine-grained labels (100 classes)

        # Reshape data to (N, 3, 32, 32)
        self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)

    def __len__(self):
        """
        Len member function
        """
        return len(self.labels)

    def __getitem__(self, idx):
        """
        Get member function
        """
        image = self.data[idx].transpose((1, 2, 0))  # Convert from (C, H, W) to (H, W, C)
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        if self.target_transform:
            label = self.target_transform(label)

        return image, label

In [None]:
# TODO: split dataset into train, test, and validate

#

In [None]:
# Init Datasets
train_data = CustomImageDataset("./data/", train=True)
test_data = CustomImageDataset("./data/", train=False)

In [None]:
# Init Dataloaders
train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

In [None]:
#