In [19]:
#| default_exp model

In [20]:
#| hide
from nbdev.showdoc import *

In [21]:
#| export 

import torch
from torch import nn
import torch.nn.functional as F

from typing import Any, Union, Tuple

import pytorch_lightning as pl

In [22]:
#| export
_size_2_t = Union[int, Tuple[int, int]]

In [23]:
sample_data = torch.rand((2, 3, 32, 32))
sample_data.shape

torch.Size([2, 3, 32, 32])

In [24]:
# | export


class ConvBnRelu(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: _size_2_t,
        stride: _size_2_t = 1,
        padding: _size_2_t | str = 0,
        dilation: _size_2_t = 1,
        groups: int = 1,
        bias: bool = True,
        padding_mode: str = "zeros",
        device: Any | None = None,
        dtype: Any | None = None,
        eps: float = 0.00001,
        momentum: float = 0.1,
        affine: bool = True,
        track_running_stats: bool = True,
    ):
        super(ConvBnRelu, self).__init__()
        self.conv = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=bias,
            padding_mode=padding_mode,
            device=device,
            dtype=dtype,
        )
        self.bn = nn.BatchNorm2d(
            num_features=out_channels,
            eps=eps,
            momentum=momentum,
            affine=affine,
            track_running_stats=track_running_stats,
            device=device,
            dtype=dtype,
        )

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return F.relu(x)

In [25]:
ConvBnRelu(3, 16, 3)(sample_data).shape  # (Fin-K+2P)/S + 1 -> 32-3/1 + 1 = 30

torch.Size([2, 16, 30, 30])

In [26]:
# | export

class ConvNet(nn.Module):
    def __init__(self, n_cls:int):
        super(ConvNet, self).__init__()
        self.conv1 = ConvBnRelu(3, 16, 3) 
        self.drop1 = nn.Dropout2d(p=0.2)
        self.conv2 = ConvBnRelu(16, 32, 3)
        self.drop2 = nn.Dropout2d(p=0.2)
        self.flat = nn.Flatten()
        self.linear = nn.LazyLinear(out_features=n_cls)
    def forward(self, x):
        x = self.drop1(self.conv1(x))
        x = self.drop2(self.conv2(x)) 
        x = self.flat(x)
        return self.linear(x)

In [27]:
ConvNet(n_cls=10)(sample_data)

tensor([[-0.5159,  0.0844, -0.0255,  0.2810, -0.4592,  0.0967, -0.3894,  0.9572,
          0.0648,  0.7900],
        [ 0.3965,  0.6567,  0.2903, -0.3211, -0.2582, -0.4820, -0.5397,  0.0930,
         -0.1613,  0.0652]], grad_fn=<AddmmBackward0>)

In [28]:
#| hide
import nbdev; nbdev.nbdev_export()