In [1]:
import os
os.chdir("/home/peppe/01_Study/01_University/Semester/2/Intro_to_ML/Project/Code/models_methods")

In [6]:
import argparse
import torch

from utility.smooth_cross_entropy import smooth_crossentropy
from data.cifar import Cifar
from utility.log import Log
from utility.initialize import initialize
from utility.step_lr import StepLR
from utility.bypass_bn import enable_running_stats, disable_running_stats
from torchvision.models import efficientnet_b0

import sys; sys.path.append("..")
from methods.SAM.sam import SAM

In [10]:
def get_model(weights="Default"):
    model = efficientnet_b0(pretrained=weights)
    model.classifier[1] = torch.nn.Linear(1280, 10, bias=True)
    for param in model.parameters():
        param.requires_grad = False
    model.classifier[1].weight.requires_grad = True
    model.classifier[1].bias.requires_grad = True
    return model

In [None]:
if __name__ == "__main__":
    model = get_model(weights="Default") # ImageNet1K_V1 weights
    log = Log(log_each=10)
    base_optimizer = torch.optim.SGD
    optimizer = SAM(model.parameters(), 
                    base_optimizer, 
                    rho=2, 
                    adaptive=False, # True if you want to use the Adaptive SAM.
                    lr=0.1, momentum=0.9, weight_decay=0.0005)
    scheduler = StepLR(optimizer, learning_rate=0.1, total_epochs=1)

    initialize(seed=42)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    dataset = Cifar(batch_size=128, # Batch size used in the training and validation loop
                    threads=2) # Number of CPU threads for dataloaders

    for epoch in range(1):
        model.train()
        log.train(len_dataset=len(dataset.train))

        for batch in dataset.train:
            inputs, targets = (b.to(device) for b in batch)

            # first forward-backward step
            enable_running_stats(model)
            predictions = model(inputs)
            loss = smooth_crossentropy(predictions, targets, 
                                       smoothing=0.1) # if smoothing=0.0, it's the same as nn.CrossEntropyLoss
            loss.mean().backward()
            optimizer.first_step(zero_grad=True)

            # second forward-backward step
            disable_running_stats(model)
            smooth_crossentropy(model(inputs), targets, smoothing=0.1).mean().backward()
            optimizer.second_step(zero_grad=True)

            with torch.no_grad():
                correct = torch.argmax(predictions.data, 1) == targets
                log(model, loss.cpu(), correct.cpu(), scheduler.lr())
                scheduler(epoch)

        model.eval()
        log.eval(len_dataset=len(dataset.test))

        with torch.no_grad():
            for batch in dataset.test:
                inputs, targets = (b.to(device) for b in batch)

                predictions = model(inputs)
                loss = smooth_crossentropy(predictions, targets)
                correct = torch.argmax(predictions, 1) == targets
                log(model, loss.cpu(), correct.cpu())

    log.flush()


In [14]:
if __name__ == "__main__":
    model = SimpleCNN()
    log = Log(log_each=10)
    base_optimizer = torch.optim.SGD
    optimizer = SAM(model.parameters(), 
                    base_optimizer, 
                    rho=2, 
                    adaptive=False, # True if you want to use the Adaptive SAM.
                    lr=0.1, momentum=0.9, weight_decay=0.0005)
    scheduler = StepLR(optimizer, learning_rate=0.1, total_epochs=10)

    initialize(seed=42)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    for epoch in range(10):
        model.train()
        log.train(len_dataset=len(data_loaders["train_loader"]))

        for batch in data_loaders["train_loader"]:
            inputs, targets = (b.to(device) for b in batch)

            # first forward-backward step
            enable_running_stats(model)
            predictions = model(inputs)
            loss = smooth_crossentropy(predictions, targets, 
                                       smoothing=0.1) # if smoothing=0.0, it's the same as nn.CrossEntropyLoss
            loss.mean().backward()
            optimizer.first_step(zero_grad=True)

            # second forward-backward step
            disable_running_stats(model)
            smooth_crossentropy(model(inputs), targets, smoothing=0.1).mean().backward()
            optimizer.second_step(zero_grad=True)

            with torch.no_grad():
                correct = torch.argmax(predictions.data, 1) == targets
                log(model, loss.cpu(), correct.cpu(), scheduler.lr())
                scheduler(epoch)

        model.eval()
        log.eval(len_dataset=len(data_loaders["test_loader"]))

        with torch.no_grad():
            for batch in data_loaders["test_loader"]:
                inputs, targets = (b.to(device) for b in batch)

                predictions = model(inputs)
                loss = smooth_crossentropy(predictions, targets)
                correct = torch.argmax(predictions, 1) == targets
                log(model, loss.cpu(), correct.cpu())

    log.flush()


┏━━━━━━━━━━━━━━┳━━━━━━━╸T╺╸R╺╸A╺╸I╺╸N╺━━━━━━━┳━━━━━━━╸S╺╸T╺╸A╺╸T╺╸S╺━━━━━━━┳━━━━━━━╸V╺╸A╺╸L╺╸I╺╸D╺━━━━━━━┓
┃              ┃              ╷              ┃              ╷              ┃              ╷              ┃
┃       epoch  ┃        loss  │    accuracy  ┃        l.r.  │     elapsed  ┃        loss  │    accuracy  ┃
┠──────────────╂──────────────┼──────────────╂──────────────┼──────────────╂──────────────┼──────────────┨
┃           0  ┃      1.7643  │     10.38 %  ┃   1.000e-01  │   00:09 min  ┃┈███████████████████████████┈┨      1.7642  │     10.09 %  ┃
┃           1  ┃      1.7643  │     10.08 %  ┃   1.000e-01  │   00:10 min  ┃┈███████████████████████████┈┨      1.7752  │      9.82 %  ┃
┃           2  ┃      1.7653  │     10.15 %  ┃   1.000e-01  │   00:11 min  ┃┈███████████████████████████┈┨      1.7600  │     10.10 %  ┃
┃           3  ┃      1.7584  │     10.79 %  ┃   2.000e-02  │   00:11 min  ┃┈███████████████████████████┈┨      1.7583  │     10.28 %  ┃
┃           4  ┃      1.

-----

# Test

In [3]:
import torch
import torchvision
import torchvision.transforms as transforms

# Define transformations to apply to the dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

root = "/home/peppe/01_Study/01_University/Semester/2/Intro_to_ML/Project/data"
# Load MNIST training dataset
trainset = torchvision.datasets.MNIST(root=root, train=True,
                                      download=True, transform=transform)

# Split the training set into training and validation sets
trainset, valset = torch.utils.data.random_split(trainset, [50000, 10000])

# Load MNIST testing dataset
testset = torchvision.datasets.MNIST(root=root, train=False,
                                     download=True, transform=transform)

# Create data loaders
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
                                          shuffle=True, num_workers=2)
valloader = torch.utils.data.DataLoader(valset, batch_size=64,
                                        shuffle=False, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
                                         shuffle=False, num_workers=2)

# Define the classes in MNIST
classes = tuple(str(i) for i in range(10))

data_loaders = {
    "train_loader": trainloader,
    "val_loader": valloader,
    "test_loader": testloader
}          

In [4]:
import torch.nn.functional as F
import torch.nn as nn
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 32 * 7 * 7)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x