In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from nn_utils import Net, DEVICE, TRAINLOADER, train_nn, test_nn, freeze_parameters

torch.cuda.empty_cache()

PATH = './nn-models/cifar10-nn-model'

# load the pretrained NN model
net = Net()
net.load_state_dict(torch.load(PATH))
net.to(device=DEVICE)

In [None]:
import sam

In [None]:

# print(f'Accuracy before freezing and randomizing: {test_nn(net=net, verbose=False)}')
freeze_parameters(net=net)
# print(f'Accuracy after freezing and randomizing: {test_nn(net=net, verbose=False)}')

eta = 0.01 # learning rate
rho = 2 # neigborhood size

base_optimizer = torch.optim.Adam
sam_optimizer = sam.SAM(net.parameters(), base_optimizer=base_optimizer)

accuracy_per_epoch_track = []
loss_per_epoch_track = []

# loop over the dataset multiple times
for epoch in range(20):
    running_loss = 0
    # loop over the dataset by mini-batch
    for mini_batch in TRAINLOADER:
        images = mini_batch[0].to(DEVICE)
        labels = mini_batch[1].to(DEVICE)

        preds = net(images) # forward mini-batch
        loss = F.cross_entropy(preds, labels)
        loss.backward()
        sam_optimizer.first_step(zero_grad=True)

        preds = net(images) # forward mini-batch
        loss = F.cross_entropy(preds, labels)
        loss.backward()
        sam_optimizer.second_step(zero_grad=True)

        running_loss += loss.item()

    accuracy = test_nn(net=net, verbose=False)
        
    # track
    accuracy_per_epoch_track.append(accuracy)
    loss_per_epoch_track.append(running_loss)

    print(f'Epoch: {epoch} -- Loss: {loss_per_epoch_track[-1]} -- Accuracy: {accuracy_per_epoch_track[-1]}')