In [1]:
import os
import pickle
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm

In [2]:
# Custom Dataset class for CIFAR-10 with lazy loading
class CIFAR10Dataset(Dataset):
    def __init__(self, data_path, batch_files, transform=None):
        self.data_path = data_path
        self.batch_files = batch_files
        self.transform = transform
        self.batch_data = None  # Only load the necessary batch when needed
        self.batch_labels = None
        self.batch_index = -1  # Track the currently loaded batch
        self.index_map = []  # Maps dataset index to batch index and in-batch index
        self._create_index_map()

    def _create_index_map(self):
        """Create a map of global indices to batch indices."""
        start_idx = 0
        for batch_num, batch_file in enumerate(self.batch_files):
            with open(os.path.join(self.data_path, batch_file), 'rb') as f:
                batch = pickle.load(f, encoding='bytes')
                batch_size = len(batch[b'labels'])
                self.index_map.extend([(batch_num, i) for i in range(batch_size)])
            start_idx += batch_size

    def _load_batch(self, batch_num):
        """Load a batch given its batch number."""
        batch_file = self.batch_files[batch_num]
        with open(os.path.join(self.data_path, batch_file), 'rb') as f:
            batch = pickle.load(f, encoding='bytes')
            self.batch_data = batch[b'data'].reshape(-1, 3, 32, 32)
            self.batch_labels = batch[b'labels']
        self.batch_index = batch_num  # Update currently loaded batch

    def __len__(self):
        return len(self.index_map)

    def __getitem__(self, idx):
        # Map global index to batch number and in-batch index
        batch_num, in_batch_idx = self.index_map[idx]

        # Load the batch if it's not already loaded
        if batch_num != self.batch_index:
            self._load_batch(batch_num)

        # Fetch image and label from the loaded batch
        image = self.batch_data[in_batch_idx]
        label = self.batch_labels[in_batch_idx]

        # Convert to the expected format (H x W x C)
        image = image.transpose(1, 2, 0)

        if self.transform:
            image = self.transform(image)

        return image, label


# Transformations for CIFAR-10 (ResNet expects 224x224 images)
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),  # ResNet-50 requires 224x224 input size
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize the images
])

In [3]:
# Path to the dataset in your Google Drive
data_path = '/content/drive/MyDrive/Artificial_Intelligence_Course_NTUT/cifar-10-python/cifar-10-batches-py/'

# Training and test batch file names
train_batches = [f'data_batch_{i}' for i in range(1, 6)]
test_batches = ['test_batch']

# Create Dataset instances
train_dataset = CIFAR10Dataset(data_path, train_batches, transform=transform)
test_dataset = CIFAR10Dataset(data_path, test_batches, transform=transform)

# Create DataLoader for train and test datasets
trainloader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=4)
testloader = DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=4)

# Model, Loss, and Optimizer setup
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet50

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = resnet50(pretrained=True)

# Modify the fully connected layer to match CIFAR-10 (10 classes)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)
model = model.to(device)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-2)



In [None]:
# Training loop with progress bar
def train_model(model, trainloader, criterion, optimizer, num_epochs=10):
    model.train()

    for epoch in range(num_epochs):
        running_loss = 0.0
        correct = 0
        total = 0

        # Create a progress bar for each epoch
        progress_bar = tqdm(enumerate(trainloader), total=len(trainloader), desc=f'Epoch {epoch+1}/{num_epochs}')

        for i, (inputs, labels) in progress_bar:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()  # Zero the gradients

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

            # Accumulate statistics
            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            # Update the progress bar with loss and accuracy information
            progress_bar.set_postfix(loss=loss.item(), accuracy=correct / total * 100)

        # Calculate epoch statistics
        epoch_loss = running_loss / len(trainloader.dataset)
        epoch_acc = correct / total * 100
        print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%')

    print('Finished Training')


# Training the model
train_model(model, trainloader, criterion, optimizer, num_epochs=10)

# Testing loop (optional)
def test_model(model, testloader):
    model.eval()  # Set model to evaluation mode
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f'Test Accuracy: {correct / total * 100:.2f}%')

# Test the model
test_model(model, testloader)

Epoch 1/10:  67%|██████▋   | 132/196 [07:59<02:49,  2.64s/it, accuracy=80.1, loss=0.409]