In [None]:
# import libraries
import torch
import numpy as np
from dataloader import get_dataloader_vae
from dataloader import get_dataloader_OOD
from models import get_trained_model
import matplotlib.pyplot as plt

In [None]:
model = get_trained_model('CIFAR10')
# get dataloaders
train_dl, test_dl = get_dataloader_vae('CIFAR10')

In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# model train loop function
def test_vae(model, num_epochs, train_loader, test_loader):
    # send model to device
    model = model.to(device)

    train_loss = []
    test_loss = []
    num_train = 0
    num_test = 0
    best_loss = np.inf

    for epoch in range(num_epochs):
        num_train = 0
        num_test = 0
        # eval mode
        model.eval()
        avg_train_loss = 0
        for i, (images, _) in enumerate(train_loader):
            num_train += images.shape[0]
            images = images.to(device)

            # Forward pass
            with torch.no_grad():
                recon_images, mu, logvar, _ = model(images)
                loss = 0.5 * (recon_images - images).pow(2).sum() - 0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).sum()
                avg_train_loss += loss
        
        avg_train_loss /= num_train         
        
        
        # eval mode
        model.eval()
        avg_val_loss = 0
        for i, (images, _) in enumerate(test_loader):
            num_test += images.shape[0]
            images = images.to(device)

            # Forward pass
            with torch.no_grad():
                recon_images, mu, logvar, _ = model(images)
                loss = 0.5 * (recon_images - images).pow(2).sum() - 0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).sum()
                avg_val_loss += loss
        
        avg_val_loss /= num_test

        # logging losses
        train_loss.append(avg_train_loss.detach().cpu().data.numpy())
        test_loss.append(avg_val_loss.detach().cpu().data.numpy())
        print(f'Epoch [{epoch+1}/{num_epochs}], Training Loss: {avg_train_loss.item():.4f}, Validation Loss: {avg_val_loss.item():.4f}')

    return model, train_loss, test_loss

In [5]:
model, train_loss, test_loss = test_vae(model, 100, train_dl, test_dl)

Epoch [6/100], Training Loss: 159.2156, Validation Loss: 157.0721
Epoch [7/100], Training Loss: 136.4739, Validation Loss: 134.6388


In [None]:
# plot losses
plt.plot(range(len(train_loss)), train_loss, label="Training Loss")
plt.plot(range(len(test_loss)), test_loss, label="Validation Loss")
plt.title('Training and Validation loss per epoch for CIFAR10 VAE')
plt.xlabel('Epoch')
plt.ylabel('Val Loss')
plt.legend()
plt.show() 