In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18, ResNet18_Weights
import time
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

In [7]:
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_float32_matmul_precision('high')

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010)),
    transforms.RandomErasing(p=0.5)
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)

def get_resnet18():
    model = resnet18(weights=None)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()
    model.fc = nn.Linear(512, 10)
    return model.to(device)

100%|██████████| 170M/170M [00:14<00:00, 12.0MB/s]


In [10]:
lr = 0.1
epochs = 200
model = get_resnet18()

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

scaler = torch.amp.GradScaler('cuda')

start_time = time.time()

for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for images, labels in trainloader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()

        with torch.amp.autocast('cuda'):
            outputs = model(images)
            loss = criterion(outputs, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item()

    scheduler.step()

    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Epoch [{epoch+1}] | Current lr: {scheduler.get_last_lr()[0]:.4f} | Loss: {running_loss/len(trainloader):.3f} | Acc: {accuracy:.2f}%")

torch.save(model.state_dict(), "saved_models/regular_resnet18.pt")

total_time = time.time() - start_time
print(f"\nTotal training time: {total_time:.2f} seconds")

Epoch [1] | Current lr: 0.1000 | Loss: 2.221 | Acc: 30.69%
Epoch [2] | Current lr: 0.1000 | Loss: 1.838 | Acc: 44.01%
Epoch [3] | Current lr: 0.0999 | Loss: 1.616 | Acc: 43.40%
Epoch [4] | Current lr: 0.0999 | Loss: 1.417 | Acc: 51.13%
Epoch [5] | Current lr: 0.0998 | Loss: 1.270 | Acc: 64.35%
Epoch [6] | Current lr: 0.0998 | Loss: 1.144 | Acc: 70.49%
Epoch [7] | Current lr: 0.0997 | Loss: 1.053 | Acc: 71.89%
Epoch [8] | Current lr: 0.0996 | Loss: 0.983 | Acc: 73.95%
Epoch [9] | Current lr: 0.0995 | Loss: 0.939 | Acc: 75.04%
Epoch [10] | Current lr: 0.0994 | Loss: 0.897 | Acc: 75.21%
Epoch [11] | Current lr: 0.0993 | Loss: 0.879 | Acc: 72.02%
Epoch [12] | Current lr: 0.0991 | Loss: 0.844 | Acc: 73.83%
Epoch [13] | Current lr: 0.0990 | Loss: 0.825 | Acc: 79.32%
Epoch [14] | Current lr: 0.0988 | Loss: 0.799 | Acc: 74.95%
Epoch [15] | Current lr: 0.0986 | Loss: 0.790 | Acc: 74.61%
Epoch [16] | Current lr: 0.0984 | Loss: 0.779 | Acc: 77.27%
Epoch [17] | Current lr: 0.0982 | Loss: 0.767 | A