## Exercise for Convolutional Neural Networks

In [None]:
#Load required packages  -- MIGHT HAVE TO PIP INSTALL "TORCHINFO"
from ISLP import load_data
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report
import pandas as pd
import numpy as np
from IPython.display import Image
from torchinfo import summary
from torch.optim import RMSprop
from torch.utils.data import TensorDataset
from pytorch_lightning.loggers import CSVLogger

from pytorch_lightning import seed_everything
seed_everything(0, workers=True)
torch.use_deterministic_algorithms(True, warn_only=True)

In [None]:
from ISLP.torch.imdb import (load_lookup,
                             load_tensor,
                             load_sparse,
                             load_sequential)
from ISLP.torch import (SimpleDataModule,
                        SimpleModule,
                        ErrorTracker,
                        rec_num_workers)

from torch.optim import RMSprop
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning import Trainer

In [None]:
from torchvision.datasets import CIFAR100

(cifar_train, cifar_test) = [CIFAR100(root="data", train=train, download=True) for train in [True, False]]

In [None]:
import numpy as np 
import matplotlib.pyplot as plt


# --- Basic stats ---
print("Train set size:", len(cifar_train))
print("Test set size:", len(cifar_test))
print("Number of classes:", len(cifar_train.classes))
print("Example classes:", cifar_train.classes[:10])

# Count samples per class in training set
from collections import Counter
train_labels = [label for _, label in cifar_train]
label_counts = Counter(train_labels)
print("\nSamples per class (train):", label_counts)

# --- Visualization: show a grid of example images ---
def show_samples(dataset, n=16):
    fig, axes = plt.subplots(4, 4, figsize=(8, 8))
    for ax in axes.flatten():
        idx = np.random.randint(0, len(dataset))
        img, label = dataset[idx]
        ax.imshow(np.transpose(np.array(img), (0, 1, 2)))
        ax.set_title(dataset.classes[label], fontsize=8)
        ax.axis("off")
    plt.tight_layout()
    plt.show()

show_samples(cifar_train)

# --- Class distribution bar plot ---
plt.figure(figsize=(12, 4))
plt.bar(range(len(label_counts)), [label_counts[i] for i in range(len(label_counts))])
plt.xticks(range(len(label_counts)), cifar_train.classes, rotation=90)
plt.title("CIFAR-100 Training Set Class Distribution")
plt.xlabel("Class")
plt.ylabel("Count")
plt.show()

In [None]:
from torchvision.transforms import ToTensor

# create the transform
transform = ToTensor()

# transform test and train predictors
cifar_train_X = torch.stack([transform(x) for x in cifar_train.data])
cifar_test_X = torch.stack([transform(x) for x in cifar_test.data])

# create test and training datasets
cifar_train = TensorDataset(cifar_train_X,
                            torch.tensor(cifar_train.targets))
cifar_test = TensorDataset(cifar_test_X,
                           torch.tensor(cifar_test.targets))

In [None]:
max_num_workers = rec_num_workers()

cifar_dm = SimpleDataModule(cifar_train,
                            cifar_test,
                            validation=0.2,
                            num_workers=max_num_workers,
                            batch_size=128)

In [None]:
class BuildingBlock(nn.Module):
    def __init__(self, in_channels,out_channels):
        
        super(BuildingBlock , self).__init__()
        self.conv = nn.Conv2d(in_channels=in_channels,
                              out_channels=out_channels,
                              kernel_size=(3,3),
                              padding='same')
        
        self.activation = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=(2,2))
        
    def forward(self, x):
        return self.pool(self.activation(self.conv(x)))

In [None]:
class CIFARModel(nn.Module):
    def __init__(self):
        super(CIFARModel, self).__init__()
        sizes = [(3,32),
                (32,64),
                (64,128),
                (128,256)]
        self.conv = nn.Sequential(*[BuildingBlock(in_, out_) for in_, out_ in sizes])
        self.output = nn.Sequential(nn.Dropout(0.5),
                                    nn.Linear(2*2*256, 512),
                                    nn.ReLU(),
                                    nn.Linear(512, 100))

    def forward(self, x):
        val = self.conv(x)
        val = torch.flatten(val, start_dim=1)
        return self.output(val)

In [None]:
# create the model
cifar_model = CIFARModel()

# describe the model architecture
summary(cifar_model, input_data=cifar_train_X, col_names=['input_size', 'output_size', 'num_params'])

In [None]:
def summary_plot(results, ax, col='loss', valid_legend='Validation',
                 training_legend='Training', ylabel='Loss', fontsize=20):
    
    # loop through epocs
    for (column, color, label) in zip([f'train_{col}_epoch', f'valid_{col}'],
                                      ['black', 'red'],
                                      [training_legend, valid_legend]):
        
        # add results to plot
        results.plot(x='epoch', y=column, label=label, marker='o', color=color, ax=ax)
        
    # label axes
    ax.set_xlabel('Epoch')
    ax.set_ylabel(ylabel)
    return ax

In [None]:
# define the optimizer
cifar_optimizer = RMSprop(cifar_model.parameters(), lr=0.001)

# create module for training
cifar_module = SimpleModule.classification(cifar_model,
                                    num_classes=100,
                                    optimizer=cifar_optimizer)

# create logger
cifar_logger = CSVLogger('logs', name='CIFAR100')

# define the training routine
cifar_trainer = Trainer(deterministic=True,
                        max_epochs=30,
                        logger=cifar_logger,
                        enable_progress_bar=False,
                        callbacks=[ErrorTracker()])

# fit the cnn network
cifar_trainer.fit(cifar_module, datamodule=cifar_dm)

In [None]:
# import plotting library
import matplotlib.pyplot as plt

# read the logs
log_path = cifar_logger.experiment.metrics_file_path
cifar_results = pd.read_csv(log_path)

# plot the training and validation loss per epoch
fig, ax = plt.subplots(1, 1, figsize=(6, 6))
summary_plot(cifar_results, ax, col='accuracy', ylabel='Accuracy')
ax.set_xticks(np.linspace(0, 30, 6).astype(int))
ax.set_ylabel('Accuracy')
ax.set_ylim([0, 1])
plt.show()

In [None]:
cifar_trainer.test(cifar_module, datamodule=cifar_dm)