In [4]:
## Imports
import torch
import torch.nn as nn
import torch.nn.functional as F

## Conv-NeXt Block

In [5]:
# ConvNeXt Block
class ConvNeXtBlock(nn.Module):
    """A single block in ConvNeXt, inspired by Swin Transformer MLP blocks but using convolutions."""
    def __init__(self, dim, kernel_size=7, expansion=4):
        super().__init__()
        self.dw_conv = nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size//2, groups=dim)  # Depthwise Conv
        self.norm = nn.LayerNorm(dim, eps=1e-6)
        self.pw_conv1 = nn.Linear(dim, expansion * dim)  # Pointwise Conv (MLP 1st Layer)
        self.act = nn.GELU()
        self.pw_conv2 = nn.Linear(expansion * dim, dim)  # Pointwise Conv (MLP 2nd Layer)

    def forward(self, x):
        shortcut = x  # Residual Connection
        x = self.dw_conv(x)
        x = x.permute(0, 2, 3, 1)  # Change to (B, H, W, C) for LayerNorm
        x = self.norm(x)
        x = self.pw_conv1(x)
        x = self.act(x)
        x = self.pw_conv2(x)
        x = x.permute(0, 3, 1, 2)  # Back to (B, C, H, W)
        return x + shortcut  # Residual Connection

In [6]:
# Test Example
model = ConvNeXtBlock(dim=3)
test_input = torch.randn(1, 3, 224, 224)
output = model(test_input)
print("ConvNeXt Output Shape:", output.shape)

ConvNeXt Output Shape: torch.Size([1, 3, 224, 224])
