# Interactive CIFAR-10 Classification Tutorial

This notebook provides an interactive guide to training and evaluating models on the CIFAR-10 dataset. You'll learn how to:
1. Load and visualize CIFAR-10 data
2. Build and train a CNN model
3. Monitor training with W&B
4. Evaluate and visualize results

## Table of Contents
1. [Setup and Data Loading](#setup)
2. [Data Visualization](#visualization)
3. [Model Architecture](#model)
4. [Training Loop](#training)
5. [Model Evaluation](#evaluation)
6. [Interactive Predictions](#predictions)

## 1. Setup and Data Loading <a name="setup"></a>

First, let's import the necessary libraries and set up our environment:

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import wandb

# Set random seed for reproducibility
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# CIFAR10 classes
classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

### Load and preprocess the CIFAR-10 dataset:

In [None]:
# Define transforms
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

# Load datasets
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100,
                                         shuffle=False, num_workers=2)

## 2. Data Visualization <a name="visualization"></a>

Let's visualize some sample images from the dataset:

In [None]:
def imshow(img):
    img = img / 2 + 0.5  # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

# Get random training images
dataiter = iter(trainloader)
images, labels = next(dataiter)

# Show images
plt.figure(figsize=(10, 10))
imshow(torchvision.utils.make_grid(images[:16]))
plt.axis('off')
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(16)))

## 3. Model Architecture <a name="model"></a>

We'll create a CNN model for CIFAR-10 classification:

In [None]:
class CIFAR10Net(nn.Module):
    def __init__(self):
        super(CIFAR10Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(128 * 4 * 4, 512)
        self.fc2 = nn.Linear(512, 10)
        self.dropout = nn.Dropout(0.25)
        self.batch_norm1 = nn.BatchNorm2d(32)
        self.batch_norm2 = nn.BatchNorm2d(64)
        self.batch_norm3 = nn.BatchNorm2d(128)

    def forward(self, x):
        x = self.pool(torch.relu(self.batch_norm1(self.conv1(x))))
        x = self.pool(torch.relu(self.batch_norm2(self.conv2(x))))
        x = self.pool(torch.relu(self.batch_norm3(self.conv3(x))))
        x = x.view(-1, 128 * 4 * 4)
        x = self.dropout(torch.relu(self.fc1(x)))
        x = self.fc2(x)
        return x

model = CIFAR10Net().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Print model summary
print(model)

## 4. Training Loop <a name="training"></a>

Let's train our model with W&B integration for monitoring:

In [None]:
# Initialize wandb
wandb.init(
    project="cifar10-interactive",
    config={
        "learning_rate": 0.001,
        "epochs": 10,
        "batch_size": 128,
        "architecture": "CNN",
        "dataset": "CIFAR-10"
    }
)

def train_model(epochs=10):
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data[0].to(device), data[1].to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            if i % 100 == 99:
                accuracy = 100 * correct / total
                wandb.log({
                    "epoch": epoch,
                    "loss": running_loss / 100,
                    "accuracy": accuracy
                })
                print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 100:.3f}, accuracy: {accuracy:.2f}%')
                running_loss = 0.0
                correct = 0
                total = 0

        # Evaluate on test set after each epoch
        test_accuracy = evaluate_model()
        wandb.log({"test_accuracy": test_accuracy})
        print(f'Epoch {epoch + 1} Test Accuracy: {test_accuracy:.2f}%')

    print('Finished Training')
    wandb.finish()

# Train the model
train_model()

## 5. Model Evaluation <a name="evaluation"></a>

Let's evaluate our model's performance:

In [None]:
def evaluate_model():
    model.eval()
    correct = 0
    total = 0
    class_correct = list(0. for i in range(10))
    class_total = list(0. for i in range(10))
    
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            c = (predicted == labels).squeeze()
            for i in range(len(labels)):
                label = labels[i]
                class_correct[label] += c[i].item()
                class_total[label] += 1
    
    # Print accuracy for each class
    for i in range(10):
        print(f'Accuracy of {classes[i]:5s}: {100 * class_correct[i] / class_total[i]:.2f}%')
    
    return 100 * correct / total

print('Overall Test Accuracy: {:.2f}%'.format(evaluate_model()))

## 6. Interactive Predictions <a name="predictions"></a>

Let's create an interactive prediction visualization:

In [None]:
def visualize_predictions(num_images=5):
    model.eval()
    dataiter = iter(testloader)
    images, labels = next(dataiter)
    
    # Get predictions
    outputs = model(images[:num_images].to(device))
    _, predicted = torch.max(outputs, 1)
    
    # Plot images with predictions
    fig = plt.figure(figsize=(15, 3))
    for idx in range(num_images):
        ax = fig.add_subplot(1, num_images, idx + 1, xticks=[], yticks=[])
        imshow(images[idx])
        ax.set_title(f'Pred: {classes[predicted[idx]]}')
        
        # Add color coding for correct/incorrect predictions
        if predicted[idx] == labels[idx]:
            plt.setp(ax.spines.values(), color='green', linewidth=2)
        else:
            plt.setp(ax.spines.values(), color='red', linewidth=2)
            ax.set_title(f'Pred: {classes[predicted[idx]]}\nTrue: {classes[labels[idx]]}',
                        color='red')
    
    plt.tight_layout()
    plt.show()

# Visualize some predictions
visualize_predictions()

## Next Steps

Now that you've completed this interactive tutorial, you can:
1. Experiment with different model architectures
2. Try different hyperparameters
3. Implement data augmentation techniques
4. Use transfer learning with pre-trained models

Check your W&B dashboard to see detailed training metrics and visualizations!