In [1]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

from models import model as MODEL
from dataset import GaussianBumpDataset

In [5]:
sigma = 1.5 # spread of the spot
grid_size = 32 # size of the image grid

model_mode = ['full', 'regularised'][1] # select whether to use regulariser
save_name = str(model_mode)

hidden_layers = 2 # number of hidden layers
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# define mask regions. mask[0] is the full dataset
rect_thre = 12
masks = [lambda x, y: True,
         lambda x, y: (x < rect_thre) | (y < rect_thre),
         lambda x, y: not ((x < rect_thre) | (y < rect_thre)),
        ]

# Dataset and DataLoader
set_FU = GaussianBumpDataset(mask = masks[0], sigma = sigma, grid_size = grid_size)
inputs_FU, targets_FU = set_FU.get_all_data()
inputs_FU, targets_FU = inputs_FU.to(device), targets_FU.to(device)

set_ID = GaussianBumpDataset(mask = masks[1], sigma = sigma, grid_size = grid_size)
set_OD = GaussianBumpDataset(mask = masks[2], sigma = sigma, grid_size = grid_size)

dataloader_ID = DataLoader(set_ID, batch_size = 32, shuffle=True)
dataloader_OD = DataLoader(set_OD, batch_size = 32, shuffle=True)

In [6]:
num_epochs = 1000

torch.manual_seed(42)
criterion = nn.MSELoss()
model = MODEL(grid_size = grid_size, hidden_layers = hidden_layers)

optimizer = optim.AdamW(model.parameters(), lr = 1e-3, weight_decay = 1e-4)
total_params = sum(p.numel() for p in model.parameters())
print('total model parameters is: {0}'.format(total_params))

# training
LOSS = []

model.to(device)

if model_mode == 'full':
    regLam, regVec = 0, 0
else:
    regLam = 0.1
    regVec = 0.01

def get_entropy(x, axis = -1):
    x = x ** 2
    x = x / torch.sum(x, axis = axis).unsqueeze(axis)
    x = torch.clamp(x, min = 1e-8)
    return - torch.sum(x * torch.log(x))

for epoch in range(num_epochs):
    model.train()
    running_loss_ID = 0.0
    for batch_idx, (inputs_ID, targets_ID) in enumerate(dataloader_ID):
        inputs_ID, targets_ID = inputs_ID.to(device), targets_ID.to(device)
        optimizer.zero_grad()
        outputs_ID = model(inputs_ID)
        loss = criterion(outputs_ID, targets_ID)
        running_loss_ID += loss.item()

        if regLam > 0:
            
            loss += regLam * get_entropy(model.eigValX, axis = -1)
            loss += regLam * get_entropy(model.eigValY, axis = -1)

            loss += regVec * torch.sum(torch.var(model.LeftVeX, dim = -2))
            loss += regVec * torch.sum(torch.var(model.RihtVeX, dim = -1))
            loss += regVec * torch.sum(torch.var(model.LeftVeY, dim = -2))
            loss += regVec * torch.sum(torch.var(model.RihtVeY, dim = -1))
        
        loss.backward()
        optimizer.step()
        
    avg_loss_ID = running_loss_ID / len(dataloader_ID)
    
    model.eval()  # Set the model to evaluation mode
    running_loss_OD = 0.0
    with torch.no_grad():  # No gradients needed during validation
        for inputs_OD, targets_OD in dataloader_OD:
            inputs_OD, targets_OD = inputs_OD.to(device), targets_OD.to(device)
            outputs_OD = model(inputs_OD)  # Forward pass
            loss = criterion(outputs_OD, targets_OD)  # Compute the validation loss
            running_loss_OD += loss.item()
    avg_loss_OD = running_loss_OD / len(dataloader_OD)
    
    LOSS.append([avg_loss_ID, avg_loss_OD])
    print(f"Epoch [{epoch+1}/{num_epochs}], ID Loss: {avg_loss_ID:.6f}, OOD Loss : {avg_loss_OD:.6f}")
LOSS = np.array(LOSS)

torch.save(model.state_dict(), "results/model_{0}.pth".format(save_name))

total model parameters is: 133218
Epoch [1/1000], ID Loss: 0.006403, OOD Loss : 0.006469
Epoch [2/1000], ID Loss: 0.006380, OOD Loss : 0.006461
Epoch [3/1000], ID Loss: 0.006374, OOD Loss : 0.006462
Epoch [4/1000], ID Loss: 0.006376, OOD Loss : 0.006471
Epoch [5/1000], ID Loss: 0.006366, OOD Loss : 0.006467
Epoch [6/1000], ID Loss: 0.006352, OOD Loss : 0.006475
Epoch [7/1000], ID Loss: 0.006364, OOD Loss : 0.006470
Epoch [8/1000], ID Loss: 0.006366, OOD Loss : 0.006466
Epoch [9/1000], ID Loss: 0.006349, OOD Loss : 0.006460
Epoch [10/1000], ID Loss: 0.006367, OOD Loss : 0.006474
Epoch [11/1000], ID Loss: 0.006356, OOD Loss : 0.006473
Epoch [12/1000], ID Loss: 0.006362, OOD Loss : 0.006474
Epoch [13/1000], ID Loss: 0.006361, OOD Loss : 0.006468
Epoch [14/1000], ID Loss: 0.006361, OOD Loss : 0.006477
Epoch [15/1000], ID Loss: 0.006356, OOD Loss : 0.006476
Epoch [16/1000], ID Loss: 0.006353, OOD Loss : 0.006461
Epoch [17/1000], ID Loss: 0.006348, OOD Loss : 0.006476
Epoch [18/1000], ID Los

In [7]:
plt.figure(figsize = (5, 3))
plt.plot(LOSS[:, 0], label = 'in-dist')
plt.plot(LOSS[:, 1], label = 'out-of-dist')
plt.yscale('log')
plt.ylim(1e-6, 1e0)
plt.xlabel('epoch')
plt.ylabel('MSE loss')
plt.legend()
plt.tight_layout()
plt.savefig('results/loss_{0}.png'.format(save_name))
plt.close()

fig, axes = plt.subplots(8, 2, figsize = (6, 24))
for i in range(len(axes)):
    res = model(inputs_ID[i, :].unsqueeze(0)).cpu()
    res = res.detach().numpy()
    tar = targets_ID[i, :].cpu()
    tar = tar.numpy()
    axes[i, 0].imshow(np.reshape(res, (grid_size, grid_size)), origin = 'lower', vmin = 0, vmax = 1)
    axes[i, 1].imshow(np.reshape(tar, (grid_size, grid_size)), origin = 'lower', vmin = 0, vmax = 1)
plt.tight_layout()
plt.savefig('results/examples_ID_{0}.png'.format(save_name))
plt.close()

fig, axes = plt.subplots(8, 2, figsize = (6, 24))
for i in range(len(axes)):
    res = model(inputs_OD[i, :].unsqueeze(0)).cpu()
    res = res.detach().numpy()
    tar = targets_OD[i, :].cpu()
    tar = tar.numpy()
    axes[i, 0].imshow(np.reshape(res, (grid_size, grid_size)), origin = 'lower', vmin = 0, vmax = 1)
    axes[i, 1].imshow(np.reshape(tar, (grid_size, grid_size)), origin = 'lower', vmin = 0, vmax = 1)
plt.tight_layout()
plt.savefig('results/examples_OD_{0}.png'.format(save_name))
plt.close()