In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# Data preprocessing
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load CIFAR-10 dataset
# 32 x 32 x 3 images, 10 classes
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)

# Define CNN
# Architecture: 32x32 -> 16x
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(128, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.fc_layers = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(4*4*256, 1024),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )
    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)
        x = self.fc_layers(x)
        return x

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, verbose=True)

def train(epochs):
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for i, (inputs, labels) in enumerate(trainloader):
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            if i % 100 == 99:
                print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 100:.3f}')
                running_loss = 0.0
        
        # Validation
        val_loss = evaluate()
        scheduler.step(val_loss)

def evaluate():
    model.eval()
    correct = 0
    total = 0
    val_loss = 0
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    accuracy = 100. * correct / total
    print(f'Accuracy: {accuracy:.2f}%')
    return val_loss / len(testloader)

Files already downloaded and verified
Files already downloaded and verified


In [6]:
# train(10)

In [4]:
def save_model(model, path='cifar10_cnn.pth'):
    torch.save(model.state_dict(), path)
    print(f"Model saved to {path}")

def load_model(path='cifar10_cnn.pth'):
    model = CNN().to(device)
    model.load_state_dict(torch.load(path))
    model.eval()
    return model

In [7]:
save_model(model)

Model saved to cifar10_cnn.pth


In [7]:
# Load and use the model
loaded_model = load_model()
test_input = next(iter(testloader))[0].to(device)
with torch.no_grad():
    prediction = loaded_model(test_input)

In [8]:
test_input = next(iter(testloader))[0].to(device)

In [9]:
model(test_input)

tensor([[ 0.3135, -0.1740,  0.0978,  ...,  0.1255,  0.2784,  0.1732],
        [ 0.1293, -0.1733,  0.1706,  ...,  0.2687,  0.0901,  0.3126],
        [ 0.2627, -0.1726,  0.0231,  ...,  0.2350,  0.1472,  0.1028],
        ...,
        [ 0.2356, -0.0915,  0.1479,  ...,  0.0126,  0.2976,  0.1019],
        [ 0.1083, -0.1578,  0.0432,  ...,  0.0848,  0.2639,  0.2720],
        [ 0.1324, -0.2091,  0.0388,  ...,  0.0231,  0.1942,  0.2245]],
       grad_fn=<AddmmBackward0>)