In [102]:
import os
import numpy as np
import torchvision.transforms

In [103]:
import pickle

DATA_DIR = 'data/cifar-10-batches-py'
LABEL_FILE = 'batches.meta'
TEST_DATA = 'test_batch'

def unpickle(file):
    with open(file, 'rb') as fo:
        data = pickle.load(fo, encoding='bytes')
    return data

def load_cifar_batches(data_dir):
    """Load all training batches and the metadata."""

    all_data = []
    all_labels = []

    for i in range(1, 6):
        file_path = os.path.join(data_dir, f'data_batch_{i}')
        batch_dict = unpickle(file_path)

        # Keys are stored as bytes, so we need to decode them
        data = batch_dict[b'data']
        labels = batch_dict[b'labels']

        all_data.append(data)
        all_labels.extend(labels)

    X_train = np.concatenate(all_data)
    y_train = np.array(all_labels)

    # Load Test Data
    test_file_path = os.path.join(data_dir, TEST_DATA)
    test_dict = unpickle(test_file_path)
    X_test = test_dict[b'data']
    y_test = np.array(test_dict[b'labels'])

    # Load Metadata (Class Names)
    meta_file_path = os.path.join(data_dir, LABEL_FILE)
    meta_dict = unpickle(meta_file_path)
    class_names = [name.decode('utf-8') for name in meta_dict[b'label_names']]

    return X_train, y_train, X_test, y_test, class_names

X_train, y_train, X_test, y_test, class_names = load_cifar_batches(DATA_DIR)

In [104]:
print(f"Dataset X_train shape: {X_train.shape}")
print(f"Dataset y_train shape: {y_train.shape}")
print(f"Test dataset X_test shape: {X_test.shape}")
print(f"Test dataset y_test shape: {y_test.shape}")
print(f"Classes number: {len(class_names)}")

Dataset X_train shape: (50000, 3072)
Dataset y_train shape: (50000,)
Test dataset X_test shape: (10000, 3072)
Test dataset y_test shape: (10000,)
Classes number: 10


In [105]:
print(f"Total training samples: {X_train.shape[0]}")
print(f"Shape of one image: {X_train.shape[1]} (32 * 32 * 3)")
print(f"Class names: {class_names}")

# Example: Reshape a single image to see its dimensions
# The 3072 pixels are stored as (R-channel, G-channel, B-channel)
# Each channel has 1024 pixels (32x32)
example_image = X_train[0].reshape(3, 32, 32).transpose(1, 2, 0) # Transpose for (H, W, C)
print(f"Shape of first image after reshaping: {example_image.shape}")
print(f"Label of first image: {class_names[y_train[0].item()]}")

Total training samples: 50000
Shape of one image: 3072 (32 * 32 * 3)
Class names: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
Shape of first image after reshaping: (32, 32, 3)
Label of first image: frog


In [106]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class Cifar10RawDataset(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = torch.tensor(labels, dtype=torch.long)
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image_data = self.data[idx]
        label = self.labels[idx]

        # Reshape the flattened image: (3072,) -> (3, 32, 32)
        # and convert to a NumPy array for compatibility with torchvision.transforms
        image = image_data.reshape(3, 32, 32).transpose(1, 2, 0)

        # Apply the transformations
        if self.transform:
            image = self.transform(image)

        return image, label

In [107]:
# CIFAR-10 standard mean and standard deviation
CIFAR_MEAN = [0.4914, 0.4822, 0.4465]
CIFAR_STD = [0.2023, 0.1994, 0.2010]

# Training Transformations (includes augmentation)
train_transform = transforms.Compose([
    transforms.ToPILImage(),    # Convert HWC NumPy array to PIL image
    transforms.RandomHorizontalFlip(),  # Augmentation
    transforms.RandomCrop(32, padding=4),   # Augmentation
    transforms.ToTensor(),      # Convert PIL image to Tensor
    transforms.Normalize(CIFAR_MEAN, CIFAR_STD) # Scales to [0, 1]
])

# Testing/Validation Transformation (no augmentation)
test_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
    transforms.Normalize(CIFAR_MEAN, CIFAR_STD)
])

In [108]:
# Transform data in a different approach
transform_train = torchvision.transforms.Compose([
    torchvision.transforms.Resize(40),
    torchvision.transforms.RandomResizedCrop(32, scale=(0.64, 1.0), ratio=(1.0, 1.0)),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
])

transform_test = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
])

In [109]:
# Create Datasets
train_dataset = Cifar10RawDataset(X_train, y_train, train_transform)
test_dataset = Cifar10RawDataset(X_test, y_test, test_transform)

# Create DataLoaders
BATCH_SIZE = 64
NUM_OF_WORKERS = 4
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_OF_WORKERS)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, num_workers=NUM_OF_WORKERS)

In [110]:
print(f"Number of training batches: {len(train_loader)}")
print(f"Number of training data points: {len(train_loader.dataset)}")

Number of training batches: 782
Number of training data points: 50000
