In [1]:
import numpy as np
import torch
from torch import nn
from utils import print_losses, load_dataset, train_model, train_step, test_step, train_SWAG
from utils import load_model, create_model, get_loaders, DATA_DIR, DATASETS
from torch.distributions.multivariate_normal import MultivariateNormal

In [2]:
def calculate_sigma(sigma_diag: torch.Tensor, D: torch.Tensor, K: int):
    assert K >= 2
    return 0.5 * torch.diag(sigma_diag) + 0.5 * torch.mm(D, D.T) / (K + 1)

def sample_posterior(theta_swa: torch.Tensor, sigma_diag: torch.Tensor, D: torch.Tensor, K: int):
    assert K >= 2
    mu = theta_swa
    sigma = calculate_sigma(sigma_diag, D, K)
    distribution = MultivariateNormal(loc=mu, covariance_matrix=sigma)
    return distribution.sample()

def get_test_predictions(x_train: np.array, y_train: np.array, 
                     x_test: np.array, y_test: np.array, 
                     model: nn.Module, batch_size: int = 100):
    train_loader, test_loader = get_loaders(x_train, y_train, x_test, y_test, batch_size)
    test_predictions = [model(batch_x).cpu().detach().numpy() for batch_x, _ in test_loader]
    return np.concatenate(test_predictions)

def sample_from_SWAG(x_train, y_train, x_test, y_test, model, theta_swa, sigma_diag, D, K, S):
    test_predictions_list = []
    for _ in range(S):
        sampled_weights = sample_posterior(theta_swa, sigma_diag, D, K)
        torch.nn.utils.vector_to_parameters(sampled_weights, model.parameters())
        test_predictions_list.append(get_test_predictions(x_train, y_train, x_test, y_test, model))
    return test_predictions_list

def RMSE(y_1, y_2):
    return np.sqrt(np.mean((y_1 - y_2)**2))

In [None]:
K = 10
S = 500
wd = 1e-6
# for dataset_name in ['boston_housing']:
# for dataset_name in ['yacht']:
for dataset_name in DATASETS:
    print("=" * 88)
    x_train, y_train, x_test, y_test, _, _ = load_dataset(dataset_name)
    batch_size = x_train.shape[0]//9
    model = create_model(x_train, layer_dims=[50], verbose=False)
    train_model(x_train, y_train, x_test, y_test, model, dataset_name, lr=0.001, epochs=50000, verbose=False, batch_size=batch_size, weight_decay=wd)
    model = load_model(model, f"best_model_weights-{dataset_name}.pth", verbose=False)
    y_pred = get_test_predictions(x_train, y_train, x_test, y_test, model)
    print(f"SGD RMSE: {RMSE(y_pred, y_test):.3f}")
    for SWAG_lr in [0.001, 0.01, 0.05, 0.1, 0.2]:
        print(f"SWAG_lr: {SWAG_lr: 0.5f}")
        for _ in range(5):
            try:
                model = load_model(model, f"best_model_weights-{dataset_name}.pth", verbose=False)
                theta_swa, sigma_diag, D = train_SWAG(x_train, y_train, x_test, y_test, model, K, verbose=False, lr=SWAG_lr, batch_size=batch_size, weight_decay=wd)
                sigma_diag = torch.clamp(sigma_diag, min=1e-10)
                samples = sample_from_SWAG(x_train, y_train, x_test, y_test, model, theta_swa, sigma_diag, D, K, S)
                samples_array = np.concatenate(samples, axis=1)
                y_pred = samples_array.mean(axis=1, keepdims=True)
                y_l = np.percentile(samples_array, 2.5, axis=1, keepdims=True)
                y_u = np.percentile(samples_array, 97.5, axis=1, keepdims=True)
                print(f"RMSE: {RMSE(y_pred, y_test):.3f}, PICP: {np.mean((y_l < y_test) & (y_test < y_u)):.3f}, MPIW:{np.mean(y_u - y_l):.3f}")
            except:
                print("Some error")
    

dataset: boston_housing, rows: 506, columns: 13, range of x: [0.0, 711.0], range of y: [5.0, 50.0]
Finished Training. Best test loss: 0.72753 in epoch 59
SGD RMSE: 0.345
SWAG_lr:  0.00100
RMSE: 0.328, PICP: 0.118, MPIW:0.136
RMSE: 0.327, PICP: 0.118, MPIW:0.137
RMSE: 0.327, PICP: 0.118, MPIW:0.138
RMSE: 0.328, PICP: 0.137, MPIW:0.147
RMSE: 0.328, PICP: 0.118, MPIW:0.132
SWAG_lr:  0.01000
RMSE: 0.333, PICP: 0.549, MPIW:0.411
RMSE: 0.329, PICP: 0.510, MPIW:0.438
RMSE: 0.336, PICP: 0.529, MPIW:0.419
RMSE: 0.328, PICP: 0.490, MPIW:0.400
RMSE: 0.328, PICP: 0.529, MPIW:0.409
SWAG_lr:  0.05000
RMSE: 0.434, PICP: 0.686, MPIW:0.776
RMSE: 0.435, PICP: 0.686, MPIW:0.782
RMSE: 0.453, PICP: 0.529, MPIW:0.657
RMSE: 0.445, PICP: 0.627, MPIW:0.709
RMSE: 0.434, PICP: 0.647, MPIW:0.686
SWAG_lr:  0.10000
RMSE: 0.549, PICP: 0.549, MPIW:0.747
RMSE: 0.573, PICP: 0.471, MPIW:0.748
RMSE: 0.506, PICP: 0.667, MPIW:0.995
RMSE: 0.548, PICP: 0.529, MPIW:0.791
RMSE: 0.496, PICP: 0.686, MPIW:1.038
SWAG_lr:  0.20000
