In [2]:
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
import torchvision
import random
from tqdm import tqdm
import seaborn as sns

In [None]:
# define training_loop and validate book 
def training_step(model, data_loader, optimizer, loss_fn, device):
    samples = 0.0
    cumulative_loss = 0.0
    cumulative_accuracy = 0.0

    model.train()

    for batch_idx, (inputs, targets) in enumerate(data_loader):
        inputs = inputs.to(device)
        targets = targets.to(device)

        outputs = model(inputs)

        loss = loss_fn(outputs, targets)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        samples += inputs.shape[0]
        cumulative_loss += loss.item()
        _, predicted = outputs.max(dim=1)

        cumulative_accuracy += predicted.eq(targets).sum().item()

    return cumulative_loss / samples, cumulative_accuracy / samples * 100

def test_step(model, data_loader, loss_fn, device):
    samples = 0.
    cumulative_loss = 0.
    cumulative_accuracy = 0.

    model.eval()

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(data_loader):
            inputs = inputs.to(device)
            targets = targets.to(device)

            outputs = model(inputs)

            loss = loss_fn(outputs, targets)

            samples += inputs.shape[0]
            cumulative_loss += loss.item() 
            _, predicted = outputs.max(1)

            cumulative_accuracy += predicted.eq(targets).sum().item()

    return cumulative_loss / samples, cumulative_accuracy / samples * 100

In [None]:
# tensorboard logging utilities
def log_values(writer, step, loss, accuracy, prefix):
    writer.add_scalar(f"{prefix}/loss", loss, step)
    writer.add_scalar(f"{prefix}/accuracy", accuracy, step)

In [None]:
def main(model,
         optimizer,
         loss_fn,
         batch_size=128,
         device=device,
         epochs=10,
         exp_name="exp1"):
    
    # Create a logger for the experiment
    writer = SummaryWriter(log_dir=f"runs/{exp_name}")

    transforms = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor()
    ])

    train_loader, val_loader, test_loader = get_cifar(batch_size, transforms,data_path=data_path)
    
    model.to(device)
    
    # Computes evaluation results before training
    print("Before training:")
    train_loss, train_accuracy = test_step(model, train_loader, loss_fn,device=device)
    val_loss, val_accuracy = test_step(model, val_loader, loss_fn,device=device)
    test_loss, test_accuracy = test_step(model, test_loader, loss_fn,device=device)
    
    # Log to TensorBoard
    log_values(writer, -1, train_loss, train_accuracy, "Train")
    log_values(writer, -1, val_loss, val_accuracy, "Validation")
    log_values(writer, -1, test_loss, test_accuracy, "Test")

    print(f"\tTraining loss {train_loss:.5f}, Training accuracy {train_accuracy:.2f}")
    print(f"\tValidation loss {val_loss:.5f}, Validation accuracy {val_accuracy:.2f}")
    print(f"\tTest loss {test_loss:.5f}, Test accuracy {test_accuracy:.2f}")
    print("-----------------------------------------------------")
    
    pbar = tqdm(range(epochs), desc="Training")
    for e in pbar:
        train_loss, train_accuracy = training_step(model, train_loader, optimizer, loss_fn, device=device)
        val_loss, val_accuracy = test_step(model, val_loader, loss_fn,device=device)
        #print(f"Epoch: {e + 1}")
        #print(f"\tTraining loss {train_loss:.5f}, Training accuracy {train_accuracy:.2f}")
        #print(f"\tValidation loss {val_loss:.5f}, Validation accuracy {val_accuracy:.2f}")
        #print("-----------------------------------------------------")
        
        # Logs to TensorBoard
        log_values(writer, e, train_loss, train_accuracy, "Train")
        log_values(writer, e, val_loss, val_accuracy, "Validation")

        pbar.set_postfix(train_loss=train_loss, train_accuracy=train_accuracy, val_loss=val_loss, val_accuracy=val_accuracy)

    # Compute final evaluation results
    print("After training:")
    train_loss, train_accuracy = test_step(model, train_loader, loss_fn,device=device)
    val_loss, val_accuracy = test_step(model, val_loader, loss_fn,device=device)
    test_loss, test_accuracy = test_step(model, test_loader, loss_fn,device=device)

    # Log to TensorBoard
    log_values(writer, epochs + 1, train_loss, train_accuracy, "Train")
    log_values(writer, epochs + 1, val_loss, val_accuracy, "Validation")
    log_values(writer, epochs + 1, test_loss, test_accuracy, "Test")

    print(f"\tTraining loss {train_loss:.5f}, Training accuracy {train_accuracy:.2f}")
    print(f"\tValidation loss {val_loss:.5f}, Validation accuracy {val_accuracy:.2f}")
    print(f"\tTest loss {test_loss:.5f}, Test accuracy {test_accuracy:.2f}")
    print("-----------------------------------------------------")

    # Closes the logger
    writer.close()

    # Let's return the net
    return model