##  ConvMixer

In [27]:
#@title **Install required packages**

%%capture
! pip install torchinfo

In [36]:
#@title **Importing libraries**
import torch # 2.5.1+cu121
import torch.nn as nn
import torchinfo #1.8.0

In [37]:
# Note: Not all dependencies have the __version__ method.
print(f"torch version: {torch.__version__}")
print(f"torchinfo version: {torchinfo.__version__}")

torch version: 2.5.1+cu121
torchinfo version: 1.8.0


**ConvMixer architecture code**


In [38]:
class Residual(nn.Module):
  def __init__(self, fn):
    super().__init__()
    self.fn = fn

  def forward(self, x):
    return self.fn(x) + x

def ConvMixer(dim, depth, kernel_size = 9, patch_size = 7, n_classes = 1000):
  return nn.Sequential(
      nn.Conv2d(3, dim,  kernel_size = patch_size, stride = patch_size),
      nn.GELU(),
      nn.BatchNorm2d(dim),
      *[nn.Sequential(
          Residual(nn.Sequential(
              nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"),
              nn.GELU(),
              nn.BatchNorm2d(dim)
          )),
          nn.Conv2d(dim,dim, kernel_size = 1),
          nn.GELU(),
          nn.BatchNorm2d(dim)
      )for i in range(depth)],
      nn.AdaptiveAvgPool2d((1,1)),
      nn.Flatten(),
      nn.Linear(dim, n_classes)
  )

In [41]:
model = ConvMixer(2048, 8, kernel_size=9, patch_size=1, n_classes=1000)
torchinfo.summary(model)

Layer (type:depth-idx)                   Param #
Sequential                               --
├─Conv2d: 1-1                            8,192
├─GELU: 1-2                              --
├─BatchNorm2d: 1-3                       4,096
├─Sequential: 1-4                        --
│    └─Residual: 2-1                     --
│    │    └─Sequential: 3-1              172,032
│    └─Conv2d: 2-2                       4,196,352
│    └─GELU: 2-3                         --
│    └─BatchNorm2d: 2-4                  4,096
├─Sequential: 1-5                        --
│    └─Residual: 2-5                     --
│    │    └─Sequential: 3-2              172,032
│    └─Conv2d: 2-6                       4,196,352
│    └─GELU: 2-7                         --
│    └─BatchNorm2d: 2-8                  4,096
├─Sequential: 1-6                        --
│    └─Residual: 2-9                     --
│    │    └─Sequential: 3-3              172,032
│    └─Conv2d: 2-10                      4,196,352
│    └─GELU: 2-11      