In [3]:
import os
os.chdir("/home/peppe/01_Study/01_University/Semester/2/Intro_to_ML/Project/Code/models_methods")

In [4]:
import argparse
import torch

from utility.smooth_cross_entropy import smooth_crossentropy
from data.cifar import Cifar
from utility.log import Log
from utility.initialize import initialize
from utility.step_lr import StepLR
from utility.bypass_bn import enable_running_stats, disable_running_stats
from torchvision.models import efficientnet_b0

import sys; sys.path.append("..")
from methods.SAM.sam import SAM

In [10]:
def get_model(weights="Default"):
    model = efficientnet_b0(pretrained=weights)
    model.classifier[1] = torch.nn.Linear(1280, 10, bias=True)
    for param in model.parameters():
        param.requires_grad = False
    model.classifier[1].weight.requires_grad = True
    model.classifier[1].bias.requires_grad = True
    return model

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms

# Define transformations to apply to the dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

root = "/home/peppe/01_Study/01_University/Semester/2/Intro_to_ML/Project/data"
# Load MNIST training dataset
trainset = torchvision.datasets.MNIST(root=root, train=True,
                                      download=True, transform=transform)

# Split the training set into training and validation sets
trainset, valset = torch.utils.data.random_split(trainset, [50000, 10000])

# Load MNIST testing dataset
testset = torchvision.datasets.MNIST(root=root, train=False,
                                     download=True, transform=transform)

# Create data loaders
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
                                          shuffle=True, num_workers=2)
valloader = torch.utils.data.DataLoader(valset, batch_size=64,
                                        shuffle=False, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
                                         shuffle=False, num_workers=2)

# Define the classes in MNIST
classes = tuple(str(i) for i in range(10))

data_loaders = {
    "train_loader": trainloader,
    "val_loader": valloader,
    "test_loader": testloader
}          

In [None]:
if __name__ == "__main__":
    
    initialize(seed=42)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    log = Log(log_each=10)
    model = SimpleCNN()
    base_optimizer = torch.optim.SGD
    optimizer = SAM(model.parameters(), base_optimizer, rho=2.0, 
                    adaptive=True, lr=0.1, momentum=0.9, weight_decay=0.0005)
    scheduler = StepLR(optimizer, 0.1, 5)

    for epoch in range(5):
        model.train()
        log.train(len_dataset=len(data_loaders["train_loader"]))

        for batch in data_loaders["train_loader"]:
            inputs, targets = (b.to(device) for b in batch)

            # first forward-backward step
            enable_running_stats(model)
            predictions = model(inputs)
            loss = smooth_crossentropy(predictions, targets, 
                                       smoothing=0.1)
            loss.mean().backward()
            optimizer.first_step(zero_grad=True)

            # second forward-backward step
            disable_running_stats(model)
            smooth_crossentropy(model(inputs), targets, smoothing=0.1).mean().backward()
            optimizer.second_step(zero_grad=True)

            with torch.no_grad():
                correct = torch.argmax(predictions.data, 1) == targets
                log(model, loss.cpu(), correct.cpu(), scheduler.lr())
                scheduler(epoch)

        model.eval()
        log.eval(len_dataset=len(data_loaders["test_loader"]))

        with torch.no_grad():
            for batch in data_loaders["test_loader"]:
                inputs, targets = (b.to(device) for b in batch)

                predictions = model(inputs)
                loss = smooth_crossentropy(predictions, targets)
                correct = torch.argmax(predictions, 1) == targets
                log(model, loss.cpu(), correct.cpu())

    log.flush()


In [7]:
import argparse
import torch

from utility.smooth_cross_entropy import smooth_crossentropy
from data.cifar import Cifar
from utility.log import Log
from utility.initialize import initialize
from utility.step_lr import StepLR
from utility.bypass_bn import enable_running_stats, disable_running_stats
from torchvision.models import efficientnet_b0

import sys; sys.path.append("..")

┏━━━━━━━━━━━━━━┳━━━━━━━╸T╺╸R╺╸A╺╸I╺╸N╺━━━━━━━┳━━━━━━━╸S╺╸T╺╸A╺╸T╺╸S╺━━━━━━━┳━━━━━━━╸V╺╸A╺╸L╺╸I╺╸D╺━━━━━━━┓
┃              ┃              ╷              ┃              ╷              ┃              ╷              ┃
┃       epoch  ┃        loss  │    accuracy  ┃        l.r.  │     elapsed  ┃        loss  │    accuracy  ┃
┠──────────────╂──────────────┼──────────────╂──────────────┼──────────────╂──────────────┼──────────────┨
┃           0  ┃      0.2936  │     90.77 %  ┃   1.000e-01  │   00:12 min  ┃┈███████████████████████████┈┨      0.1142  │     97.48 %  ┃
┃           1  ┃      0.0946  │     98.10 %  ┃   1.000e-01  │   00:12 min  ┃┈███████████████████████████┈┨      0.0739  │     98.61 %  ┃
┃           2  ┃      0.0628  │     98.75 %  ┃   2.000e-02  │   00:04 min  ┠┈█████████┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┨

KeyboardInterrupt: 