In [9]:
import torch
from torch.utils.data import random_split, DataLoader
import torchvision.transforms as transforms
import torchvision

transform = transforms.Compose([
    transforms.RandomRotation(degrees=15),  # Randomly rotate the image by up to 15 degrees
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),  # Randomly translate the image
    #transforms.RandomResizedCrop(size=28, scale=(0.8, 1.0)),  # Randomly crop and resize
    #transforms.ColorJitter(brightness=0.1, contrast=0.1),  # Random brightness and contrast
    transforms.ToTensor(),  # Convert images to tensors
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)) 
])

train_ds = torchvision.datasets.CIFAR100("/home/eagle/Projects/dl_from_scratch/cifar10", train=True, download=True, transform=transform)
test_ds = torchvision.datasets.CIFAR100("/home/eagle/Projects/dl_from_scratch/cifar10", train=False, download=True, transform=transform)

train_size = int(0.8 * len(train_ds))  # 80% for training
val_size = len(train_ds) - train_size  # 20% for validation

# Split the train_dataset into train and val
train_ds, val_ds = random_split(train_ds, [train_size, val_size])

batch_size = 512
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=6)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=6)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=6)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to /home/eagle/Projects/dl_from_scratch/cifar10/cifar-100-python.tar.gz


100%|███████████████████████| 169001437/169001437 [00:02<00:00, 65443124.12it/s]


Extracting /home/eagle/Projects/dl_from_scratch/cifar10/cifar-100-python.tar.gz to /home/eagle/Projects/dl_from_scratch/cifar10
Files already downloaded and verified


In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvNextBlock(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(ConvNextBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=in_channel, kernel_size=7, padding=3, groups=in_channel)
        self.ln1 = nn.LayerNorm(in_channel)
        self.conv2 = nn.Conv2d(in_channels=in_channel, out_channels= 4 * in_channel, kernel_size=1)
        self.gelu = nn.GELU()
        self.conv3 = nn.Conv2d(in_channels=4 * in_channel, out_channels=out_channel, kernel_size=1)
        self.ls1 = LayerScale(out_channel)
        self.dp1 = DropPath() 
            
    def forward(self, x):
        residual = x
        x = self.conv1(x)
        x = x.permute(0, 2, 3, 1)
        x = self.ln1(x)
        x = x.permute(0, 3, 1, 2)
        x = self.conv2(x)
        x = self.gelu(x)
        x = self.conv3(x)
        x = self.ls1(x)
        x = self.dp1(x)
        x += residual
        return x

class LayerScale(nn.Module):
    def __init__(self, channels, init_value=1e-6):
        super(LayerScale, self).__init__()
        self.gamma = nn.Parameter(init_value * torch.ones((1, channels, 1, 1)))

    def forward(self, x):
        return self.gamma * x

class DropPath(nn.Module):
    def __init__(self, drop_prob=0.1):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        if self.drop_prob == 0 or not self.training:
            return x
        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
        random_tensor.floor()
        return x / keep_prob * random_tensor

class DownsampleBlock(nn.Module):
    def __init__(self, channels):
        super(DownsampleBlock, self).__init__()
        self.l1 = nn.LayerNorm(channels)
        self.conv1 = nn.Conv2d(in_channels=channels, out_channels=channels * 2 , kernel_size=2, stride=2)

    def forward(self, x):
        x = x.permute(0, 2, 3, 1)
        x = self.l1(x)
        x = x.permute(0, 3, 1, 2)
        x = self.conv1(x)
        return x


class ConvNext(nn.Module):
    def __init__(self, num_classes=1000):
        super(ConvNext, self).__init__()

        self.stem = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=96, kernel_size=4, stride=4),
        )

        self.stem_ln = nn.LayerNorm(96, eps=1e-6)

        # Stages (following ConvNeXt Tiny/Small structure)
        self.stage1 = self._make_stage(ConvNextBlock, in_channel=96, out_channel=96, num_blocks=3)
        self.downsample1 = DownsampleBlock(96)
        
        self.stage2 = self._make_stage(ConvNextBlock, in_channel=192, out_channel=192, num_blocks=3)
        self.downsample2 = DownsampleBlock(192)
        
        self.stage3 = self._make_stage(ConvNextBlock, in_channel=384, out_channel=384, num_blocks=9)
        self.downsample3 = DownsampleBlock(384)
        
        self.stage4 = self._make_stage(ConvNextBlock, in_channel=768, out_channel=768, num_blocks=3)

        # Classification head
        self.norm = nn.LayerNorm(768, eps=1e-6)  # Final layer norm
        self.fc = nn.Linear(768, num_classes)

    def _make_stage(self, block, in_channel, out_channel, num_blocks):
        layers = []
        for _ in range(num_blocks):
            layers.append(block(in_channel, out_channel))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.stem(x)

        x = x.permute(0, 2, 3, 1)  # (batch, height, width, channels)
        x = self.stem_ln(x)
        x = x.permute(0, 3, 1, 2)
        
        x = self.stage1(x)
        x = self.downsample1(x)
        
        x = self.stage2(x)
        x = self.downsample2(x)
        
        x = self.stage3(x)
        x = self.downsample3(x)
        
        x = self.stage4(x)

        x = x.mean(dim=[2, 3])  # Global average pooling

        x = self.norm(x)
        x = self.fc(x)
        return x

# Initialize the ConvNeXt model
model = ConvNext(num_classes=1000) 

# Test with a random input tensor
x = torch.randn(1, 3, 224, 224)
out = model(x)
print("Output shape:", out.shape)  

Output shape: torch.Size([1, 1000])


In [None]:
import torch.optim as optim

device = "cuda" if torch.cuda.is_available() else "cpu"
model = ConvNet().to(device)
model.eval()
model.fuse_model()  # Fuse the model layers
model.train()
#model = torch.compile(model)

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=0.01)
scaler = torch.cuda.amp.GradScaler()

num_epochs = 15

for epoch in range(num_epochs):
    model.train()  # Set the model to training mode
    running_loss = 0.0
    
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)  
        optimizer.zero_grad()

        # Forward pass with mixed precision
        with torch.cuda.amp.autocast():  
            outputs = model(inputs)
            loss = loss_fn(outputs, labels)

        # Backward pass
        scaler.scale(loss).backward()  # Scale the loss for stable gradients
        scaler.step(optimizer)  # Update the parameters
        scaler.update()  # Update the scaler

        running_loss += loss.item()   
        
    # Calculate average loss for the epoch
    avg_loss = running_loss / len(train_loader)

    # Validation phase
    model.eval()  # Set the model to evaluation mode
    running_val_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():  # Disable gradient calculation
        for inputs, labels in val_loader:  # Assuming val_loader is defined
            inputs, labels = inputs.to(device), labels.to(device)

            with torch.cuda.amp.autocast():  # Enable autocasting for validation
                outputs = model(inputs)
                loss = loss_fn(outputs, labels)
                
            running_val_loss += loss.item()

            # Calculate accuracy
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    # Calculate average validation loss and accuracy
    avg_val_loss = running_val_loss / len(val_loader)
    accuracy = correct / total * 100  # Convert to percentage

    print(f"Epoch [{epoch + 1}/{num_epochs}], "
          f"Training Loss: {avg_loss:.4f}, "
          f"Validation Loss: {avg_val_loss:.4f}, "
          f"Validation Accuracy: {accuracy:.2f}%")

print("Training complete.")