In [None]:
import numpy as np
import torch
from matplotlib import pyplot as plt
from sklearn.linear_model import LinearRegression
from torch import nn

from data import get_loaders, load_dataset, train_test_split
from training import test_step, train_step, train_model, get_test_predictions
from utils import print_losses
from swag import sample_from_SWAG, run_SWAG
from definitions import DATASETS
from models import create_model, load_model
from metrics import RMSE
from swag_mod import calculate_coeffs

In [None]:
def train_step_mod(
    batch_x: torch.Tensor,
    batch_y: torch.Tensor,
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    criterion,
    lr_multipliers,
):
    optimizer.zero_grad()
    outputs = model(batch_x)
    loss = criterion(outputs, batch_y)
    loss.backward()
    multitiply_grads(model, lr_multipliers)
    optimizer.step()
    return loss.item()

def train_SWAG_mod(
    x_train: np.array,
    y_train: np.array,
    x_test: np.array,
    y_test: np.array,
    model: nn.Module,
    K,
    lr_multipliers,
    epochs: int = 100,
    batch_size: int = 100,
    lr: float = 0.1,
    verbose: bool = True,
    c=1,
    momentum=0,
    weight_decay=0,
):
    assert c >= 1 and K >= 2
#     train_loader, test_loader = get_loaders(x_train, y_train, x_test, y_test, batch_size)
    train_loader, _, test_loader = get_loaders(x_train, y_train, x_test, y_test, batch_size, val_loader=True)
    theta_epoch = torch.nn.utils.parameters_to_vector(model.parameters()).detach().cpu().clone()
    theta = theta_epoch.clone()
    theta_square = theta_epoch.clone() ** 2
    D = None

    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
    criterion = nn.MSELoss()

    train_losses = [test_step(batch_x, batch_y, model, criterion) for batch_x, batch_y in train_loader]
    test_losses = [test_step(batch_x, batch_y, model, criterion) for batch_x, batch_y in test_loader]
    if verbose:
        print_losses(0, train_losses, test_losses)

    thetas = []
    for epoch in range(1, epochs + 1):
        train_losses = [train_step_mod(batch_x, batch_y, model, optimizer, criterion, lr_multipliers) 
                        for batch_x, batch_y in train_loader]
        test_losses = [test_step(batch_x, batch_y, model, criterion) for batch_x, batch_y in test_loader]
        if verbose:
            print_losses(epoch, train_losses, test_losses)
        if epoch % c == 0:
            if verbose:
                print("SWAG moment update")
            n = epoch / c
            theta_epoch = torch.nn.utils.parameters_to_vector(model.parameters()).detach().cpu().clone()
            thetas.append(theta_epoch.clone())
            theta = (n * theta + theta_epoch.clone()) / (n + 1)
            theta_square = (n * theta_square + theta_epoch.clone() ** 2) / (n + 1)
            deviations = (theta_epoch.clone() - theta).reshape(-1, 1)
            if D is None:
                D = deviations
            else:
                if D.shape[1] == K:
                    D = D[:, 1:]
                D = torch.cat((D, deviations), dim=1)
    sigma_diag = theta_square - theta ** 2
    torch.nn.utils.vector_to_parameters(theta, model.parameters())
    test_losses = [test_step(batch_x, batch_y, model, criterion) for batch_x, batch_y in test_loader]
    # print(f"Finished SWAG.     Best test loss: {np.mean(test_losses):.5f}")
    return theta, sigma_diag, D, thetas

def multitiply_grads(model, lr_multipliers):
    start_ind = 0
    for params in model.parameters():
        shape = params.shape
        total_len = params.reshape(-1).shape[0]
        multipliers = lr_multipliers[start_ind:(start_ind + total_len)].reshape(shape)
        start_ind += total_len
        params.grad = params.grad * multipliers

In [None]:
DSELECTED = [
    'boston_housing',
    'concrete',
    'energy_heating_load',
#     'kin8nm',
#     'naval_compressor_decay',
    'power',
#     'protein',
    'wine',
    'yacht',
#     'year_prediction_msd',
            ]
DSELECTED

In [None]:
# SWAG LOOP

# for dataset_name in DATASETS:
for dataset_name in DSELECTED:
    run_SWAG(dataset_name, weight_decay=0)

In [None]:
K = 10
S = 500
weight_decay = 0
multiplier = 2
tolerance = 0.05
# for dataset_name in ['boston_housing']:
# for dataset_name in ['yacht']:
# for dataset_name in ['year_prediction_msd']:
# for dataset_name in ['naval_compressor_decay']:
for tolerance in np.linspace(0.03, 0.07, 5):
    print(f"""========================================================================================
    
    TOLERANCE: {tolerance}
    """)
#     for dataset_name in DATASETS[:3]:
    for dataset_name in DSELECTED:
#     for dataset_name in ['naval_compressor_decay']:
        run_SWAG(dataset_name, weight_decay=0, train_model_flag=False)
        x_train, y_train, x_test, y_test, _, _ = load_dataset(dataset_name, verbose=False)
        batch_size = x_train.shape[0]//9
        model = create_model(x_train, layer_dims=[50], verbose=False)
    #     train_model(x_train, y_train, x_test, y_test, model, dataset_name, lr=0.001, epochs=50000, verbose=False, 
    #                 batch_size=batch_size, weight_decay=weight_decay)
        model = load_model(model, f"best_model_weights-{dataset_name}.pth", verbose=False)
        y_pred = get_test_predictions(x_train, y_train, x_test, y_test, model)
    #     print(f"SGD RMSE: {RMSE(y_pred, y_test):.3f}")
        print("---MOD---")
        theta_epoch = torch.nn.utils.parameters_to_vector(model.parameters()).detach().cpu().clone()
        lr_multipliers = torch.ones_like(theta_epoch).float()
        for rounds in range(15):
            try:
    #             print(0.001 * (2 ** (rounds)))
                model = load_model(model, f"best_model_weights-{dataset_name}.pth", verbose=False)
                theta_swa, sigma_diag, D, thetas = train_SWAG_mod(x_train, y_train, x_test, y_test, model, K, lr_multipliers,
                                                              verbose=False, lr=0.001, batch_size=batch_size, 
                                                              weight_decay=weight_decay, epochs=50)
                weight_series = torch.stack(thetas).numpy()
                step_plot = weight_series.shape[1] // 5
#                 print(weight_series.std(axis=0, keepdims=True))
                weight_series -= weight_series.mean(axis=0, keepdims=True)
                weight_series /= weight_series.std(axis=0, keepdims=True) + 1e-10
#                 print(weight_series)
    #             plt.plot(weight_series[:,::step_plot], alpha=0.3)
    #             plt.show()
                coeffs = calculate_coeffs(weight_series, False)
#                 print(coeffs)
                lr_multipliers[np.abs(coeffs) > tolerance] *= multiplier
    #             print(lr_multipliers)
                _, x_val, _, y_val = train_test_split(x_train, y_train, test_size=0.1)
                y_val_pred = model(torch.from_numpy(x_val).float()).detach().cpu().numpy()
                sigma_diag = torch.clamp(sigma_diag, min=1e-10)
                samples = sample_from_SWAG(x_train, y_train, x_test, y_test, model, theta_swa, sigma_diag, D, K, S)
                samples_array = np.concatenate(samples, axis=1)
                y_pred = samples_array.mean(axis=1, keepdims=True)
                y_l = np.percentile(samples_array, 2.5, axis=1, keepdims=True)
                y_u = np.percentile(samples_array, 97.5, axis=1, keepdims=True)
                print(f"RMSE: {RMSE(y_pred, y_test):.3f}, PICP: {np.mean((y_l < y_test) & (y_test < y_u)):.3f}, MPIW:{np.mean(y_u - y_l):.3f}, val RMSE: {RMSE(y_val_pred, y_val):.3f}")
            except Exception as e:
#                 raise e
                print("Some error")
