In [None]:
import torch
import numpy as np
import pandas as pd
import math
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style("whitegrid")
sns.set_palette("bright")

sns.set(font_scale=2.0)
sns.set_style('whitegrid')



In [None]:
from online_gp.models.streaming_sgpr import StreamingSGPR, StreamingSGPRBound
from gpytorch import mlls, kernels

torch.set_default_tensor_type(torch.DoubleTensor)

In [None]:
def make_basic_plot(model, x, y, old_x=None, old_y=None, bounds=(-6., 6.)):
    model.eval()
    with torch.no_grad():
        test_x = torch.linspace(*bounds).view(-1,1)
        pred_dist = model.likelihood(model(test_x))
        pred_induc = model.variational_strategy.variational_distribution.mean
        
    fig = plt.figure(figsize=(8, 6))
    plt.plot(test_x, pred_dist.mean, label = "Predictive Mean", color = "#3dbbdb", linewidth=4)
    plt.fill_between(test_x.view(-1), *[x.detach() for x in pred_dist.confidence_region()], 
                     alpha = 0.3, color = "#3dbbdb")
    
    plt.scatter(x, y, color = "#6d6d6d", label = "Current Data", s=50, zorder=1)
    plt.scatter(model.variational_strategy.inducing_points.data, pred_induc.detach(), 
            color = "#d71e5e", marker="*", label = "Inducing Points", s=150, zorder=100)
    if old_x is not None:
        plt.scatter(old_x, old_y, color = "#6d6d6d", alpha = 0.3, label = "Old Data", s=50, zorder=1)
    sns.despine()
    plt.xlabel('$x$')
    plt.ylabel('$y$', rotation=0)
        
#     plt.legend()
    plt.tight_layout()
    return fig

In [None]:
x_bounds = (-1, 1)
batch_size = 16
num_init = num_z = 16
num_steps = 1040
shuffle = True

assert num_z <= num_init

x_train = torch.linspace(*x_bounds, num_steps)
y_train = torch.sin(6 * x_train) + math.sqrt(1e-2) * torch.randn(x_train.size(0))
if shuffle:
    row_perm = torch.randperm(x_train.size(0))
    x_train, y_train = x_train[row_perm], y_train[row_perm]
    
x_init, x_train = x_train[:num_init], x_train[num_init:]
y_init, y_train = y_train[:num_init], y_train[num_init:]

x_perm = torch.randperm(x_init.size(0))[:num_z]
z_init = x_init[x_perm].clone()

In [None]:
covar_module = kernels.RBFKernel()
ssgp = StreamingSGPR(z_init, covar_module=covar_module, learn_inducing_locations=True, num_data=x_init.size(0),
                    jitter=1e-3)
elbo = mlls.VariationalELBO(ssgp.likelihood, ssgp, num_data=x_init.size(0))
mll = mlls.ExactMarginalLogLikelihood(ssgp.likelihood, ssgp)
trainable_params = [
    dict(params=ssgp.likelihood.parameters(), lr=1e-1),
    dict(params=ssgp.covar_module.parameters(), lr=1e-1),
    dict(params=ssgp.variational_strategy.inducing_points, lr=1e-2),
    dict(params=ssgp.variational_strategy._variational_distribution.parameters(), lr=1e-2)
]
optimizer = torch.optim.Adam(trainable_params)

ssgp.train()
records = []
for i in range(400):
    optimizer.zero_grad()
    loss = -elbo(ssgp(x_init), y_init)
    loss.backward()
    optimizer.step()

    evidence = mll(ssgp(x_init), y_init)
    records.append(dict(elbo=-loss.item(), evidence=evidence.item(),
                        noise=ssgp.likelihood.noise.item()))
ssgp.eval()
ssgp.disable_q_grad()
# ssgp.likelihood.requires_grad_(False)

z_init = ssgp.variational_strategy.inducing_points.clone().detach()
z_init = z_init.sort(dim=0).values

In [None]:
df = pd.DataFrame(records)
fig = plt.figure(figsize=(10, 3))

ax = fig.add_subplot(1, 2, 1)
ax.plot(df.evidence, label='Evidence')
ax.plot(df.elbo, label='ELBO')
plt.legend()

ax = fig.add_subplot(1, 2, 2)
ax.plot(df.noise)
ax.set_ylabel('noise')

plt.tight_layout()

_ = make_basic_plot(ssgp, x_init, y_init, bounds=x_bounds)

In [None]:
num_update_steps = batch_size
num_chunks = x_train.size(0) // batch_size
x_seen = x_last = x_init
y_seen = y_last = y_init
records = []
for t, (new_x, new_y) in enumerate(zip(x_train.chunk(num_chunks), y_train.chunk(num_chunks))):

    ssgp.eval()
    pred_y = ssgp(new_x.view(-1))
    
    ssgp = ssgp.get_fantasy_model(x_last.view(-1, 1), y_last.view(-1, 1), resample_ratio=0.)
    elbo = StreamingSGPRBound(ssgp, combine_terms=False)
    mll = mlls.ExactMarginalLogLikelihood(ssgp.likelihood, ssgp)
    trainable_params = [
        dict(params=ssgp.likelihood.parameters(), lr=1e-3),
        dict(params=ssgp.covar_module.parameters(), lr=1e-3),
        dict(params=ssgp.variational_strategy.inducing_points, lr=1e-4)
    ]
    optimizer = torch.optim.Adam(trainable_params)
    
    ssgp.train()
    for _ in range(num_update_steps):
        optimizer.zero_grad()
        logp_term, trace_term, t1, t2 = elbo(new_x.view(-1, 1), new_y.view(-1, 1))
        loss = -(logp_term + trace_term)
        loss.backward()
        optimizer.step()
        evidence = mll(ssgp(x_seen), y_seen)
        z_now = ssgp.variational_strategy.inducing_points.clone().detach()
        z_now = z_now.sort(dim=0).values
        z_disp = (z_init - z_now).norm() / z_init.norm()
        records.append(dict(elbo=-loss.item(), evidence=evidence.item(),
                            noise=ssgp.likelihood.noise.item(), logp_term=logp_term.item(),
                            trace_term=trace_term.item(), z_disp=z_disp.item(),
                            t1=t1.item(), t2=t2.item()))
        
    x_seen = torch.cat([x_seen, new_x])
    y_seen = torch.cat([y_seen, new_y])
    x_last, y_last = new_x, new_y

    if t % (num_chunks // 4) == (num_chunks // 4 - 1):
        fig = make_basic_plot(ssgp, x_seen, y_seen, bounds=x_bounds)
#         plt.title(f'T = {(t + 1) * batch_size}')
        plt.show(fig)
        fig.savefig(f'sgpr_{batch_size}_sine_{(t+1) * batch_size}.pdf')

In [None]:
df = pd.DataFrame(records)
fig = plt.figure(figsize=(12, 5))

ax = fig.add_subplot(2, 2, 1)
ax.plot(df.evidence, label='Evidence')
ax.plot(df.logp_term, label='logp(\hat y)')
# ax.plot(df.elbo, label='ELBO')
plt.legend()

ax = fig.add_subplot(2, 2, 2)

# ax.plot(df.t1, label='t1')
ax.plot(df.t1 + df.t2, label='trace', color='darkblue')
ax.fill_between(range(len(df)), np.zeros_like(df.t1.values), df.t1, label='t1', color='blue', alpha=0.3)
ax.fill_between(range(len(df)), df.t1, df.t1 + df.t2, label='t2', color='lightblue', alpha=0.3)
plt.legend()

ax = fig.add_subplot(2, 2, 3)
ax.plot(df.noise)
ax.set_ylabel('noise')

ax = fig.add_subplot(2, 2, 4)
ax.plot(df.z_disp)
ax.set_ylabel('z displacement')

plt.tight_layout()

In [None]:
fig, ax = plt.subplots(figsize=(8, 6))
ax.plot(df.t1 + df.t2, label='trace', color='#330662')
ax.fill_between(range(len(df)), np.zeros_like(df.t1.values), df.t1, label=r'$\mathrm{trace}_1$', 
                color='#57068c', alpha=0.3)
ax.fill_between(range(len(df)), df.t1, df.t1 + df.t2, label=r'$\mathrm{trace}_2$', 
                color='#8900e1', alpha=0.3)
# plt.title(f'batch_size={batch_size}')
plt.ylim((0, 0.3))
plt.xlabel('t')
plt.ylabel("Value")
plt.legend(loc='upper right')
sns.despine()
plt.tight_layout()
plt.savefig(f'sgpr_{batch_size}_sine_trace.pdf')