In [None]:
import torch
from torch import nn
import matplotlib.pyplot as plt

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
################# CIFAR 10 ###################

import torchvision
import torchvision.transforms as transforms

mean = (0.4914,0.4822,0.4465)
std = (0.2023,0.1994,0.2010)

train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32,padding=4),
    transforms.ColorJitter(brightness=0.1,contrast=0.1,saturation=0.1,hue=0.02),
    transforms.ToTensor(),
    transforms.Normalize(mean,std)


])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean,std)

])

train_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=train_transform
)
test_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=test_transform
)

from torch.utils.data import DataLoader

train_loader = DataLoader(dataset=train_dataset, batch_size=64, num_workers=6, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, num_workers=6, shuffle=False)



In [None]:
x, label = train_dataset[0]
x.shape

In [None]:
class CNN_Cifar10(nn.Module):
    def __init__(self):
        super().__init__()
        self.block_1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.block_2 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.clf = nn.Sequential(
            nn.AdaptiveAvgPool2d((4, 4)),
            nn.Flatten(),
            nn.Linear(in_features=128 * 4 * 4, out_features=256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256,10)
        )

    def forward(self, x):
        x = self.block_1(x)
        x = self.block_2(x)
        x = self.clf(x)
        return x


In [None]:
torch.manual_seed(42)
model = CNN_Cifar10()


In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

In [None]:

from helper_functions import accuracy_fn # Note: could also use torchmetrics.Accuracy(task = 'multiclass', num_classes=len(class_names)).to(device)



In [None]:
n_epochs = 10 # do more, e.g. 30 since I am using data augmentation now
model.to(device)
from torchmetrics.classification import Accuracy
accuracy_metric = Accuracy(task="multiclass", num_classes=10).to(device)
for epoch in range(n_epochs):
    train_loss = 0
    for batch, (X, y) in enumerate(train_loader):
        model.train()
        X, y = X.to(device), y.to(device)
        y_pred = model(X)
        loss = loss_fn(y_pred, y)
        train_loss += loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    train_loss = train_loss/len(train_loader)

    test_loss, test_acc = 0, 0
    model.eval()
    correct, total = 0, 0
    with torch.inference_mode():
        for X, y in test_loader:
            X, y = X.to(device), y.to(device)
            test_pred = model(X)

            # 1. Calculate loss (accumulatively)
            test_loss += loss_fn(test_pred, y)  # Accumulate loss per batch
            y_pred_labels = test_pred.argmax(dim=1)
            # 2. Calculate accuracy (batch-wise)
            correct += (y_pred_labels == y).sum().item()
            total += y.size(0)

        # Divide total test loss by length of test dataloader (per batch)
        test_loss /= len(test_loader)

        # Divide total accuracy by length of test dataloader (per batch)
        test_acc = correct / total

    print(f"Epoch {epoch+1}:")
    print(f"Train Loss: {train_loss:.4f} | Test Loss: {test_loss:.4f} | Test Accuracy: {test_acc*100:.2f}%\n")
