In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from nn_utils import Net, DEVICE, TRAINLOADER, train_nn, test_nn, freeze_parameters

torch.cuda.empty_cache()

PATH = './nn-models/cifar10-nn-model'

# load the pretrained NN model
net = Net()
net.load_state_dict(torch.load(PATH))
net.to(device=DEVICE)

Files already downloaded and verified
Files already downloaded and verified


Net(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (batchNorm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (batchNorm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (batchNorm3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (batchNorm4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv5): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (batchNorm5): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv6): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (batchNorm6): BatchNorm2d(512, eps=1e-0

In [2]:
import sam

In [3]:

# print(f'Accuracy before freezing and randomizing: {test_nn(net=net, verbose=False)}')
freeze_parameters(net=net)
# print(f'Accuracy after freezing and randomizing: {test_nn(net=net, verbose=False)}')

eta = 0.01 # learning rate
rho = 2 # neigborhood size

base_optimizer = torch.optim.Adam
sam_optimizer = sam.SAM(net.parameters(), base_optimizer=base_optimizer)

accuracy_per_epoch_track = []
loss_per_epoch_track = []

# loop over the dataset multiple times
for epoch in range(50):
    running_loss = 0
    # loop over the dataset by mini-batch
    for mini_batch in TRAINLOADER:
        images = mini_batch[0].to(DEVICE)
        labels = mini_batch[1].to(DEVICE)

        preds = net(images) # forward mini-batch
        loss = F.cross_entropy(preds, labels)
        loss.backward()
        sam_optimizer.first_step(zero_grad=True)

        preds = net(images) # forward mini-batch
        loss = F.cross_entropy(preds, labels)
        loss.backward()
        sam_optimizer.second_step(zero_grad=True)

        running_loss += loss.item()

    accuracy = test_nn(net=net, verbose=False)
        
    # track
    accuracy_per_epoch_track.append(accuracy)
    loss_per_epoch_track.append(running_loss)

    print(f'Epoch: {epoch} -- Loss: {loss_per_epoch_track[-1]} -- Accuracy: {accuracy_per_epoch_track[-1]}')

Epoch: 0 -- Loss: 1802.0426481962204 -- Accuracy: 57
Epoch: 1 -- Loss: 573.5988614559174 -- Accuracy: 67
Epoch: 2 -- Loss: 321.6241240054369 -- Accuracy: 72
Epoch: 3 -- Loss: 211.65936678647995 -- Accuracy: 74
Epoch: 4 -- Loss: 148.8118852674961 -- Accuracy: 75
Epoch: 5 -- Loss: 109.58709641546011 -- Accuracy: 77
Epoch: 6 -- Loss: 85.61287180706859 -- Accuracy: 77
Epoch: 7 -- Loss: 67.4212055914104 -- Accuracy: 78
Epoch: 8 -- Loss: 54.6455521825701 -- Accuracy: 79
Epoch: 9 -- Loss: 43.0169026828371 -- Accuracy: 79
Epoch: 10 -- Loss: 38.028302820399404 -- Accuracy: 79
Epoch: 11 -- Loss: 30.778166708536446 -- Accuracy: 80
Epoch: 12 -- Loss: 26.979544115252793 -- Accuracy: 80
Epoch: 13 -- Loss: 22.319208657252602 -- Accuracy: 80
Epoch: 14 -- Loss: 18.8590752807213 -- Accuracy: 80
Epoch: 15 -- Loss: 17.548065606039017 -- Accuracy: 80
Epoch: 16 -- Loss: 13.532572646159679 -- Accuracy: 80
Epoch: 17 -- Loss: 12.58299790439196 -- Accuracy: 81
Epoch: 18 -- Loss: 11.146079835569253 -- Accuracy: 