In [16]:
import torch
import torchvision
import os

In [17]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

root_dir='data'
trainset = torchvision.datasets.CIFAR10(root=root_dir, train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
testset = torchvision.datasets.CIFAR10(root=root_dir, train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [18]:
print(len(trainset),len(testset))
print(len(trainloader),len(testloader))

50000 10000
782 157


In [19]:
import torch
import torch.nn as nn

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(128 * 16 * 16, 256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = nn.ReLU()(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = nn.ReLU()(x)
        x = self.conv3(x)
        x = nn.ReLU()(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.fc2(x)
        return x

model = CNN()

In [20]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [21]:
checkpoint_dir = 'checkpoints_cnn'
os.makedirs(checkpoint_dir, exist_ok=True)

def find_last_checkpoint():
    checkpoints = [f for f in os.listdir(checkpoint_dir) if f.startswith('model_epoch_')]
    if not checkpoints:
        return None, 0
    checkpoints.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]), reverse=True)
    last_checkpoint = checkpoints[0]
    last_epoch = int(last_checkpoint.split('_')[-1].split('.')[0])
    return os.path.join(checkpoint_dir, last_checkpoint), last_epoch

In [22]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model.to(device)
def train_model(model, trainloader, criterion, optimizer, num_epochs=10):
    checkpoint_path, start_epoch = find_last_checkpoint()
    if checkpoint_path:
        print(f"Resuming training from {checkpoint_path}")
        model.load_state_dict(torch.load(checkpoint_path))
    else:
        print("No checkpoint found. Starting from scratch.")
        start_epoch = 0
    for epoch in range(start_epoch, num_epochs):
        running_loss = 0.0
        for inputs, labels in trainloader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch + 1}, Loss: {running_loss / len(trainloader)}")
        checkpoint_path = os.path.join(checkpoint_dir, f'model_epoch_{epoch + 1}.pth')
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Model saved at {checkpoint_path}")
train_model(model, trainloader, criterion, optimizer, num_epochs=15)

No checkpoint found. Starting from scratch.
Epoch 1, Loss: 1.3559170074170204
Model saved at checkpoints_cnn/model_epoch_1.pth
Epoch 2, Loss: 0.9570983177255792
Model saved at checkpoints_cnn/model_epoch_2.pth
Epoch 3, Loss: 0.7832975501904402
Model saved at checkpoints_cnn/model_epoch_3.pth
Epoch 4, Loss: 0.6453608965782254
Model saved at checkpoints_cnn/model_epoch_4.pth
Epoch 5, Loss: 0.5180220688929034
Model saved at checkpoints_cnn/model_epoch_5.pth
Epoch 6, Loss: 0.4018580437354419
Model saved at checkpoints_cnn/model_epoch_6.pth
Epoch 7, Loss: 0.3021884016178148
Model saved at checkpoints_cnn/model_epoch_7.pth
Epoch 8, Loss: 0.23489788625760913
Model saved at checkpoints_cnn/model_epoch_8.pth
Epoch 9, Loss: 0.18700672425520237
Model saved at checkpoints_cnn/model_epoch_9.pth
Epoch 10, Loss: 0.15444398666863018
Model saved at checkpoints_cnn/model_epoch_10.pth
Epoch 11, Loss: 0.13865851874872948
Model saved at checkpoints_cnn/model_epoch_11.pth
Epoch 12, Loss: 0.12487255945287244

In [23]:
def evaluate_model(testloader):
    checkpoint_path, start_epoch = find_last_checkpoint()
    if checkpoint_path:
        print(f"Loading latest checkpoint from {checkpoint_path}")
        model.load_state_dict(torch.load(checkpoint_path))
    else:
        print("No checkpoint found. Evaluating with the current model state.")
        return "No checkpoints"
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    print(f"Accuracy: {accuracy}%")
    return accuracy
evaluate_model(testloader=testloader)

Loading latest checkpoint from checkpoints_cnn/model_epoch_15.pth


  model.load_state_dict(torch.load(checkpoint_path))


Accuracy: 67.64%


67.64

In [24]:
print(model)

CNN(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fc1): Linear(in_features=32768, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=10, bias=True)
)
