In [1]:
import os
os.chdir("/home/peppe/01_Study/01_University/Semester/2/Intro_to_ML/Project/Code/models_methods")

In [2]:
import argparse
import torch

from utility.smooth_cross_entropy import smooth_crossentropy
from data.cifar import Cifar
from utility.log import Log
from utility.initialize import initialize
from utility.step_lr import StepLR
from utility.bypass_bn import enable_running_stats, disable_running_stats
from torchvision.models import efficientnet_b0

import sys; sys.path.append("..")
from methods.SAM.sam import SAM

In [10]:
def get_model(weights="Default"):
    model = efficientnet_b0(pretrained=weights)
    model.classifier[1] = torch.nn.Linear(1280, 10, bias=True)
    for param in model.parameters():
        param.requires_grad = False
    model.classifier[1].weight.requires_grad = True
    model.classifier[1].bias.requires_grad = True
    return model

In [None]:
if __name__ == "__main__":
    model = get_model(weights="Default") # ImageNet1K_V1 weights
    log = Log(log_each=10)
    base_optimizer = torch.optim.SGD
    optimizer = SAM(model.parameters(), 
                    base_optimizer, 
                    rho=2, 
                    adaptive=False, # True if you want to use the Adaptive SAM.
                    lr=0.1, momentum=0.9, weight_decay=0.0005)
    scheduler = StepLR(optimizer, learning_rate=0.1, total_epochs=1)

    initialize(seed=42)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    dataset = Cifar(batch_size=128, # Batch size used in the training and validation loop
                    threads=2) # Number of CPU threads for dataloaders

    for epoch in range(1):
        model.train()
        log.train(len_dataset=len(dataset.train))

        for batch in dataset.train:
            inputs, targets = (b.to(device) for b in batch)

            # first forward-backward step
            enable_running_stats(model)
            predictions = model(inputs)
            loss = smooth_crossentropy(predictions, targets, 
                                       smoothing=0.1) # if smoothing=0.0, it's the same as nn.CrossEntropyLoss
            loss.mean().backward()
            optimizer.first_step(zero_grad=True)

            # second forward-backward step
            disable_running_stats(model)
            smooth_crossentropy(model(inputs), targets, smoothing=0.1).mean().backward()
            optimizer.second_step(zero_grad=True)

            with torch.no_grad():
                correct = torch.argmax(predictions.data, 1) == targets
                log(model, loss.cpu(), correct.cpu(), scheduler.lr())
                scheduler(epoch)

        model.eval()
        log.eval(len_dataset=len(dataset.test))

        with torch.no_grad():
            for batch in dataset.test:
                inputs, targets = (b.to(device) for b in batch)

                predictions = model(inputs)
                loss = smooth_crossentropy(predictions, targets)
                correct = torch.argmax(predictions, 1) == targets
                log(model, loss.cpu(), correct.cpu())

    log.flush()
