## Training AlexNet on CIFAR-10 dataset 
- **[Full Architecture Explanation](https://github.com/sammmeeeer/From-LR-Transformers/blob/main/Deep-Neural-Networks/AlexNet.ipynb)**

In [7]:
import torch 
import torch.nn as nn 
import torch.optim as optim 
import time 
import torchvision 
import torchvision.transforms as transforms
from torch.utils.data import DataLoader 
import numpy as np 
from tqdm import tqdm 

In [8]:
# The AlexNet Architecture customised for CIFAR-10 
# Will reduce conv layers 5 to 3 for faster training
# Also smaller kernel sizes (3 x 3)

# Training optimizations 
 # Adam optimizer (faster convergence)
 # 10 epochs
 # Larger batch_size = 256
 # Will remove Response Layer Norm 


class SimpleAlexNet(nn.Module):
    def __init__(self):
        super(SimpleAlexNet, self).__init__()
        
        self.features = nn.Sequential(
            # Simplified conv layers
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(64, 192, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(192, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(256 * 4 * 4, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x
         

In [9]:
def train():
    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Simple transforms
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    # Load CIFAR-10
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                          download=True, transform=transform)
    trainloader = DataLoader(trainset, batch_size=256, shuffle=True, num_workers=2)
    
    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                         download=True, transform=transform)
    testloader = DataLoader(testset, batch_size=256, shuffle=False, num_workers=2)
    
    # Model, loss and optimizer
    model = SimpleAlexNet().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    # Training loop
    num_epochs = 10
    print("Training started...")
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        pbar = tqdm(trainloader, desc=f'Epoch {epoch+1}/{num_epochs}')
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            pbar.set_postfix({
                'loss': running_loss/total,
                'acc': 100.*correct/total
            })
    
    # Test the model
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    acc = 100 * correct / total
    print(f'\nAccuracy on test set: {acc:.2f}%')
    
    return model

if __name__ == '__main__':
    train()

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|█████████████████████████████████████████████████| 170498071/170498071 [02:24<00:00, 1180567.12it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Training started...


Epoch 1/10: 100%|██████████████████████████████| 196/196 [00:12<00:00, 15.97it/s, loss=0.00594, acc=44.6]
Epoch 2/10: 100%|██████████████████████████████| 196/196 [00:11<00:00, 16.64it/s, loss=0.00425, acc=61.2]
Epoch 3/10: 100%|███████████████████████████████| 196/196 [00:12<00:00, 15.92it/s, loss=0.0035, acc=68.4]
Epoch 4/10: 100%|██████████████████████████████| 196/196 [00:10<00:00, 17.99it/s, loss=0.00304, acc=72.5]
Epoch 5/10: 100%|██████████████████████████████| 196/196 [00:12<00:00, 15.42it/s, loss=0.00267, acc=76.1]
Epoch 6/10: 100%|██████████████████████████████| 196/196 [00:12<00:00, 15.81it/s, loss=0.00242, acc=78.2]
Epoch 7/10: 100%|██████████████████████████████| 196/196 [00:11<00:00, 16.83it/s, loss=0.00215, acc=80.7]
Epoch 8/10: 100%|██████████████████████████████| 196/196 [00:11<00:00, 16.77it/s, loss=0.00192, acc=82.7]
Epoch 9/10: 100%|██████████████████████████████| 196/196 [00:11<00:00, 16.77it/s, loss=0.00177, acc=84.2]
Epoch 10/10: 100%|████████████████████████████


Accuracy on test set: 80.38%
