In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from ndlinear import NdLinear
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, random_split
from torchvision.transforms import autoaugment, transforms


train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    autoaugment.AutoAugment(autoaugment.AutoAugmentPolicy.CIFAR10),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    transforms.RandomErasing(p=0.2)
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

# Load CIFAR-10 dataset
train_dataset = torchvision.datasets.CIFAR10(
    root='./data', 
    train=True,
    download=True, 
    transform=train_transform
)

train_size = int(0.9 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

# Validation
val_dataset.dataset.transform = test_transform

# Test dataset
test_dataset = torchvision.datasets.CIFAR10(
    root='./data', 
    train=False,
    download=True, 
    transform=test_transform
)

# Create data loaders
train_loader = DataLoader(
    train_dataset, 
    batch_size=128, 
    shuffle=True, 
    num_workers=2,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=128, 
    shuffle=False, 
    num_workers=2,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset, 
    batch_size=128, 
    shuffle=False, 
    num_workers=2,
    pin_memory=True
)


In [None]:

class EnhancedNdLinearCNN(nn.Module):
    def __init__(self, dropout_rate=0.4):
        super(EnhancedNdLinearCNN, self).__init__()
        
        # First block with residual connection
        self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.bn1_1 = nn.BatchNorm2d(64)
        self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.bn1_2 = nn.BatchNorm2d(64)
        self.shortcut1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=1),
            nn.BatchNorm2d(64)
        )
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(dropout_rate)
        
        # Second block with residual connection
        self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn2_1 = nn.BatchNorm2d(128)
        self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn2_2 = nn.BatchNorm2d(128)
        self.shortcut2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=1),
            nn.BatchNorm2d(128)
        )
        
        # Third block with residual connection
        self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn3_1 = nn.BatchNorm2d(256)
        self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.bn3_2 = nn.BatchNorm2d(256)
        self.shortcut3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=1),
            nn.BatchNorm2d(256)
        )
        
        # Fourth block with residual connection
        self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        self.bn4_1 = nn.BatchNorm2d(512)
        self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn4_2 = nn.BatchNorm2d(512)
        self.shortcut4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=1),
            nn.BatchNorm2d(512)
        )
        
        # Global average pooling
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        
        
        self.ndlinear1 = NdLinear(input_dims=(512,), hidden_size=(256,))
        self.bn1 = nn.BatchNorm1d(256)
        self.dropout1 = nn.Dropout(dropout_rate)
        
        self.ndlinear2 = NdLinear(input_dims=(256,), hidden_size=(128,))
        self.bn2 = nn.BatchNorm1d(128)
        self.dropout2 = nn.Dropout(dropout_rate)
        
        self.ndlinear3 = NdLinear(input_dims=(128,), hidden_size=(10,))

    def forward(self, x):
        # Block 1 with residual connection
        identity = self.shortcut1(x)
        x = F.relu(self.bn1_1(self.conv1_1(x)))
        x = self.bn1_2(self.conv1_2(x))
        x = F.relu(x + identity)
        x = self.pool(x)
        x = self.dropout(x)
        
        # Block 2 with residual connection
        identity = self.shortcut2(x)
        x = F.relu(self.bn2_1(self.conv2_1(x)))
        x = self.bn2_2(self.conv2_2(x))
        x = F.relu(x + identity)
        x = self.pool(x)
        x = self.dropout(x)
        
        # Block 3 with residual connection
        identity = self.shortcut3(x)
        x = F.relu(self.bn3_1(self.conv3_1(x)))
        x = self.bn3_2(self.conv3_2(x))
        x = F.relu(x + identity)
        x = self.pool(x)
        x = self.dropout(x)
        
        # Block 4 with residual connection
        identity = self.shortcut4(x)
        x = F.relu(self.bn4_1(self.conv4_1(x)))
        x = self.bn4_2(self.conv4_2(x))
        x = F.relu(x + identity)
        x = self.pool(x)
        x = self.dropout(x)
        
        # Keep your exact NdLinear structure
        x = self.global_avg_pool(x)
        x = x.view(x.size(0), -1)
        
        x = self.ndlinear1(x)
        x = self.bn1(x)
        x = F.relu(x, inplace=True)
        x = self.dropout1(x)
        
        x = self.ndlinear2(x)
        x = self.bn2(x)
        x = F.relu(x, inplace=True)
        x = self.dropout2(x)
        
        x = self.ndlinear3(x)
        return x
    
model = EnhancedNdLinearCNN()


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load('best_model.pth', map_location=device))
model.to(device)


def mixup_data(x, y, alpha=0.2):
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(x.device)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

In [None]:
def ensemble_predict(models, inputs):
    outputs = [model(inputs) for model in models]
    outputs = torch.stack(outputs)
    return outputs.mean(dim=0)

# Train multiple models with different seeds
models = []
for seed in [42, 123, 456, 789, 101]:
    torch.manual_seed(seed)
    model = EnhancedNdLinearCNN().to(device)
    # Train model...
    models.append(model)

# Evaluate ensemble
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = ensemble_predict(models, inputs)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

print(f'Ensemble Accuracy: {100 * correct / total:.2f}%')

In [12]:
def train(model, dataloader, optimizer, criterion, device, epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for i, (inputs, labels) in enumerate(dataloader):
        inputs, labels = inputs.to(device), labels.to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        
        # Statistics
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        if (i + 1) % 50 == 0:  # Print every 50 batches
            print(f'Epoch: {epoch+1}, Batch: {i+1}, Loss: {loss.item():.4f}, Acc: {100.*correct/total:.2f}%')
    
    epoch_loss = running_loss / len(dataloader)
    epoch_acc = 100. * correct / total
    print(f'Training Loss: {epoch_loss:.4f}, Training Acc: {epoch_acc:.2f}%')
    
    # Make sure to return these values
    return epoch_loss, epoch_acc

def validate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    val_loss = running_loss / len(dataloader)
    val_acc = 100. * correct / total
    print(f'Validation Loss: {val_loss:.4f}, Validation Acc: {val_acc:.2f}%')
    
    # Make sure to return these values
    return val_loss, val_acc

In [14]:
num_epochs = 100
best_acc = 0
patience = 20
counter = 0

for epoch in range(num_epochs):
    # Training phase
    train_loss, train_acc = train(model, train_loader, optimizer, criterion, device, epoch)
    
    # Validation phase
    val_loss, val_acc = validate(model, val_loader, criterion, device)
    
    # Update learning rate
    scheduler.step()
    current_lr = optimizer.param_groups[0]['lr']
    print(f'Current learning rate: {current_lr:.6f}')
    
    # Save best model
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), 'best_model.pth')
        print(f'Model saved with accuracy: {best_acc:.2f}%')
        counter = 0
    else:
        counter += 1
        
    # Early stopping
    if counter >= patience:
        print(f'Early stopping triggered after {epoch+1} epochs')
        break
    
    print(f'Epoch {epoch+1}/{num_epochs} completed')
    print(f'Best accuracy so far: {best_acc:.2f}%')

Epoch: 1, Batch: 50, Loss: 0.1368, Acc: 97.50%
Epoch: 1, Batch: 100, Loss: 0.0469, Acc: 97.38%
Epoch: 1, Batch: 150, Loss: 0.0572, Acc: 97.43%
Epoch: 1, Batch: 200, Loss: 0.0287, Acc: 97.36%
Epoch: 1, Batch: 250, Loss: 0.1514, Acc: 97.22%
Epoch: 1, Batch: 300, Loss: 0.1113, Acc: 97.16%
Epoch: 1, Batch: 350, Loss: 0.1414, Acc: 97.09%
Training Loss: 0.0868, Training Acc: 97.08%
Validation Loss: 0.2129, Validation Acc: 93.62%
Current learning rate: 0.000970
Model saved with accuracy: 93.62%
Epoch 1/100 completed
Best accuracy so far: 93.62%
Epoch: 2, Batch: 50, Loss: 0.1005, Acc: 97.45%
Epoch: 2, Batch: 100, Loss: 0.0817, Acc: 97.52%
Epoch: 2, Batch: 150, Loss: 0.0394, Acc: 97.46%
Epoch: 2, Batch: 200, Loss: 0.0546, Acc: 97.42%
Epoch: 2, Batch: 250, Loss: 0.1372, Acc: 97.33%
Epoch: 2, Batch: 300, Loss: 0.1575, Acc: 97.25%
Epoch: 2, Batch: 350, Loss: 0.0730, Acc: 97.17%
Training Loss: 0.0832, Training Acc: 97.16%
Validation Loss: 0.2288, Validation Acc: 93.68%
Current learning rate: 0.0009

In [None]:
def test(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    
    accuracy = 100. * correct / total
    print(f'Test Accuracy: {accuracy:.2f}%')
    return accuracy

# Load best model and evaluate
model.load_state_dict(torch.load('best_model.pth'))
test_acc = test(model, test_loader, device)