In [None]:
#| default_exp models.conv

In [None]:
#| hide
%load_ext autoreload
%autoreload 2
from nbdev.showdoc import *

# Convolution-based Model

In [None]:
#| export
import torch.nn as nn
import torch
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST

from pytorch_lightning import LightningModule, Trainer
from torchmetrics import Accuracy
from hydra.utils import instantiate
from omegaconf import OmegaConf

from nimrod.data.datasets import MNISTDataModule
from nimrod.utils import get_device

  from .autonotebook import tqdm as notebook_tqdm


## Convnet
Simple convolution network for image recognition

In [None]:
#| export
class ConvNet(nn.Module):
    def __init__(self):
        super().__init__()

        # Define the convolutional layers
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)

        # Define the pooling and dropout layers
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)

        # Define the fully connected layers
        self.fc1 = nn.Linear(32 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x:torch.Tensor # input image tensor of dimension (B, C, W, H)
                ) -> torch.Tensor: # output probs (B, N_classes)
        # Pass the input through the convolutional layers
        x = self.conv1(x)
        x = self.pool(x)
        x = self.dropout1(x)
        x = self.conv2(x)
        x = self.pool(x)
        x = self.dropout2(x)

        # Reshape the output for the fully connected layers
        x = x.view(-1, 32 * 7 * 7)

        # Pass the output through the fully connected layers
        x = self.fc1(x)
        x = self.fc2(x)

        # Return the final output
        return x

### Usage

#### MNIST data

In [None]:
cfg = OmegaConf.load('../config/data/image/mnist.yaml')
datamodule = instantiate(cfg.datamodule)
datamodule.prepare_data()
datamodule.setup()

#### Model

In [None]:
# model instantiation
convnet = ConvNet()

# one data point 
X,y = datamodule.data_test[0]
print("X (C,H,W): ", X.shape, y)
y_hat = convnet(X)
print(y_hat.data)

# a batch of data via dataloader
XX,YY = next(iter(datamodule.test_dataloader()))
yy_hat = convnet(XX)
print("y (B,N_classes):", yy_hat.shape)

X (C,H,W):  torch.Size([1, 28, 28]) 0
tensor([[ 0.1526, -0.0935,  0.2227, -0.1580, -0.1237,  0.0805, -0.1651, -0.1068,
         -0.0409,  0.2865]])
y (B,N_classes): torch.Size([64, 10])


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()