In [8]:
import os
from typing import List, Dict
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision import datasets
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

BATCH_SIZE = 4

print(f"DEVICE: {device}")
start_epoch = 0
end_epoch = 5
best_acc = 0 
if not os.path.exists('./checkpoint_resnet18/resnet18_model_best.pth'):
  resnet18 = models.resnet18(pretrained=True)
  resnet18 = resnet18.to(device)
else:
  resnet18 = torch.load('./checkpoint_resnet18/resnet18_model_best.pth')
  print(f"Loaded saved model ./checkpoint_resnet18/resnet18_model_best.pth")

if os.path.exists('./checkpoint_resnet18/resnet18_checkpoint_best.pth'):
  checkpoint = torch.load('./checkpoint_resnet18/resnet18_checkpoint_best.pth')
  start_epoch = checkpoint['epoch']+1
  best_acc = checkpoint['acc']
  print(f"Loaded checkpoint, starting from epoch {start_epoch}")


normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
transform_cifar10 = transforms.Compose([transforms.RandomResizedCrop(size=(224, 224), antialias=True),
                                         transforms.RandomHorizontalFlip(p=0.5),
                                         transforms.ToTensor(),normalize])
train_dataset = datasets.CIFAR10(
        root = './data',
        train = True,
        download = True,
        transform=transform_cifar10)

test_dataset = datasets.CIFAR10(
        root = './data',
        train = False,
        download = True,
        transform=transform_cifar10)  

train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size= BATCH_SIZE, shuffle=True,
        num_workers= 2, pin_memory=True)

test_loader = torch.utils.data.DataLoader(
        test_dataset,batch_size=BATCH_SIZE, shuffle=False,
        num_workers=2, pin_memory=True)

def unfreeze_layers(model, n_layers_to_unfreeze):
    total_layers = len(list(model.parameters()))
    start_idx = total_layers - n_layers_to_unfreeze
    for i, param in enumerate(model.parameters()):
        if i >= start_idx:
            param.requires_grad = True
        else:
            param.requires_grad = False

def test(model, criterion, test_loader, epoch):
  model.eval() 
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for images, targets in test_loader:
      images, targets = images.to(device), targets.to(device)
      output = model(images)
      _, pred = torch.max(output,1)
      test_loss += criterion(output, targets).item()
      pred = output.data.max(1, keepdim=True)[1] 
      correct += (pred == targets).sum().item()
  
  test_loss /= len(test_loader.dataset)
  print(f'Test result on epoch {epoch}: Avg loss is {test_loss}, Accuracy: {100.*correct/len(test_loader.dataset)}%')
  return 100.*correct/len(test_loader.dataset)

def train(model, criterion, train_loader, epoch):
  model.train() # we need to set the mode for our model

  for batch_idx, (images, targets) in enumerate(train_loader):
    images, targets = images.to(device), targets.to(device)
    optimizer.zero_grad()
    output = model(images)
    loss = criterion(output, targets)
    loss.backward()
    optimizer.step()
    if batch_idx % 10000 == 0:
      print(f'Epoch {epoch}: [{batch_idx*len(images)}/{len(train_loader.dataset)}] Loss: {loss.item()}')


for param in resnet18.parameters():
    param.requires_grad = False

unfreeze_layers(resnet18, n_layers_to_unfreeze=2)

num_classes = 10
resnet18.fc = nn.Linear(resnet18.fc.in_features, num_classes).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(resnet18.parameters(), lr=0.001, momentum=0.8)

for epoch in range(start_epoch, end_epoch):
    train(resnet18,criterion,test_loader,epoch)
    acc = test(resnet18,criterion,test_loader,epoch)
    if acc > best_acc:
        print('Saving..')
        state = {
            'state_dict': resnet18.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint_resnet18'):
            os.mkdir('checkpoint_resnet18')
        torch.save(state, './checkpoint_resnet18/resnet18_checkpoint_best.pth')
        torch.save(resnet18, './checkpoint_resnet18/resnet18_model_best.pth')
        best_acc = acc


DEVICE: cuda
Loaded saved model ./checkpoint_resnet18/resnet18_model.pth
Loaded checkpoint, starting from epoch 3
Files already downloaded and verified
Files already downloaded and verified
Epoch 3: [0/10000] Loss: 2.541553497314453
Test result on epoch 3: Avg loss is 0.3287594131141901, Accuracy: 85.39%
Saving..
Epoch 4: [0/10000] Loss: 1.3127593994140625
Test result on epoch 4: Avg loss is 0.31796397340744736, Accuracy: 85.59%
Saving..
