# Convolution-based Model

In [None]:
#| default_exp models.conv

In [None]:
#| hide
%load_ext autoreload
%autoreload 2
from nbdev.showdoc import *

In [None]:
#| export
import torch.nn as nn
import torch

from lightning import LightningModule, Trainer
from lightning.pytorch.loggers import CSVLogger

from torch_lr_finder import LRFinder

from hydra.utils import instantiate
from omegaconf import OmegaConf

from matplotlib import pyplot as plt
import pandas as pd

from nimrod.utils import get_device
from nimrod.models.core import Classifier

import logging
logger = logging.getLogger(__name__)

## Conv Layer

Using a convolution with a stride of 2 instead of max pooling essentially achieves the same goal of downsampling an image by reducing its spatial dimensions, but with the key difference that the convolution layer can learn more complex feature combinations from overlapping regions, while max pooling only selects the maximum value within a window, potentially losing information about the finer details within that region; making the convolution with stride approach often preferred for preserving more spatial information in a neural network. 

In [None]:
#| export
class ConvLayer(nn.Module):
    def __init__(self,
                in_channels:int=3, # input channels
                out_channels:int=16, # output channels
                kernel_size:int=3, # kernel size
                activation:bool=True
                ):

        super().__init__()
        self.activation = activation
        # use stride 2 for downsampling instead of max or average pooling with stride 1
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, 2, kernel_size//2)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        if self.activation:
            x = self.relu(x)
        return x

### Usage

In [None]:
B, C, H, W = 64, 1, 28, 28
X = torch.rand(B, C, H, W)
c = ConvLayer(1, 16, 3, True)
# flatten all dims except batch
Y = torch.flatten(c(X), 1)
print(Y.shape)

torch.Size([64, 3136])


## Convnet Model
Simple convolution network for image recognition

In [None]:
#| export
class ConvNet(nn.Module):
    def __init__(
            self,
            in_channels:int=1, # input channels
            out_channels:int=10 # num_classes
            ):
        super().__init__()
        self.net = nn.Sequential(
            ConvLayer(in_channels, 8, kernel_size=5), #14x14
            nn.BatchNorm2d(8),
            ConvLayer(8, 16), #7x7
            nn.BatchNorm2d(16),
            ConvLayer(16, 32), #4x4
            nn.BatchNorm2d(32),
            ConvLayer(32, 64), #2x2
            nn.BatchNorm2d(64),
            ConvLayer(64, 10, activation=False), #1x1
            nn.BatchNorm2d(10),
            nn.Flatten()

        )

    def forward(self, x:torch.Tensor # input image tensor of dimension (B, C, W, H)
                ) -> torch.Tensor: # output probs (B, N_classes)

        return self.net(x)

### Usage

#### Mock data

In [None]:
B, C, H, W = 64, 1, 28, 28
X = torch.rand(B, C, H, W)
X.shape

torch.Size([64, 1, 28, 28])

#### Model

In [None]:
# model instantiation
convnet = ConvNet()
out = convnet(X)
print(out.shape)

# from config
cfg = OmegaConf.load('../config/image/model/conv.yaml')
print(cfg.nnet)
convnet = instantiate(cfg.nnet)
print(convnet(X).shape)

torch.Size([64, 10])
{'_target_': 'nimrod.models.conv.ConvNet', 'in_channels': 1, 'out_channels': '${num_classes}'}
torch.Size([64, 10])


### Training

#### Dataloaders

In [None]:
#| notest

# data module config
cfg = OmegaConf.load('../config/image/data/mnist.yaml')

datamodule = instantiate(cfg.datamodule)
# datamodule.prepare_data()
datamodule.batch_size = 2048
datamodule.setup()

# one data point 
X,y = datamodule.data_test[0]
print("X (C,H,W): ", X.shape, "y: ", y)

# a batch of data via dataloader
XX,YY = next(iter(datamodule.test_dataloader()))
print("XX (B,C,H,W): ", XX.shape, "YY: ", YY.shape)

print(len(datamodule.data_train))
print(len(datamodule.data_train)//cfg.datamodule.batch_size)

[18:04:01] INFO - Init MNIST DataModule
[18:04:01] INFO - MNISTDataset: init
[18:04:01] INFO - ImageDataset: init
[18:04:08] INFO - MNISTDataset: init
[18:04:08] INFO - ImageDataset: init


X (C,H,W):  torch.Size([1, 28, 28]) y:  0
XX (B,C,H,W):  torch.Size([64, 1, 28, 28]) YY:  torch.Size([64])
56000
875


#### Model & hardware

In [None]:
device = get_device()
print(device)
model = ConvNet()
model = model.to(device)

#### Loss, optimizer, scheduler

##### LR finder

In [None]:
#| notest

criterion = nn.CrossEntropyLoss()
    
optimizer = torch.optim.Adam(model.parameters(), lr=1e-7, weight_decay=1e-2)
    
# Initialize LR Finder
lr_finder = LRFinder(model, optimizer, criterion, device=device)
    
# Run LR range test
lr_finder.range_test(
    datamodule.train_dataloader(),
    start_lr=1e-6,      # Extremely small starting learning rate
    end_lr=10,          # Large ending learning rate
    num_iter=100,   # Number of iterations to test
    smooth_f=0.05,   # Smoothing factor for the loss
    diverge_th=5, 
)
    
# Plot the learning rate vs loss
_, lr_found = lr_finder.plot(log_lr=True)
print('Suggested lr:', lr_found)
    
lr_finder.reset()
    

    

##### 1-cycle training loop

In [None]:
#| notest

# data module config
cfg = OmegaConf.load('../config/image/data/mnist.yaml')
cfg.datamodule.batch_size = 2048
datamodule = instantiate(cfg.datamodule)
# datamodule.prepare_data()
datamodule.setup()

device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'
print(device)
model = ConvNet()
model = model.to(device)


criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
steps_per_epoch = len(datamodule.data_train) // cfg.datamodule.batch_size
total_steps = steps_per_epoch* N_EPOCHS
print(len(datamodule.data_train), cfg.datamodule.batch_size, steps_per_epoch, total_steps)
# scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=steps_per_epochs, epochs=1)
N_EPOCHS = 10
scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=lr_found,  # Peak learning rate
        # total_steps=len(datamodule.data_train) * N_EPOCHS,  # Total training iterations
        steps_per_epoch=steps_per_epoch,
        epochs=N_EPOCHS,
        pct_start=0.3,  # 30% of training increasing LR, 70% decreasing
        anneal_strategy='cos',  # Cosine annealing
        div_factor=10,  # Initial lr = max_lr / div_factor
        # final_div_factor=1e4,
        three_phase=False  # Two phase LR schedule (increase then decrease)
    )

%time
losses = []
lrs = []
current_step = 0

for epoch in range(N_EPOCHS):
    i = 0
    model.train()
    for images, labels in datamodule.train_dataloader():
        if current_step >= total_steps:
            print(f"Reached total steps: {current_step}/{total_steps}")
            break
        optimizer.zero_grad()
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)        
        loss.backward()
        optimizer.step()
        scheduler.step()    
        current_step += 1
    
        losses.append(loss.item())
        # current_lr = scheduler.get_last_lr()[0]
        current_lr = optimizer.param_groups[0]['lr']
        lrs.append(current_lr)
        if not (i % 100):
            print(f"Loss {loss.item():.4f}, Current LR: {current_lr:.10f}, Step: {current_step}/{total_steps}")
        i += 1
    
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in datamodule.test_dataloader():
            # model expects input (B,H*W)
            images = images.to(device)
            labels = labels.to(device)
            # Pass the input through the model
            outputs = model(images)
            # Get the predicted labels
            _, predicted = torch.max(outputs.data, 1)

            # Update the total and correct counts
            total += labels.size(0)
            correct += (predicted == labels).sum()

        # Print the accuracy
    print(f"Epoch {epoch + 1}: Loss {loss.item():.4f}, Accuracy = {100 * correct / total:.2f}%")
    # print(f'Current LR: {optimizer.param_groups[0]["lr"]:.5f}')



In [None]:
#| notest
plt.figure(1)
plt.subplot(211)
plt.ylabel('loss')
plt.xlabel('step')
plt.plot(losses)
plt.subplot(212)
plt.ylabel('lr')
plt.xlabel('step')
plt.plot(lrs)

## ConvNetX

In [None]:
#| export

class ConvNetX(Classifier, LightningModule):
    def __init__(
            self,
            nnet:ConvNet,
            num_classes:int,
            optimizer:torch.optim.Optimizer,
            scheduler:torch.optim.lr_scheduler,
            ):
        logger.info("ConvNetX: init")
        super().__init__(num_classes, optimizer, scheduler)
        self.save_hyperparameters(logger=False, ignore=['nnet'])
        self.lr = optimizer.keywords['lr'] # for lr finder
        self.nnet = nnet

    def forward(self, x:torch.Tensor)->torch.Tensor:
        return self.nnet(x)
    
    def _step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.loss(y_hat, y)
        preds = y_hat.argmax(dim=1)
        return loss, preds, y
    
    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        x, y = batch
        y_hat = self.forward(x)
        return y_hat.argmax(dim=1)

### Usage

In [None]:
cfg = OmegaConf.load('../config/image/model/conv.yaml')
model = instantiate(cfg)

In [None]:
B, C, H, W = 64, 1, 28, 28
X = torch.rand(B, C, H, W)
X.shape
print(model(X).shape)

### Nimrod training

In [None]:
# model
cfg = OmegaConf.load('../config/image/model/conv.yaml')
model = instantiate(cfg)

# data module config
cfg = OmegaConf.load('../config/image/data/mnist.yaml')
cfg.datamodule.batch_size = 2048
cfg.datamodule.num_workers = 0
datamodule = instantiate(cfg.datamodule)
# datamodule.prepare_data()
datamodule.setup()

In [None]:
trainer = Trainer(
    accelerator="auto",
    max_epochs=3,
    logger=CSVLogger("logs", name="mnist_convnet")
    )


In [None]:
#| notest
trainer.fit(model, datamodule.train_dataloader(), datamodule.val_dataloader())


In [None]:
#| notest
csv_path = f"{trainer.logger.log_dir}/metrics.csv"
metrics = pd.read_csv(csv_path)
metrics.head(5)

In [None]:
#| notest
plt.figure()
plt.plot(metrics['step'], metrics['train/loss_step'], 'b.-')
plt.plot(metrics['step'], metrics['val/loss'],'r.-')
plt.show()

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()