In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class ConvNeXtBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=7, stride=1, expansion=4):
        super(ConvNeXtBlock, self).__init__()
        
        # Depthwise Separable Convolution (Conv + Depthwise + Pointwise)
        self.conv1 = nn.Conv2d(in_channels, in_channels * expansion, kernel_size=kernel_size, stride=stride, padding=kernel_size//2, groups=in_channels)
        self.conv2 = nn.Conv2d(in_channels * expansion, out_channels, kernel_size=1)
        
        # Layer Normalization
        self.ln1 = nn.LayerNorm(in_channels * expansion)
        self.ln2 = nn.LayerNorm(out_channels)
        
        # GELU Activation
        self.gelu = nn.GELU()
        
        # Skip Connection (Residual)
        self.skip = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride) if in_channels != out_channels else nn.Identity()

    def forward(self, x):
        identity = x
        
        # Depthwise Separable Convolution Block
        x = self.conv1(x)
        x = self.ln1(x)
        x = self.gelu(x)
        x = self.conv2(x)
        x = self.ln2(x)
        
        # Adding Skip Connection (Residual)
        x = x + self.skip(identity)
        
        return x

In [None]:
class ConvNeXt(nn.Module):
    def __init__(self, num_classes=1000, depth=16, base_channels=64, expansion=4):
        super(ConvNeXt, self).__init__()
        
        # Initial Conv Layer (stem)
        self.stem = nn.Conv2d(3, base_channels, kernel_size=4, stride=4)
        
        # ConvNeXt Blocks (Residual blocks stacked)
        self.blocks = nn.ModuleList([
            ConvNeXtBlock(base_channels * (expansion**i), base_channels * (expansion**(i+1)), stride=2 if i > 0 else 1, expansion=expansion)
            for i in range(depth)
        ])
        
        # Classifier Head (Global Average Pooling + Fully Connected)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(base_channels * (expansion**depth), num_classes)
    
    def forward(self, x):
        x = self.stem(x)
        
        # Pass through ConvNeXt Blocks
        for block in self.blocks:
            x = block(x)
        
        # Global Average Pooling
        x = self.pool(x)
        x = torch.flatten(x, 1)  # Flatten the output
        
        # Final Classification Layer
        x = self.fc(x)
        
        return x