In [1]:
import numpy as np
import torch
from torch import nn
from utils import print_losses, load_dataset, train_model, train_step, test_step
from utils import load_model, create_model, get_loaders, DATA_DIR, DATASETS

In [2]:
def train_SWAG(
    x_train: np.array,
    y_train: np.array,
    x_test: np.array,
    y_test: np.array,
    model: nn.Module,
    dataset_name,
    epochs: int = 20,
    batch_size: int = 100,
    lr: float = 0.02,
    verbose: bool = True,
    c = 2,
    K = 5,
):
    assert c >=1 and K >=2
    train_loader, test_loader = get_loaders(
        x_train, y_train, x_test, y_test, batch_size
    )
    theta_epoch = torch.nn.utils.parameters_to_vector(model.parameters()).detach().cpu().clone()
    theta = theta_epoch
    theta_square = theta_epoch**2
    D = None

    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    criterion = nn.MSELoss()

    train_losses = [test_step(batch_x, batch_y, model, criterion) for batch_x, batch_y in train_loader]
    test_losses = [test_step(batch_x, batch_y, model, criterion) for batch_x, batch_y in test_loader]
    if verbose:
        print_losses(0, train_losses, test_losses)

    for epoch in range(1, epochs + 1):
        train_losses = [train_step(batch_x, batch_y, model, optimizer, criterion) for batch_x, batch_y in train_loader]
        test_losses = [test_step(batch_x, batch_y, model, criterion) for batch_x, batch_y in test_loader]
        if verbose:
            print_losses(epoch, train_losses, test_losses)
        if epoch % c == 0:
            if verbose:
                print("SWAG moment update")
            n = epoch / c
            theta_epoch = torch.nn.utils.parameters_to_vector(model.parameters()).detach().cpu().clone()
            theta = (n * theta + theta_epoch) / (n + 1)
            theta_square = (n * theta_square + theta_epoch**2) / (n + 1)
            deviations = (theta_epoch - theta).reshape(-1,1)
            if D is None:
                D = deviations
            else:
                if D.shape[1] == K:
                    D = D[:,1:]
                D = torch.cat((D, deviations), dim=1)
    sigma_diag = theta_square - theta**2
    torch.nn.utils.vector_to_parameters(theta, model.parameters())
    test_losses = [test_step(batch_x, batch_y, model, criterion) for batch_x, batch_y in test_loader]
    print_losses(0, train_losses, test_losses)
    return theta, sigma_diag, D


In [3]:
# for dataset_name in ['boston_housing']:
for dataset_name in DATASETS:
    print("=" * 88)
    x_train, y_train, x_test, y_test, _, _ = load_dataset(dataset_name)
    model = create_model(x_train, layer_dims=[5], verbose=False)
    train_model(x_train, y_train, x_test, y_test, model, dataset_name, verbose=False)
    model = load_model(model, f"best_model_weights-{dataset_name}.pth", verbose=False)
    theta, sigma_diag, D = train_SWAG(x_train, y_train, x_test, y_test, model, dataset_name, verbose=False)

dataset: boston_housing, rows: 506, columns: 13, range of x: [0.0, 711.0], range of y: [5.0, 50.0]
Early Stopping. Best test loss: 0.1256062537431717 in epoch 87
Epoch:   0, Train loss: 0.218, Test loss: 0.127
dataset: concrete, rows: 1030, columns: 8, range of x: [0.0, 1145.0], range of y: [2.33, 82.6]
Early Stopping. Best test loss: 0.12056347727775574 in epoch 78
Epoch:   0, Train loss: 0.249, Test loss: 0.119
dataset: energy_heating_load, rows: 768, columns: 8, range of x: [0.0, 808.5], range of y: [6.01, 43.1]
Early Stopping. Best test loss: 0.0735061764717102 in epoch 38
Epoch:   0, Train loss: 0.107, Test loss: 0.073
dataset: kin8nm, rows: 8192, columns: 8, range of x: [-1.5706812, 1.5707529], range of y: [0.040165378, 1.4585206]
Early Stopping. Best test loss: 0.2447793202267753 in epoch 181
Epoch:   0, Train loss: 0.235, Test loss: 0.244
dataset: naval_compressor_decay, rows: 11934, columns: 16, range of x: [0.0, 72784.872], range of y: [0.95, 1.0]
Early Stopping. Best test lo

In [4]:
print(list(model.parameters()))
theta = torch.nn.utils.parameters_to_vector(model.parameters())
torch.nn.utils.vector_to_parameters(theta, model.parameters())
print(list(model.parameters()))
theta = torch.nn.utils.parameters_to_vector(model.parameters()) * 100
torch.nn.utils.vector_to_parameters(theta, model.parameters())
print(list(model.parameters()))


[Parameter containing:
tensor([[-6.8719e-02,  3.3007e-01, -1.4184e-02, -2.1600e-01,  7.3773e-02,
          4.7634e-01,  1.8083e-01,  6.4695e-02,  6.7116e-02, -4.4226e-02,
          7.4374e-02, -4.5007e-02,  9.8322e-02, -2.3441e-01,  4.4935e-02,
         -2.0055e-01,  3.1837e-02,  4.6269e-02,  5.8313e-03, -1.6894e-01,
         -7.8161e-02,  1.4711e-01, -2.8028e-01, -1.1366e-01,  9.9529e-02,
         -3.3127e-03, -2.2686e-01, -7.3052e-02,  1.1638e-02, -2.1022e-02,
         -4.9833e-02, -1.1117e-02,  7.1722e-02, -6.8154e-03,  3.2765e-02,
          3.9085e-01,  2.3103e-02, -5.5059e-02, -8.9089e-02, -4.8600e-02,
          9.6552e-03,  4.7307e-02, -8.2469e-02,  3.9732e-02, -8.1195e-03,
         -1.3822e-01,  6.2903e-03, -1.7652e-02,  6.4481e-02,  8.1603e-02,
         -8.2623e-02, -1.2330e-02,  7.2972e-03,  6.7775e-03,  8.9973e-02,
         -3.3690e-02,  4.6575e-02, -2.2559e-01,  7.9643e-03,  4.5148e-02,
         -4.5836e-02,  2.4264e-02, -4.2504e-02, -9.2413e-02,  3.5833e-02,
         -2.510