Based on: https://www.kaggle.com/code/kmldas/cifar10-resnet-90-accuracy-less-than-5-min

And: https://github.com/GouMinghao/rgb_matters/blob/main/rgbd_graspnet/net/fastpose.py

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from typing import Callable

np.random.seed(1)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision as tv
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import make_grid

from torchinfo import summary



In [None]:
# Load in Data from disk and calculate mean and std
DatasetType = Dataset[tuple[torch.Tensor, int]]

CIFAR = "CIFAR10"   # or "CIFAR100"
num_classes = 10 if CIFAR == "CIFAR10" else 100
CIFARDataset = getattr(datasets, CIFAR)

data_path = "/home/bam/bam_ws/src/bam_brain/bam_gym/data_downloads"

# Load CIFAR without normalization so we can compute stats
dataset: DatasetType = CIFARDataset(
    root=data_path, train=True, download=True,
    transform=transforms.ToTensor()
)

loader = torch.utils.data.DataLoader(dataset, batch_size=512, shuffle=False, num_workers=4)

mean = 0.0
std = 0.0
n_samples = 0

images: torch.Tensor
for images, label in loader:
    # images shape: [batch, channels, height, width]
    batch_samples = images.size(0)
    images = images.view(batch_samples, images.size(1), -1)  # [batch, C, H*W]
    
    mean += images.mean(2).sum(0)   # sum over batch
    std += images.std(2).sum(0)
    n_samples += batch_samples

mean /= n_samples
std /= n_samples

print("Mean:", mean)
print("Std:", std)

In [None]:
# Create Data loaders with transforms

train_transforms = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean, std, inplace=True),
])
test_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std, inplace=True),
])

# Q: inplace true or false? https://chatgpt.com/c/68b5d239-5c54-8322-a73e-6f177e0545db
# A: If you don't need to access the original tensor after normalization, set inplace=True for efficiency.

train_dataset = CIFARDataset(root=data_path, train=True, download=True, transform=train_transforms)
test_dataset  = CIFARDataset(root=data_path, train=False, download=True, transform=test_transforms)

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True)
test_dataloader  = torch.utils.data.DataLoader(test_dataset,  batch_size=256, shuffle=False, num_workers=4, pin_memory=True)


In [None]:
def resnet_factory(num_layers=18, num_classes=10) -> models.resnet.ResNet:

    

    assert num_layers in [18, 34, 50, 101, 152]
    model = eval(f"models.resnet{num_layers}(pretrained=True)") # pretrained is legacy interface, you can no specific the exact weights to load in

    # Replace the 7x7 stride-2 conv + maxpool with 3x3 stride-1 and no pool:
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()
    # Adjust classifier head:
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model

num_layers = 18
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = resnet_factory(num_layers, num_classes).to(device)
model_stats = summary(model, input_size=(100, 3, 32, 32))
print(model_stats)

In [None]:
def show_batch(data_loader: DataLoader):
    for images, labels in data_loader:
        fig, ax = plt.subplots(figsize=(12, 12))
        ax.set_xticks([]); ax.set_yticks([])
        ax.imshow(make_grid(images[:64], nrow=8).permute(1, 2, 0))
        break

In [None]:
show_batch(train_dataloader)

In [None]:
# Visualize CIFAR images (RGB)
fig, axes = plt.subplots(1, 3, figsize=(10, 3))  

for ax, i in zip(axes, np.random.randint(0, len(train_dataset), 3)):
    image, label = train_dataset[i]
    # image is a torch.Tensor in (C, H, W), unnormalize for display
    img = image.clone()
    for c in range(3):
        img[c] = img[c] * std[c] + mean[c]
    img = img.numpy().transpose(1, 2, 0)  # (H, W, C)
    img = np.clip(img, 0, 1)
    ax.imshow(img)

    ax.set_title(f'Label: {train_dataset.classes[label]}')
    ax.axis('off')

plt.tight_layout()
plt.show()

plt.title('Histogram of Pixel Values (Normalized)')
# Show histogram for all channels
img0 = train_dataset[0][0]

plt.hist(img0.numpy().flatten(), bins=30, color='gray', alpha=0.7)
plt.xlabel('Pixel Value')
plt.ylabel('Count')
plt.show()

plt.title('Histogram of Pixel Values (Raw)')

for c in range(3):
    img0[c] = img0[c] * std[c] + mean[c]
plt.hist(img0.numpy().flatten(), bins=30, color='gray', alpha=0.7)
plt.xlabel('Pixel Value')
plt.ylabel('Count')
plt.show()

In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params = model.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer)



In [None]:
train_losses = []
test_losses = []
test_accuracies = []


In [None]:
def train() -> dict[str, float]:
    model.train()
    n_total = len(train_dataloader.dataset)
    n_batches = len(train_dataloader)
    running_loss = 0.0

    for batch, (X, y) in enumerate(train_dataloader):
        X, y = X.to(device), y.to(device)

        pred = model(X) 
        train_loss: torch.Tensor = loss_fn(pred, y) 

        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()

        running_loss += train_loss.item()

        if batch % 100 == 0:
            loss, current = train_loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{n_total:>5d}]")

    avg_loss = running_loss / n_batches
    return {"avg_loss": avg_loss}


In [None]:
def test() -> dict[str, float]:
    size = len(test_dataloader.dataset)
    num_batches = len(test_dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in test_dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    accuracy = correct / size
    print(f"Test Error: \n Accuracy: {(100*accuracy):>0.1f}%, Avg loss: {test_loss:>8f} \n")

    return {"test_loss": test_loss, "test_accuracy": accuracy}


In [None]:
for epoch in range(100):
    print(f"Epoch {epoch+1}\n-------------------------------")
    train_info = train()
    test_info = test()

    train_losses.append(train_info["avg_loss"])
    test_losses.append(test_info["test_loss"])
    test_accuracies.append(test_info["test_accuracy"])


In [None]:
plt.figure(figsize=(10,4))

plt.subplot(1,2,1)
plt.plot(train_losses, label="Train Loss")
plt.plot(test_losses, label="Test Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()

plt.subplot(1,2,2)
plt.plot(test_accuracies, label="Test Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()

plt.show()