In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from data_loader import load_cifar
import numpy as np

# Load the data
data = load_cifar()
train, test = data["train"], data["test"]
print(f"Train set: {len(train)} samples")
print(f"Test set: {len(test)} samples")

Train set: 50000 samples
Test set: 10000 samples


In [11]:
class CIFARNet(nn.Module):
    def __init__(self):
        super().__init__()
        # First convolutional layer
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        # Second convolutional layer
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        # Third convolutional layer
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        # Fully connected layers
        self.fc1 = nn.Linear(128 * 4 * 4, 512)
        self.fc2 = nn.Linear(512, 10)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        # First conv block
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(x, 2)
        
        # Second conv block
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(x, 2)
        
        # Third conv block
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.max_pool2d(x, 2)
        
        # Flatten and fully connected layers
        x = x.reshape(-1, 128 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# Initialize the model
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
net = CIFARNet().to(device)
print(f"Using device: {device}")

Using device: mps


In [12]:
# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

print_every = 10000

# Training loop
def train_epoch(net, train_data, criterion, optimizer, device):
    net.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for i, data in enumerate(train_data):
        # Get the inputs and labels
        inputs = torch.tensor(data["image"], dtype=torch.float32).unsqueeze(0).permute(0, 3, 1, 2)
        labels = torch.tensor([data["label"]], dtype=torch.long)
        
        # Move to device
        inputs, labels = inputs.to(device), labels.to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        
        # Statistics
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        if i % print_every == print_every - 1:
            print(f'[{i + 1}] loss: {running_loss / print_every:.3f}, accuracy: {100. * correct / total:.2f}%')
            running_loss = 0.0
            correct = 0
            total = 0

# Train for multiple epochs
num_epochs = 10
for epoch in range(num_epochs):
    print(f'\nEpoch {epoch + 1}/{num_epochs}')
    train_epoch(net, train, criterion, optimizer, device)

print('Finished Training')


Epoch 1/10


RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

In [None]:
# Evaluate on test set
def evaluate(net, test_data, device):
    net.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for data in test_data:
            inputs = torch.tensor(data["image"], dtype=torch.float32).unsqueeze(0).permute(0, 3, 1, 2)
            labels = torch.tensor([data["label"]], dtype=torch.long)
            
            inputs, labels = inputs.to(device), labels.to(device)
            
            outputs = net(inputs)
            _, predicted = outputs.max(1)
            
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    accuracy = 100. * correct / total
    print(f'Test Accuracy: {accuracy:.2f}%')

evaluate(net, test, device)