In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet50
import time

# Device configuration
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Data preprocessing and augmentation
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])

# Load CIFAR-100 dataset
trainset = torchvision.datasets.CIFAR100(
    root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=5120, shuffle=True, num_workers=64)

testset = torchvision.datasets.CIFAR100(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=4)

# Define the ResNet-50 model
model = resnet50(num_classes=100)
model = model.to(device)

if device == 'cuda':
    model = torch.nn.DataParallel(model)
    cudnn.benchmark = True

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1,
                      momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

# Training
def train(epoch):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        print(f'Epoch: {epoch} | Batch: {batch_idx+1}/{len(trainloader)} | Loss: {train_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f}% ({correct}/{total})')

# Testing
def test(epoch):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            print(f'Test | Batch: {batch_idx+1}/{len(testloader)} | Loss: {test_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f}% ({correct}/{total})')

    # Save checkpoint
    # torch.save(model.state_dict(), f'./checkpoint/resnet50_epoch{epoch}.pth')

# Main training loop
num_epochs = 5
start = time.time()
for epoch in range(num_epochs):
    train(epoch)
    # test(epoch)
    scheduler.step()

end = time.time()
print("Training finished.",end-start)


Files already downloaded and verified
Files already downloaded and verified
Epoch: 0 | Batch: 1/10 | Loss: 4.975 | Acc: 0.801% (41/5120)
Epoch: 0 | Batch: 2/10 | Loss: 4.980 | Acc: 0.938% (96/10240)
Epoch: 0 | Batch: 3/10 | Loss: 5.080 | Acc: 1.107% (170/15360)
Epoch: 0 | Batch: 4/10 | Loss: 5.213 | Acc: 1.235% (253/20480)
Epoch: 0 | Batch: 5/10 | Loss: 5.489 | Acc: 1.199% (307/25600)
Epoch: 0 | Batch: 6/10 | Loss: 5.853 | Acc: 1.221% (375/30720)
Epoch: 0 | Batch: 7/10 | Loss: 6.164 | Acc: 1.183% (424/35840)
Epoch: 0 | Batch: 8/10 | Loss: 6.422 | Acc: 1.174% (481/40960)
Epoch: 0 | Batch: 9/10 | Loss: 6.563 | Acc: 1.187% (547/46080)
Epoch: 0 | Batch: 10/10 | Loss: 6.706 | Acc: 1.186% (593/50000)
Epoch: 1 | Batch: 1/10 | Loss: 7.107 | Acc: 1.855% (95/5120)
Epoch: 1 | Batch: 2/10 | Loss: 6.799 | Acc: 1.914% (196/10240)
Epoch: 1 | Batch: 3/10 | Loss: 6.617 | Acc: 1.901% (292/15360)
Epoch: 1 | Batch: 4/10 | Loss: 6.341 | Acc: 1.914% (392/20480)
Epoch: 1 | Batch: 5/10 | Loss: 6.172 | Acc: 1.