In [None]:
import torch
import matplotlib.pyplot as plt

import online_gp
import gpytorch

In [None]:
def get_data(N=50, shift=0):
    x = torch.randn(N,1) + shift
    y = (torch.sin(3. * x) + 0.1 * torch.randn_like(x)).view(-1)
    return x, y

In [None]:
x, y = get_data()

In [None]:
plt.scatter(x, y)

In [None]:
def fit_model(mll, model, optimizer, x, y, num_steps=1000):
    for i in range(num_steps):
        loss = -mll(model(x), y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        if i % (num_steps // 10 if num_steps > 10 else 1) == 0:
            print("Loss: ", loss)

In [None]:
def make_basic_plot(model, x, y, old_x=None, old_y=None, bounds=(-6., 6.)):
    model.eval()
    with torch.no_grad():
        test_x = torch.linspace(*bounds).view(-1,1)
        pred_dist = vargp_model(test_x)
        pred_induc = vargp_model(vargp_model.variational_strategy.inducing_points.data.view(-1,1))
        
    plt.plot(test_x, pred_dist.mean, label = "Predictive Mean")
    plt.fill_between(test_x.view(-1), *[x.detach() for x in pred_dist.confidence_region()], alpha = 0.3)
    
    plt.scatter(x, y, color = "black", label = "Current Data")
    plt.scatter(vargp_model.variational_strategy.inducing_points.data, pred_induc.mean.detach(), 
            color = "red", marker="x", label = "Inducing Points")
    if old_x is not None:
        plt.scatter(old_x, old_y, color = "grey", alpha = 0.5, label = "Old Data")
        
    plt.legend()

In [None]:
likelihood = gpytorch.likelihoods.GaussianLikelihood()
vargp_model = online_gp.models.VariationalGPModel(torch.randn(25, 1), streaming=False, likelihood = likelihood)
mll = gpytorch.mlls.VariationalELBO(likelihood=likelihood, model=vargp_model, num_data=x.shape[-2], beta = 1.0)

In [None]:
optimizer = torch.optim.Adam(list(vargp_model.parameters()) + list(likelihood.parameters()), lr = 0.01)

In [None]:
fit_model(mll, vargp_model, optimizer, x, y)

In [None]:
make_basic_plot(vargp_model, x, y)

In [None]:
new_x, new_y = get_data(N = 5, shift = 6)

induc_to_keep = torch.randperm(25)[:22]
new_x_to_keep = torch.randperm(new_x.shape[0])[:3]
new_inducing = torch.cat((
    new_x[new_x_to_keep], 
    vargp_model.variational_strategy.inducing_points[induc_to_keep]
),dim=0).detach()

vargp_model.update_variational_parameters(new_x, new_y, new_inducing)

In [None]:
vargp_model.zero_grad()


In [None]:
vargp_model.zero_grad()
vargp_model.train()

mll = gpytorch.mlls.VariationalELBO(
    likelihood=likelihood, 
    model=vargp_model, 
    num_data=new_x.shape[-2], 
    beta = 1.0,
    combine_terms=True
)
optimizer = torch.optim.Adam(list(vargp_model.parameters()) + list(likelihood.parameters()), lr = 0.1)

In [None]:
fit_model(mll, vargp_model, optimizer, new_x, new_y, num_steps=5)

In [None]:
make_basic_plot(vargp_model, new_x, new_y, old_x=x, old_y=y, bounds=(-3., 12.))
plt.ylim((-3., 3))

In [None]:
plt.scatter(new_inducing, new_inducing - vargp_model.variational_strategy.inducing_points.detach())
plt.xlabel("Initial Inducing Points")
plt.ylabel("Initial Inducing Points - Current Inducing Points")