In [9]:
import numpy as np
import torch
from matplotlib import pyplot as plt
from sklearn.linear_model import LinearRegression
from torch import nn

from data import get_loaders, load_dataset, train_test_split
from training import test_step, train_step, train_model, get_test_predictions
from utils import print_losses
from swag import sample_from_SWAG, run_SWAG, calculate_sigma
from definitions import DATASETS, device
from models import create_model, load_model
from metrics import RMSE
from swag_mod import calculate_coeffs

In [10]:
def train_step_mod(
    batch_x: torch.Tensor,
    batch_y: torch.Tensor,
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    criterion,
    lr_multipliers,
):
    optimizer.zero_grad()
    outputs = model(batch_x.to(device))
    loss = criterion(outputs, batch_y.to(device))
    loss.backward()
    multitiply_grads(model, lr_multipliers)
    optimizer.step()
    return loss.item()

def train_SWAG_mod(
    x_train: np.array,
    y_train: np.array,
    x_test: np.array,
    y_test: np.array,
    model: nn.Module,
    K,
    lr_multipliers,
    epochs: int = 100,
    batch_size: int = 100,
    lr: float = 0.01,
    verbose: bool = True,
    c=1,
    momentum=0,
    weight_decay=0,
):
    assert c >= 1 and K >= 2
#     train_loader, test_loader = get_loaders(x_train, y_train, x_test, y_test, batch_size)
    train_loader, _, test_loader = get_loaders(x_train, y_train, x_test, y_test, batch_size, val_loader=True)
    theta_epoch = torch.nn.utils.parameters_to_vector(model.parameters()).detach().cpu().clone()
    theta = theta_epoch.clone()
    theta_square = theta_epoch.clone() ** 2
    D = None

    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
    criterion = nn.MSELoss()

    train_losses = [test_step(batch_x, batch_y, model, criterion) for batch_x, batch_y in train_loader]
    test_losses = [test_step(batch_x, batch_y, model, criterion) for batch_x, batch_y in test_loader]
    if verbose:
        print_losses(0, train_losses, test_losses)

    thetas = []
    for epoch in range(1, epochs + 1):
        train_losses = [train_step_mod(batch_x, batch_y, model, optimizer, criterion, lr_multipliers) 
                        for batch_x, batch_y in train_loader]
        test_losses = [test_step(batch_x, batch_y, model, criterion) for batch_x, batch_y in test_loader]
        if verbose:
            print_losses(epoch, train_losses, test_losses)
        if epoch % c == 0:
            if verbose:
                print("SWAG moment update")
            n = epoch / c
            theta_epoch = torch.nn.utils.parameters_to_vector(model.parameters()).detach().cpu().clone()
            thetas.append(theta_epoch.clone())
            theta = (n * theta + theta_epoch.clone()) / (n + 1)
            theta_square = (n * theta_square + theta_epoch.clone() ** 2) / (n + 1)
            deviations = (theta_epoch.clone() - theta).reshape(-1, 1)
            if D is None:
                D = deviations
            else:
                if D.shape[1] == K:
                    D = D[:, 1:]
                D = torch.cat((D, deviations), dim=1)
    sigma_diag = theta_square - theta ** 2
    torch.nn.utils.vector_to_parameters(theta, model.parameters())
    model.to(device)
    test_losses = [test_step(batch_x, batch_y, model, criterion) for batch_x, batch_y in test_loader]
    # print(f"Finished SWAG.     Best test loss: {np.mean(test_losses):.5f}")
    return theta, sigma_diag, D, thetas

def multitiply_grads(model, lr_multipliers):
    start_ind = 0
    for params in model.parameters():
        shape = params.shape
        total_len = params.reshape(-1).shape[0]
        multipliers = lr_multipliers[start_ind:(start_ind + total_len)].reshape(shape)
        start_ind += total_len
        params.grad = params.grad * multipliers.to(device)
        
def sample_and_get_metrics(x_train, y_train, x_test, y_test, model, theta_swa, sigma_diag, D, K, S):
    samples = sample_from_SWAG(x_train, y_train, x_test, y_test, model, theta_swa, sigma_diag, D, K, S)
    samples_array = np.concatenate(samples, axis=1)
    y_pred = samples_array.mean(axis=1, keepdims=True)
    y_l = np.percentile(samples_array, 2.5, axis=1, keepdims=True)
    y_u = np.percentile(samples_array, 97.5, axis=1, keepdims=True)
    rmse = RMSE(y_pred, y_test)
    pcip = np.mean((y_l < y_test) & (y_test < y_u))
    mpiw = np.mean(y_u - y_l)
    return rmse, pcip, mpiw


In [11]:
seed = 556
torch.manual_seed(seed)
np.random.seed(seed)

In [12]:
DSELECTED = [
    'boston_housing',
    'concrete',
    'energy_heating_load',
    'power',
    'wine',
    'yacht',
    'kin8nm',
    'naval_compressor_decay',
#     'protein',
#     'year_prediction_msd',
            ]
DSELECTED

['boston_housing',
 'concrete',
 'energy_heating_load',
 'power',
 'wine',
 'yacht',
 'kin8nm',
 'naval_compressor_decay']

In [13]:
# # SWAG LOOP

# # for dataset_name in DATASETS:
# for dataset_name in DSELECTED:
# # for dataset_name in ['boston_housing']:
#     run_SWAG(dataset_name, weight_decay=0)

In [14]:
TOLERANCE = {
    "boston_housing": 0.06,
    "concrete": 0.05,
    "energy_heating_load": 0.05,
    "kin8nm": 0.04,
    "naval_compressor_decay": 0.04,
    "power": 0.05,
    "protein": 0.05,
    "wine": 0.02,
    "yacht": 0.03,
    "year_prediction_msd": 0.05,
}

In [8]:
K = 10
S = 500
weight_decay = 0
multiplier = 2
# tolerance = 0.05
# for dataset_name in ['boston_housing']:
# for dataset_name in ['yacht']:
# for dataset_name in ['year_prediction_msd']:
# for dataset_name in ['naval_compressor_decay']:
#     for dataset_name in DATASETS[:3]:
for dataset_name in DATASETS:
#     for dataset_name in ['naval_compressor_decay']:
    if dataset_name in TOLERANCE:
        run_SWAG(dataset_name, weight_decay=0, train_model_flag=False)
        tolerance = TOLERANCE[dataset_name]
        print(f"""-----> TOLERANCE: {tolerance}""")
        x_train, y_train, x_test, y_test, _, _ = load_dataset(dataset_name, verbose=False)
        batch_size = x_train.shape[0]//9
        model = create_model(x_train, layer_dims=[50], verbose=False)
    #     train_model(x_train, y_train, x_test, y_test, model, dataset_name, lr=0.001, epochs=50000, verbose=False, 
    #                 batch_size=batch_size, weight_decay=weight_decay)
        model = load_model(model, f"best_model_weights-{dataset_name}.pth", verbose=False)
        y_pred = get_test_predictions(x_train, y_train, x_test, y_test, model)
    #     print(f"SGD RMSE: {RMSE(y_pred, y_test):.3f}")
#         print("---MOD---")
        theta_epoch = torch.nn.utils.parameters_to_vector(model.parameters()).detach().cpu().clone()
        lr_multipliers = torch.ones_like(theta_epoch).float()
        round_results = []
        for rounds in range(10):
            try:
                model = load_model(model, f"best_model_weights-{dataset_name}.pth", verbose=False)
                theta_swa, sigma_diag, D, thetas = train_SWAG_mod(x_train, y_train, x_test, y_test, model, K, lr_multipliers,
                                                              verbose=False, lr=0.01, batch_size=batch_size, 
                                                              weight_decay=weight_decay, epochs=50)
                weight_series = torch.stack(thetas).numpy()
#                 step_plot = weight_series.shape[1] // 5
                weight_series -= weight_series.mean(axis=0, keepdims=True)
                weight_series /= weight_series.std(axis=0, keepdims=True) + 1e-10
    #             plt.plot(weight_series[:,::step_plot], alpha=0.3)
    #             plt.show()
                coeffs = calculate_coeffs(weight_series, False)
                lr_multipliers[np.abs(coeffs) > tolerance] *= multiplier
                _, x_val, _, y_val = train_test_split(x_train, y_train, test_size=0.1)
                y_val_pred = model(torch.from_numpy(x_val).float().to(device)).detach().cpu().numpy()
                sigma_diag = torch.clamp(sigma_diag, min=1e-10)
                rmse_test, pcip_test, mpiw_test = sample_and_get_metrics(x_train, y_train, x_test, y_test, 
                                                                         model, theta_swa, sigma_diag, D, K, S)
                rmse_val, pcip_val, mpiw_val = sample_and_get_metrics(x_train, y_train, x_val, y_val, 
                                                                         model, theta_swa, sigma_diag, D, K, S)
                round_results.append([rmse_test, pcip_test, mpiw_test, rmse_val, pcip_val, mpiw_val])
            except Exception as e:
                pass
#                 raise e
#                 print("Some error")
        best_ind = np.array(round_results)[:,3].argmin()
        rmse_test, pcip_test, mpiw_test, rmse_val, pcip_val, mpiw_val = round_results[best_ind]
        print(f"Test RMSE: {rmse_test:.3f}, PICP: {pcip_test:.3f}, MPIW:{mpiw_test:.3f} ", end = "| ")
        print(f"Val RMSE: {rmse_val:.3f}, PICP: {pcip_val:.3f}, MPIW:{mpiw_val:.3f}")


dataset: boston_housing, rows: 506, columns: 13, range of x: [0.0, 711.0], range of y: [5.0, 50.0]
SGD RMSE: 0.334
SWAG_lr:  0.01000
RMSE: 0.306, PICP: 0.333, MPIW:0.280
-----> TOLERANCE: 0.06
Test RMSE: 0.338, PICP: 0.686, MPIW:0.834 | Val RMSE: 0.415, PICP: 0.848, MPIW:1.145
dataset: concrete, rows: 1030, columns: 8, range of x: [0.0, 1145.0], range of y: [2.33, 82.6]
SGD RMSE: 0.581
SWAG_lr:  0.20000
RMSE: 0.338, PICP: 0.825, MPIW:0.971
-----> TOLERANCE: 0.05
Test RMSE: 0.535, PICP: 0.718, MPIW:1.178 | Val RMSE: 0.552, PICP: 0.710, MPIW:1.185
dataset: energy_heating_load, rows: 768, columns: 8, range of x: [0.0, 808.5], range of y: [6.01, 43.1]
SGD RMSE: 0.232
SWAG_lr:  0.10000
RMSE: 0.203, PICP: 0.805, MPIW:0.472
-----> TOLERANCE: 0.05
Some error
Test RMSE: 0.172, PICP: 0.909, MPIW:0.744 | Val RMSE: 0.224, PICP: 0.843, MPIW:0.723
dataset: kin8nm, rows: 8192, columns: 8, range of x: [-1.5706812, 1.5707529], range of y: [0.040165378, 1.4585206]
SGD RMSE: 0.474
SWAG_lr:  0.10000
RMSE:

In [21]:
K = 10
S = 500
weight_decay = 0
multiplier = 2
tolerance = 0.05
# for dataset_name in ['boston_housing']:
# for dataset_name in ['yacht']:
# for dataset_name in ['year_prediction_msd']:
# for dataset_name in ['naval_compressor_decay']:
#     for dataset_name in DATASETS[:3]:
for dataset_name in DATASETS:
#     for dataset_name in ['naval_compressor_decay']:
#     if dataset_name in TOLERANCE:
#         continue
    run_SWAG(dataset_name, weight_decay=0, train_model_flag=False)
    for tolerance in np.linspace(0.05, 0.08, 8):
        print(f"""-----> TOLERANCE: {tolerance}""")
        x_train, y_train, x_test, y_test, _, _ = load_dataset(dataset_name, verbose=False)
        batch_size = x_train.shape[0]//9
        model = create_model(x_train, layer_dims=[50], verbose=False)
    #     train_model(x_train, y_train, x_test, y_test, model, dataset_name, lr=0.001, epochs=50000, verbose=False, 
    #                 batch_size=batch_size, weight_decay=weight_decay)
        model = load_model(model, f"best_model_weights-{dataset_name}.pth", verbose=False)
        y_pred = get_test_predictions(x_train, y_train, x_test, y_test, model)
    #     print(f"SGD RMSE: {RMSE(y_pred, y_test):.3f}")
#         print("---MOD---")
        theta_epoch = torch.nn.utils.parameters_to_vector(model.parameters()).detach().cpu().clone()
        lr_multipliers = torch.ones_like(theta_epoch).float()
        for rounds in range(10):
            try:
                model = load_model(model, f"best_model_weights-{dataset_name}.pth", verbose=False)
                theta_swa, sigma_diag, D, thetas = train_SWAG_mod(x_train, y_train, x_test, y_test, model, K, lr_multipliers,
                                                              verbose=False, lr=0.01, batch_size=batch_size, 
                                                              weight_decay=weight_decay, epochs=50)
                print(torch.diagonal(calculate_sigma(sigma_diag, D, K)).numpy().mean())
                weight_series = torch.stack(thetas).numpy()
#                 step_plot = weight_series.shape[1] // 5
                weight_series -= weight_series.mean(axis=0, keepdims=True)
                weight_series /= weight_series.std(axis=0, keepdims=True) + 1e-10
    #             plt.plot(weight_series[:,::step_plot], alpha=0.3)
    #             plt.show()
                coeffs = calculate_coeffs(weight_series, False)
                lr_multipliers[np.abs(coeffs) > tolerance] *= multiplier
                _, x_val, _, y_val = train_test_split(x_train, y_train, test_size=0.1)
                y_val_pred = model(torch.from_numpy(x_val).float().to(device)).detach().cpu().numpy()
                sigma_diag = torch.clamp(sigma_diag, min=1e-10)
                rmse_test, pcip_test, mpiw_test = sample_and_get_metrics(x_train, y_train, x_test, y_test, 
                                                                         model, theta_swa, sigma_diag, D, K, S)
                rmse_val, pcip_val, mpiw_val = sample_and_get_metrics(x_train, y_train, x_val, y_val, 
                                                                         model, theta_swa, sigma_diag, D, K, S)
            except Exception as e:
#                 raise e
                print("Some error")


dataset: boston_housing, rows: 506, columns: 13, range of x: [0.0, 711.0], range of y: [5.0, 50.0]
SGD RMSE: 0.652
SWAG_lr:  0.01000
RMSE: 0.388, PICP: 0.392, MPIW:0.324
-----> TOLERANCE: 0.05
0.00011692351


KeyboardInterrupt: 

In [None]:
torch.optim.Adam()