In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torch.optim.lr_scheduler import StepLR
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time

# CIFAR10 

In [2]:
batch_size = 128

# Define data transformations for training and testing
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    # Standard normalization for CIFAR-10
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

# Load CIFAR-10 training data
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)

# Split the training set into training and validation subsets (e.g., 90% training, 10% validation)
train_size = int(0.9 * len(trainset))
val_size = len(trainset) - train_size
train_subset, val_subset = random_split(trainset, [train_size, val_size])

trainloader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=2)
valloader = DataLoader(val_subset, batch_size=batch_size, shuffle=False, num_workers=2)

# Load CIFAR-10 test data
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

print("Data loaded successfully!")

Files already downloaded and verified
Files already downloaded and verified
Data loaded successfully!


In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

# --- Squeeze-and-Excitation Block (same as before) ---
class SEBlock(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super(SEBlock, self).__init__()
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Linear(in_channels, in_channels // reduction)
        self.fc2 = nn.Linear(in_channels // reduction, in_channels)

    def forward(self, x):
        b, c, _, _ = x.shape
        squeeze = self.global_pool(x).view(b, c)
        excitation = torch.sigmoid(self.fc2(F.relu(self.fc1(squeeze))))
        excitation = excitation.view(b, c, 1, 1)
        return x * excitation

# --- ResNeXt Block with SE ---
class ResNeXtBlock(nn.Module):
    def __init__(self, in_channels, out_channels, cardinality=32, stride=1, downsample=False):
        super(ResNeXtBlock, self).__init__()
        # Bottleneck channels: typically a reduction factor is applied
        mid_channels = out_channels // 2
        
        # 1x1 convolution for dimension reduction
        self.conv_reduce = nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False)
        self.bn_reduce = nn.BatchNorm2d(mid_channels)
        
        # 3x3 grouped convolution: using the specified cardinality (number of groups)
        self.conv_conv = nn.Conv2d(mid_channels, mid_channels, kernel_size=3, stride=stride,
                                   padding=1, groups=cardinality, bias=False)
        self.bn_conv = nn.BatchNorm2d(mid_channels)
        
        # 1x1 convolution for dimension restoration
        self.conv_expand = nn.Conv2d(mid_channels, out_channels, kernel_size=1, bias=False)
        self.bn_expand = nn.BatchNorm2d(out_channels)
        
        # SE module to recalibrate features
        self.se = SEBlock(out_channels)
        
        # Shortcut connection in case of dimension mismatch or stride > 1
        self.downsample = None
        if downsample or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
        
        self.activation = nn.ReLU(inplace=True)
        
    def forward(self, x):
        identity = x
        
        out = self.conv_reduce(x)
        out = self.bn_reduce(out)
        out = self.activation(out)
        
        out = self.conv_conv(out)
        out = self.bn_conv(out)
        out = self.activation(out)
        
        out = self.conv_expand(out)
        out = self.bn_expand(out)
        
        # Apply SE attention
        out = self.se(out)
        
        # Adjust shortcut if needed
        if self.downsample is not None:
            identity = self.downsample(x)
        
        out += identity
        out = self.activation(out)
        return out

# --- Advanced ResNeXt-SE Model for CIFAR-10 ---
class ResNeXtSE_CIFAR10(nn.Module):
    def __init__(self, num_classes=10, cardinality=32):
        super(ResNeXtSE_CIFAR10, self).__init__()
        self.in_channels = 64
        # Initial convolution: for 32x32 RGB images
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.activation = nn.ReLU(inplace=True)
        
        # Create layers with increasing feature dimensions.
        # You can adjust the number of blocks per layer to balance depth and computation.
        self.layer1 = self._make_layer(64, num_blocks=3, cardinality=cardinality, stride=1)
        self.layer2 = self._make_layer(128, num_blocks=4, cardinality=cardinality, stride=2)
        self.layer3 = self._make_layer(256, num_blocks=6, cardinality=cardinality, stride=2)
        self.layer4 = self._make_layer(512, num_blocks=3, cardinality=cardinality, stride=2)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)
        
    def _make_layer(self, out_channels, num_blocks, cardinality, stride):
        layers = []
        # First block may downsample and increase dimensions
        layers.append(ResNeXtBlock(self.in_channels, out_channels, cardinality=cardinality,
                                   stride=stride, downsample=True))
        self.in_channels = out_channels
        # The remaining blocks maintain dimensions
        for _ in range(1, num_blocks):
            layers.append(ResNeXtBlock(out_channels, out_channels, cardinality=cardinality, stride=1))
        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.activation(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

# --- Early Stopping Implementation ---
class EarlyStopping:
    def __init__(self, patience=10, verbose=False, delta=0.01):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
            verbose (bool): If True, prints a message for each validation loss improvement.
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta

    def __call__(self, val_loss, model):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        """Saves model when validation loss decrease."""
        if self.verbose:
            print(f"Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model...")
        torch.save(model.state_dict(), 'checkpoint.pt')
        self.val_loss_min = val_loss

# --- Training Loop with Early Stopping ---
def train_model(model, train_loader, valid_loader, criterion, optimizer, scheduler, n_epochs=50, patience=10, device='cpu'):
    early_stopping = EarlyStopping(patience=patience, verbose=True)
    
    for epoch in range(1, n_epochs+1):
        model.train()
        train_loss = 0.0
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * data.size(0)
            
        scheduler.step()
        train_loss /= len(train_loader.dataset)
        
        # Validation phase
        model.eval()
        valid_loss = 0.0
        correct = 0
        with torch.no_grad():
            for data, target in valid_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                loss = criterion(output, target)
                valid_loss += loss.item() * data.size(0)
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()
        valid_loss /= len(valid_loader.dataset)
        valid_acc = correct / len(valid_loader.dataset)
        
        print(f"Epoch {epoch}, Train Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}, Valid Acc: {valid_acc:.4f}")
        
        # Check early stopping criteria
        early_stopping(valid_loss, model)
        if early_stopping.early_stop:
            print("Early stopping triggered.")
            break

Using device: mps


In [7]:
model = ResNeXtSE_CIFAR10(num_classes=10, cardinality=32).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.0005, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.5)
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [8]:
train_model(model, trainloader, valloader, criterion, optimizer, scheduler, n_epochs=50, patience=10, device=device)

Epoch 1, Train Loss: 1.5732, Valid Loss: 1.3181, Valid Acc: 0.5210
Validation loss decreased (inf --> 1.318115). Saving model...
Epoch 2, Train Loss: 1.1814, Valid Loss: 1.1604, Valid Acc: 0.5862
Validation loss decreased (1.318115 --> 1.160378). Saving model...
Epoch 3, Train Loss: 0.9707, Valid Loss: 0.9100, Valid Acc: 0.6804
Validation loss decreased (1.160378 --> 0.910014). Saving model...
Epoch 4, Train Loss: 0.7931, Valid Loss: 0.7751, Valid Acc: 0.7196
Validation loss decreased (0.910014 --> 0.775052). Saving model...
Epoch 5, Train Loss: 0.7093, Valid Loss: 0.7311, Valid Acc: 0.7388
Validation loss decreased (0.775052 --> 0.731124). Saving model...
Epoch 6, Train Loss: 0.6521, Valid Loss: 0.6698, Valid Acc: 0.7620
Validation loss decreased (0.731124 --> 0.669793). Saving model...
Epoch 7, Train Loss: 0.5663, Valid Loss: 0.6206, Valid Acc: 0.7830
Validation loss decreased (0.669793 --> 0.620552). Saving model...
Epoch 8, Train Loss: 0.5330, Valid Loss: 0.5892, Valid Acc: 0.7910


In [9]:
model.load_state_dict(torch.load('checkpoint.pt'))
model.eval()  # Set the model to evaluation mode

correct = 0
total = 0

with torch.no_grad():
    for data, target in testloader:
        data, target = data.to(device), target.to(device)
        outputs = model(data)
        # Get predictions by finding the class with the maximum logit
        _, predicted = outputs.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()

test_accuracy = 100 * correct / total
print(f"Test Accuracy: {test_accuracy:.2f}%")

  model.load_state_dict(torch.load('checkpoint.pt'))


Test Accuracy: 82.93%
