In [56]:
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
import torchvision
import random
from tqdm import tqdm
import seaborn as sns

In [57]:
import os
os.chdir("/home/peppe/01_Study/01_University/Semester/2/Intro_to_ML/Project/Code/models_methods")

In [58]:
from utility.bypass_bn import enable_running_stats, disable_running_stats
from utility.initialize import initialize
from utility.early_stopping import EarlyStopping

In [59]:
def train_step(model, data_loader, optimizer, loss_fn, device, SAM=False, smoothing=0.1, verbose=False, log_interval=10):
    samples = 0.0
    cumulative_loss = 0.0
    cumulative_accuracy = 0.0

    model.train()

    for batch_idx, (inputs, targets) in enumerate(data_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        
        # first forward-backward step
        if SAM:        
            enable_running_stats(model) # disable batch norm running stats

        outputs = model(inputs)

        if SAM:
            loss = loss_fn(outputs, targets, smoothing=smoothing)
        else:
            loss = loss_fn(outputs, targets)

        loss.mean().backward()
        
        if SAM:
            optimizer.first_step(zero_grad=True)
            # second forward-backward step
            disable_running_stats(model)
            loss = loss_fn(model(inputs), targets, smoothing=smoothing)
            loss.mean().backward()
            optimizer.second_step(zero_grad=True)
        else:
            optimizer.step()
            optimizer.zero_grad()
        
        samples += inputs.shape[0]
        cumulative_loss += loss.mean().item()
        _, predicted = outputs.max(dim=1)

        cumulative_accuracy += predicted.eq(targets).sum().item()

        if verbose and batch_idx % log_interval == 0:
            current_loss = cumulative_loss / samples
            current_accuracy = cumulative_accuracy / samples * 100
            print(f'Batch {batch_idx}/{len(data_loader)}, Loss: {current_loss:.4f}, Accuracy: {current_accuracy:.2f}%', end='\r')

    return cumulative_loss / samples, cumulative_accuracy / samples * 100

In [60]:
def test_step(model, data_loader, loss_fn, device):
    samples = 0.
    cumulative_loss = 0.
    cumulative_accuracy = 0.

    model.eval()

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(data_loader):
            inputs = inputs.to(device)
            targets = targets.to(device)

            outputs = model(inputs)

            loss = loss_fn(outputs, targets)

            samples += inputs.shape[0]
            cumulative_loss += loss.mean().item() 
            _, predicted = outputs.max(1)

            cumulative_accuracy += predicted.eq(targets).sum().item()

    return cumulative_loss / samples, cumulative_accuracy / samples * 100

In [61]:
# tensorboard logging utilities
def log_values(writer, step, loss, accuracy, prefix):
    writer.add_scalar(f"{prefix}/loss", loss, step)
    writer.add_scalar(f"{prefix}/accuracy", accuracy, step)

In [117]:
def main(model,
         optimizer,
         loss_fn,
         data_loaders: dict,
         train_step: callable,
         test_step: callable,
         device,
         epochs=10,
         exp_name=None,
         exp_path="/home/peppe/01_Study/01_University/Semester/2/Intro_to_ML/Project/Code/experiments/",
         use_early_stopping=True,
         patience=5,
         delta=1e-3,
         scheduler=None,
         verbose_steps=True, # print after log_interval-learning steps
         log_interval=10,
         use_SAM=False, # if SAM=True then loss_fn must be smooth_cross_entropy with smoothing >= 0.07
         smoothing=0.1): 
    
    assert os.path.exists(f"{exp_path}"), "Experiment path does not exist"
    
    if use_SAM == True: 
        assert smoothing >= 0.07, "smoothing must be >= 0.7 when using SAM"
        assert loss_fn == smooth_crossentropy, "loss function must be smooth_crossentropy when using SAM"   
        optimizer = SAM(model.parameters(), 
                        optimizer, 
                        rho=2, 
                        adaptive=True, 
                        lr=0.1, momentum=0.9, weight_decay=0.0005)
            
    # Create a logger for the experiment
    writer = SummaryWriter(log_dir=f"{exp_path + exp_name}")

    if use_early_stopping:
        early_stopping = EarlyStopping(patience=patience, 
                                       delta=delta,
                                       path=f"{exp_path + exp_name + '/checkpoint.pt'}",)
        
    model.to(device)
    
    # Computes evaluation results before training
    print("Before training:")
    train_loss, train_accuracy = test_step(model, data_loaders["train_loader"], loss_fn,device=device)
    val_loss, val_accuracy = test_step(model, data_loaders["val_loader"], loss_fn,device=device)
    test_loss, test_accuracy = test_step(model, data_loaders["test_loader"], loss_fn,device=device)
    
    # Log to TensorBoard
    log_values(writer, -1, train_loss, train_accuracy, "Train")
    log_values(writer, -1, val_loss, val_accuracy, "Validation")
    log_values(writer, -1, test_loss, test_accuracy, "Test")

    print(f"\tTraining loss {train_loss:.5f}, Training accuracy {train_accuracy:.2f}")
    print(f"\tValidation loss {val_loss:.5f}, Validation accuracy {val_accuracy:.2f}")
    print(f"\tTest loss {test_loss:.5f}, Test accuracy {test_accuracy:.2f}")
    print("-----------------------------------------------------")
    
    pbar = tqdm(range(epochs), desc="Training")
    for e in pbar:
        train_loss, train_accuracy = train_step(model, data_loaders["train_loader"], optimizer, loss_fn, 
                                                device=device, SAM=use_SAM, verbose=verbose_steps, log_interval=log_interval)
        #if scheduler:
        #    scheduler.step()
        val_loss, val_accuracy = test_step(model, data_loaders["val_loader"], loss_fn,device=device)
        
        print("-----------------------------------------------------")
        
        # Logs to TensorBoard
        log_values(writer, e, train_loss, train_accuracy, "Train")
        log_values(writer, e, val_loss, val_accuracy, "Validation")

        pbar.set_postfix(train_loss=train_loss, train_accuracy=train_accuracy, val_loss=val_loss, val_accuracy=val_accuracy)

        if use_early_stopping:
            early_stopping(val_loss, model)
            if early_stopping.early_stop:
                print("Early stopping")
                break
    # Compute final evaluation results
    print("After training:")
    train_loss, train_accuracy = test_step(model, data_loaders["train_loader"], loss_fn,device=device)
    val_loss, val_accuracy = test_step(model, data_loaders["val_loader"], loss_fn,device=device)
    test_loss, test_accuracy = test_step(model, data_loaders["test_loader"], loss_fn,device=device)

    # Log to TensorBoard
    log_values(writer, epochs + 1, train_loss, train_accuracy, "Train")
    log_values(writer, epochs + 1, val_loss, val_accuracy, "Validation")
    log_values(writer, epochs + 1, test_loss, test_accuracy, "Test")

    print(f"\tTraining loss {train_loss:.5f}, Training accuracy {train_accuracy:.2f}")
    print(f"\tValidation loss {val_loss:.5f}, Validation accuracy {val_accuracy:.2f}")
    print(f"\tTest loss {test_loss:.5f}, Test accuracy {test_accuracy:.2f}")
    print("-----------------------------------------------------")

    # Closes the logger
    writer.close()

    # Let's return the net
    return model

-----

# Toy Example

In [65]:
from methods.SAM.sam import SAM
from utility.smooth_cross_entropy import smooth_crossentropy

In [66]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [67]:
# get data
import torch
import torchvision
import torchvision.transforms as transforms

# Define transformations to apply to the dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

root = "/home/peppe/01_Study/01_University/Semester/2/Intro_to_ML/Project/data" #
# Load MNIST training dataset
trainset = torchvision.datasets.MNIST(root=root, train=True,
                                      download=True, transform=transform)

# Split the training set into training and validation sets
trainset, valset = torch.utils.data.random_split(trainset, [50000, 10000])

# Load MNIST testing dataset
testset = torchvision.datasets.MNIST(root=root, train=False,
                                     download=True, transform=transform)

# Create data loaders
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
                                          shuffle=True, num_workers=2)
valloader = torch.utils.data.DataLoader(valset, batch_size=64,
                                        shuffle=False, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
                                         shuffle=False, num_workers=2)

# Define the classes in MNIST
classes = tuple(str(i) for i in range(10))

data_loaders = {
    "train_loader": trainloader,
    "val_loader": valloader,
    "test_loader": testloader
}          

In [68]:
# create simple model 
import torch.nn.functional as F
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 32 * 7 * 7)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [114]:
model = SimpleCNN()
base_optimizer = torch.optim.SGD
optimizer = SAM(model.parameters(), 
                base_optimizer, 
                rho=2, 
                adaptive=True, 
                lr=0.1, momentum=0.9, weight_decay=0.0005)

In [112]:
model = SimpleCNN()
base_optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0005)

In [105]:
type(optimizer) == SAM

True

In [104]:
SAM

methods.SAM.sam.SAM

In [None]:
main(
    model,
    base_optimizer,
    smooth_crossentropy,
    train_step=train_step,
    test_step=test_step,
    data_loaders=data_loaders,
    use_SAM=True, 
    device=device,
    exp_name="sam_test_4",
    epochs=5,
    use_early_stopping=True,
    delta=1e-3,
    verbose_steps=True,
    log_interval=100,
    smoothing=0.1
    )

In [None]:
# visualize with tensorboard
%reload_ext tensorboard
%tensorboard --logdir=/home/peppe/01_Study/01_University/Semester/2/Intro_to_ML/Project/Code/experiments/sam_test_2 # experiment path