In [None]:
!pip install timm



In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from timm.models.layers import DropPath
from timm.data import Mixup
from torch.utils.data import DataLoader
from torchvision import datasets
import torch.nn.functional as F



In [None]:
class ConvNeXtBlock(nn.Module):
    def __init__(self, in_channels, drop_path_rate=0.0):
        super(ConvNeXtBlock, self).__init__()
        self.dwconv = nn.Conv2d(in_channels, in_channels, kernel_size=7, padding=3, groups=in_channels)  # depthwise conv 7x7
        self.norm = nn.GroupNorm(1, in_channels, eps=1e-6)  # groupNorm with 1 group to simulate LayerNorm
        self.pwconv1 = nn.Conv2d(in_channels, 4*in_channels, kernel_size=1)  # Pointwise conv 1
        self.act = nn.GELU()  # GELU Activation function
        self.pwconv2 = nn.Conv2d(4*in_channels, in_channels, kernel_size=1)  # Pointwise conv 2
        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()

    def forward(self, x):
        shortcut = x
        x = self.dwconv(x)

        x = self.norm(x)  # apply GroupNorm (simulates LayerNorm)

        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        x = self.drop_path(x) + shortcut  # residual connection with stochastic depth
        return x


In [None]:
class ConvNeXt(nn.Module):
    def __init__(self, num_classes=10, depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.1):
        super(ConvNeXt, self).__init__()
        self.downsample_layers = nn.ModuleList()
        self.stages = nn.ModuleList()

        # stem (initial layer)
        stem = nn.Sequential(
            nn.Conv2d(3, dims[0], kernel_size=4, stride=4),
            nn.GroupNorm(1, dims[0], eps=1e-6)  # Replace LayerNorm with GroupNorm
        )
        self.downsample_layers.append(stem)

        # stages with downsampling
        for i in range(4):
            stage = nn.Sequential(*[
                ConvNeXtBlock(dims[i], drop_path_rate=drop_path_rate)
                for _ in range(depths[i])
            ])
            self.stages.append(stage)
            if i < 3:  # add downsampling layer between stages
                downsample = nn.Sequential(
                    nn.GroupNorm(1, dims[i], eps=1e-6),
                    nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2)
                )
                self.downsample_layers.append(downsample)

        # final layers
        self.norm = nn.GroupNorm(1, dims[-1], eps=1e-6)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(dims[-1], num_classes)


    def forward(self, x):
        for i in range(4):
            x = self.downsample_layers[i](x)
            x = self.stages[i](x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x


In [None]:
# Data loading and augmentation
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
    transforms.RandomErasing(scale=(0.02, 0.33)),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])


train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:12<00:00, 13.4MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [None]:

######Model#######
model = ConvNeXt(num_classes=10, depths=[3, 3, 9, 3], dims=[128, 256, 512, 1024], drop_path_rate=0.3)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# optimizer and scheduler
optimizer = optim.AdamW(model.parameters(), lr=0.0005, weight_decay=0.05)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)
criterion = nn.KLDivLoss(reduction="batchmean")

mixup_fn = Mixup(mixup_alpha=0.2, cutmix_alpha=0.2, num_classes=10)

#training loop with log_softmax for soft label compatibility
num_epochs = 200
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    # print("Epoch: " + str(epoch + 1))

    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        #apply MixUp for soft labels
        inputs, labels = mixup_fn(inputs, labels)

        optimizer.zero_grad()

        #apply log_softmax to model outputs for KLDivLoss
        outputs = model(inputs)
        outputs = F.log_softmax(outputs, dim=1)

        # print("outputs shape:", outputs.shape)  #should be [batch_size, num_classes]
        # print("labels shape:", labels.shape)    #should match [batch_size, num_classes]

        #calculate the KL divergence loss
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    scheduler.step()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader)}")

model.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Accuracy: {100 * correct / total}%")

Epoch 1/200, Loss: 1.6019795369309233
Epoch 2/200, Loss: 1.3248262809365607
Epoch 3/200, Loss: 1.2782708042113067
Epoch 4/200, Loss: 1.2404614353118955
Epoch 5/200, Loss: 1.1980447055738601
Epoch 6/200, Loss: 1.164890632528783
Epoch 7/200, Loss: 1.1290759212525605
Epoch 8/200, Loss: 1.1044695471863613
Epoch 9/200, Loss: 1.0774589347869843
Epoch 10/200, Loss: 1.0671927921302484
Epoch 11/200, Loss: 1.1593816279602782
Epoch 12/200, Loss: 1.1350113114585048
Epoch 13/200, Loss: 1.115943494736386
Epoch 14/200, Loss: 1.0917710039926611
Epoch 15/200, Loss: 1.0647903655648536
Epoch 16/200, Loss: 1.0423119709924664
Epoch 17/200, Loss: 1.0182935179346968
Epoch 18/200, Loss: 0.9974430994609432
Epoch 19/200, Loss: 0.9754678089447948
Epoch 20/200, Loss: 0.9569139348728882
Epoch 21/200, Loss: 0.9385593862027464
Epoch 22/200, Loss: 0.9118798636566953
Epoch 23/200, Loss: 0.8979873584816828
Epoch 24/200, Loss: 0.8928682937875123
Epoch 25/200, Loss: 0.8722035462594093
Epoch 26/200, Loss: 0.86177529528012