In [None]:
import numpy as np
import torch

from ignite.engine import Events, Engine
from ignite.metrics import Average, Loss
from ignite.contrib.handlers import ProgressBar

import gpytorch
from gpytorch.mlls import VariationalELBO
from gpytorch.likelihoods import GaussianLikelihood

from vduq.dkl import DKL_GP
from fc_resnet import FCResNet

import matplotlib.pyplot as plt
import seaborn as sns

sns.set()
sns.set_palette("colorblind")

In [None]:
def make_data(n_samples=1000, noise=0.05, seed=2):
    # make some random sines & cosines
    np.random.seed(seed)
    W = np.random.randn(30, 1)
    b = np.random.rand(30, 1) * 2 * np.pi
    
    x = 5 * np.sign(np.random.randn(n_samples)) + np.random.randn(n_samples).clip(-2, 2)
    y = np.cos(W * x + b).sum(0) + noise * np.random.randn(n_samples)
    return x[..., None], y

In [None]:
x, y = make_data()
plt.scatter(x, y)

In [None]:
np.random.seed(0)
torch.manual_seed(0)

input_dim = 1
batch_size = 64

X_train, y_train = make_data()
X_test, y_test = X_train, y_train
domain = 15

ds_train = torch.utils.data.TensorDataset(torch.from_numpy(X_train).float(), torch.from_numpy(y_train).float())
dl_train = torch.utils.data.DataLoader(ds_train, batch_size=batch_size, shuffle=True, drop_last=True)

ds_test = torch.utils.data.TensorDataset(torch.from_numpy(X_test).float(), torch.from_numpy(y_test).float())
dl_test = torch.utils.data.DataLoader(ds_test, batch_size=200, shuffle=False)

n_inducing_points = 20

features = 128
depth = 4
kernel = "Matern12"
spectral_normalization = True
num_classes = 1 # regression with 1D output

coeff = 0.95
n_power_iterations = 1
dropout_rate = 0.0

feature_extractor = FCResNet(
    input_dim=input_dim, 
    features=features, 
    depth=depth, 
    spectral_normalization=spectral_normalization, 
    coeff=coeff, 
    n_power_iterations=n_power_iterations,
    dropout_rate=dropout_rate
)
model = DKL_GP(
    feature_extractor=feature_extractor,
    num_classes=num_classes,
    train_dataset=ds_train,
    kernel=kernel,
    n_inducing_points=n_inducing_points,
)

likelihood = GaussianLikelihood()
loss_fn = VariationalELBO(likelihood, model.gp, num_data=len(ds_train))

if torch.cuda.is_available():
    model = model.cuda()
    likelihood = likelihood.cuda()
    
lr=1e-3

parameters = [
    {"params": model.feature_extractor.parameters(), "lr": lr},
    {"params": model.gp.parameters(), "lr": lr},
    {"params": likelihood.parameters(), "lr": lr},
]

optimizer = torch.optim.Adam(parameters, weight_decay=5e-4)
pbar = ProgressBar()

def step(engine, batch):
    model.train()
    likelihood.train()
    
    optimizer.zero_grad()
    
    x, y = batch
    if torch.cuda.is_available():
        x = x.cuda()
        y = y.cuda()

    y_pred = model(x)
    loss = - loss_fn(y_pred, y)
    
    loss.backward()
    optimizer.step()
    
    return loss.item()


def eval_step(engine, batch):
    model.eval()
    likelihood.eval()
    
    x, y = batch
    if torch.cuda.is_available():
        x = x.cuda()
        y = y.cuda()

    y_pred = model(x)
            
    return y_pred, y

    
trainer = Engine(step)
evaluator = Engine(eval_step)

metric = Average()
metric.attach(trainer, "loss")
pbar.attach(trainer)


metric = Loss(lambda y_pred, y: - loss_fn(y_pred, y))
metric.attach(evaluator, "loss")

@trainer.on(Events.EPOCH_COMPLETED)
def log_results(trainer):
    evaluator.run(dl_test)
    metrics = evaluator.state.metrics
    train_likelihood = trainer.state.metrics["loss"]
    if trainer.state.epoch % 20 == 0:
        print(f"Test Results - Epoch: {trainer.state.epoch} "
              f"Test Likelihood: {metrics['loss']:.2f} "
              f"Train Likelihood: {train_likelihood:.2f}")

In [None]:
trainer.run(dl_train, max_epochs=500)

In [None]:
model.eval()
likelihood.eval()

x_lin = np.linspace(-domain, domain, 100)
            
with torch.no_grad(), gpytorch.settings.num_likelihood_samples(64):
    xx = torch.tensor(x_lin[..., None]).float()
    if torch.cuda.is_available():
        xx = xx.cuda()
    ol = likelihood(model(xx))
    output = ol.mean.cpu()
    output_std = ol.stddev.cpu()

In [None]:
plt.xlim(-domain, domain)
plt.fill_between(x_lin, output - output_std, output + output_std, alpha=0.2, color='b')
plt.fill_between(x_lin, output - 2 * output_std, output + 2 * output_std, alpha=0.2, color='b')


plt.scatter([], [])
plt.scatter([], [])
X_vis, y_vis = make_data(n_samples=200)
plt.scatter(X_vis.squeeze(), y_vis)
plt.plot(x_lin, output, alpha=0.5)

In [None]:
plt.xlim(-domain, domain)

for i in range(12):
    plt.plot(x_lin, ol.rsample().cpu(), alpha=0.3, color='b')  
    
plt.scatter([], [])
plt.scatter([], [])
X_vis, y_vis = make_data(n_samples=200)
plt.scatter(X_vis.squeeze(), y_vis, s=50)