In [8]:
import torch
import torchvision
import os

In [9]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

root_dir='data'
trainset = torchvision.datasets.CIFAR10(root=root_dir, train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
testset = torchvision.datasets.CIFAR10(root=root_dir, train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [10]:
print(len(trainset),len(testset))
print(len(trainloader),len(testloader))

50000 10000
782 157


In [11]:
model = torchvision.models.resnet18(weights=None)
model.fc = torch.nn.Linear(model.fc.in_features, 10)

In [12]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)

In [13]:
checkpoint_dir = 'checkpoints_resnet'
os.makedirs(checkpoint_dir, exist_ok=True)
def find_last_checkpoint():
    checkpoints = [f for f in os.listdir(checkpoint_dir) if f.startswith('model_epoch_')]
    if not checkpoints:
        return None, 0
    checkpoints.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]), reverse=True)
    last_checkpoint = checkpoints[0]
    last_epoch = int(last_checkpoint.split('_')[-1].split('.')[0])
    return os.path.join(checkpoint_dir, last_checkpoint), last_epoch

In [14]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model.to(device)
def train_model(model, trainloader, criterion, optimizer, num_epochs=10):
    checkpoint_path, start_epoch = find_last_checkpoint()
    if checkpoint_path:
        print(f"Resuming training from {checkpoint_path}")
        model.load_state_dict(torch.load(checkpoint_path))
    else:
        print("No checkpoint found. Starting from scratch.")
        start_epoch = 0
    for epoch in range(start_epoch, num_epochs):
        model.train()
        running_loss = 0.0
        for inputs, labels in trainloader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch + 1}, Loss: {running_loss / len(trainloader):.4f}")
        checkpoint_path = os.path.join(checkpoint_dir, f'model_epoch_{epoch + 1}.pth')
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Model saved at {checkpoint_path}")

train_model(model, trainloader, criterion, optimizer, num_epochs=15)

Resuming training from checkpoints_resnet/model_epoch_12.pth


  model.load_state_dict(torch.load(checkpoint_path))


Epoch 13, Loss: 0.7214
Model saved at checkpoints_resnet/model_epoch_13.pth
Epoch 14, Loss: 0.7092
Model saved at checkpoints_resnet/model_epoch_14.pth
Epoch 15, Loss: 0.7011
Model saved at checkpoints_resnet/model_epoch_15.pth


In [15]:
def evaluate_model(testloader):
    checkpoint_path, start_epoch = find_last_checkpoint()
    if checkpoint_path:
        print(f"Loading latest checkpoint from {checkpoint_path}")
        model.load_state_dict(torch.load(checkpoint_path))
    else:
        print("No checkpoint found. Evaluating with the current model state.")
        return "No checkpoints"
    model.to('mps').eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    print(f"Accuracy: {accuracy}%")
    return accuracy
evaluate_model(testloader=testloader)

Loading latest checkpoint from checkpoints_resnet/model_epoch_15.pth


  model.load_state_dict(torch.load(checkpoint_path))


Accuracy: 70.13%


70.13

In [16]:
print(model)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  