In [None]:
import torch
import gpytorch
import matplotlib.pyplot as plt

import seaborn as sns
sns.set_style("whitegrid")
sns.set_palette("bright")

sns.set(font_scale=2.0)
sns.set_style('whitegrid')


torch.set_default_dtype(torch.double)
torch.random.manual_seed(0)

In [None]:
palette = sns.light_palette("#57068c", 2, reverse=True)
palette

In [None]:
exact_palette = sns.light_palette("#28619e", 10, reverse=True)
exact_palette

In [None]:
from volatilitygp.models import SingleTaskVariationalGP

In [None]:
train_x = torch.randn(100)

fn = lambda x: torch.sin(2. * x.abs() + x**2 / 2)
train_y = fn(train_x)

In [None]:
plt.scatter(train_x, train_y)

In [None]:
model = SingleTaskVariationalGP(
    mean_module=gpytorch.means.ZeroMean(),
    init_points=3. * torch.randn(25,1), 
    likelihood = gpytorch.likelihoods.GaussianLikelihood(),
    use_piv_chol_init=True,
    use_whitened_var_strat=True,
)
model.likelihood.noise = 0.01
model.likelihood.raw_noise.detach_()

In [None]:
model.train()
# likelihood.train()

optimizer = torch.optim.Adam([
    {'params': model.parameters()},
    # {'params': likelihood.parameters()},
], lr=0.1)

# Our loss object. We're using the VariationalELBO
mll = gpytorch.mlls.VariationalELBO(model.likelihood, model, num_data=train_y.size(0))

In [None]:
for i in range(350):
    optimizer.zero_grad()
    loss = -mll(model(train_x), train_y)
    loss.backward()
    optimizer.step()
    
    if i % 50 == 0:
        print("loss: ", loss.item())

In [None]:
model.eval()
model.likelihood.eval()

test_x = torch.linspace(-4, 7, 100)

with torch.no_grad():
    pred = model.likelihood(model(test_x))

In [None]:
plt.scatter(train_x, train_y, color = "maroon")
plt.plot(test_x, pred.mean.detach())
plt.fill_between(test_x, *[x.detach() for x in pred.confidence_region()], alpha = 0.3)

In [None]:
model.variational_strategy.variational_distribution

In [None]:
test_points = 3. * torch.rand(25, 1, requires_grad = False) + 2.5
# test_values = model(test_points).rsample(torch.Size((128,))).unsqueeze(-1)
test_values = fn(test_points).view(-1)

In [None]:
test_values.shape

In [None]:
plt.scatter(train_x, train_y, color = "maroon")
plt.scatter(test_points, test_values, color = "red")

In [None]:
%pdb

In [None]:
fant_model = model.condition_on_observations(test_points, test_values, condition_into_sgpr=False)

In [None]:
fant_model.eval()
fant_model.likelihood.eval()

with torch.no_grad():
    fant_pred = fant_model.likelihood(fant_model(test_x))

In [None]:
ind_points = model.variational_strategy.inducing_points[:,0].detach()
pseudo_responses = fant_model.train_targets[25:].detach()

In [None]:
pseudo_responses.max()

In [None]:
plt.scatter(ind_points, pseudo_responses)
plt.scatter(test_points, fant_model.train_targets[:25].detach())

In [None]:
plt.plot(test_x, fant_pred.variance)
plt.plot(test_x, pred.variance.detach())

plt.scatter(train_x, 0.1 * torch.ones_like(train_x), color = "maroon")
plt.scatter(test_points, 0.1 * torch.ones_like(test_points), color = "red")

In [None]:
fig, ax = plt.subplots(1, 1, figsize = (8, 6))

plt.plot(test_x, exact_pred.mean.detach(), label = "Original GP", 
         color = exact_palette[3], linewidth=6, alpha = 0.8)
plt.plot(test_x, fant_exact_pred.mean.detach(), label = "Conditioned GP", 
         color = exact_palette[0], linewidth=6, alpha = 0.8)
plt.plot(test_x, pred.mean.detach(), label = "Original SVGP", color = palette[3], linewidth=6, alpha = 0.8)
#plt.fill_between(test_x, *[x.detach() for x in pred.confidence_region()], alpha = 0.2, color = palette[-2])

plt.plot(test_x, fant_pred.mean.detach(), label = "Conditioned SVGP", color = palette[0], linewidth=6, alpha = 0.8)
#plt.fill_between(test_x, *[x.detach() for x in fant_pred.confidence_region()], alpha = 0.2, color = palette[0])

plt.plot(torch.linspace(-4, 7, 100), 
         fn(torch.linspace(-4, 7, 100)), 
         color = "#6d6d6d", linestyle="--", linewidth=6, 
         label = "True Latent", zorder=0)

plt.scatter(train_x, train_y, color = "#d71e5e", label = "Training Points", marker = "x", 
            s = 400, alpha = 0.5,linewidths=6)
plt.scatter(test_points, test_values, color = "#d71e5e", label = "New Points", 
            s = 400, marker = "x",linewidths=6)

plt.legend(ncol = 7, loc = "upper center", bbox_to_anchor = (0.5, -0.2))
plt.xlabel("x")
plt.ylabel("y")
plt.xlim((-4, 7))
plt.savefig("fantasization_label.pdf", bbox_inches="tight")


In [None]:
fig, ax = plt.subplots(2, 1, figsize = (8, 6), sharex=True, sharey=True, dpi=300)

ax[0].plot(test_x, pred.mean.detach(), label = "Original Model", 
           color = palette[3], linewidth=4, alpha = 0.8)
ax[0].fill_between(test_x, *[x.detach() for x in pred.confidence_region()], alpha = 0.2, color = palette[3])

ax[1].plot(test_x, fant_pred.mean.detach(), label = "Conditioned Model", 
           color = palette[0], linewidth=4, alpha = 0.8)
ax[1].fill_between(test_x, *[x.detach() for x in fant_pred.confidence_region()], alpha = 0.2, color = palette[0])

ax[0].plot(torch.linspace(-4, 7, 100), 
         fn(torch.linspace(-4, 7, 100)), 
         color = "#6d6d6d", linestyle="--", linewidth=3, 
         label = "True Latent", zorder=0)
ax[1].plot(torch.linspace(-4, 7, 100), 
         fn(torch.linspace(-4, 7, 100)), 
         color = "#6d6d6d", linestyle="--", linewidth=3, 
         label = "True Latent", zorder=0)

ax[0].scatter(train_x, train_y, color = "#d71e5e", label = "Training Points", marker = "x", s = 100, zorder=30)
ax[1].scatter(train_x, train_y, color = "#d71e5e", label = "Training Points", marker = "x", s = 100, 
              alpha = 0.2, zorder=300)
ax[1].scatter(test_points, test_values, color = "#d71e5e", marker = "x", label = "New Points", s = 100, zorder=30)

# plt.legend(ncol = 5, loc = "upper center", bbox_to_anchor = (0.5, -0.2))
ax[0].set_xlabel("x")
ax[0].set_ylabel("y")
ax[1].set_xlabel("x")
ax[1].set_ylabel("y")
plt.xlim((-4, 7))
plt.ylim((-2, 2))
plt.savefig("fantasization_svgp_gaussian.pdf", bbox_inches = "tight")

In [None]:
from botorch.models import SingleTaskGP
from botorch.optim.fit import fit_gpytorch_torch
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.mlls import ExactMarginalLogLikelihood
from gpytorch.kernels import ScaleKernel, RBFKernel

In [None]:
exact_model = SingleTaskGP(train_x.reshape(-1,1), train_y.reshape(-1, 1), likelihood=GaussianLikelihood(), 
                    covar_module = ScaleKernel(RBFKernel()))
exact_model.likelihood.noise = 0.01
exact_model.likelihood.raw_noise.detach_()

In [None]:
mll = ExactMarginalLogLikelihood(exact_model.likelihood, exact_model)
fit_gpytorch_torch(mll);

In [None]:
exact_model.eval()
exact_model.likelihood.eval()

with torch.no_grad():
    exact_pred = exact_model.likelihood(exact_model(test_x))
    
# exact_pred = exact_model.posterior(test_points)
fant_exact_model = exact_model.condition_on_observations(test_points, test_values)

In [None]:
fant_exact_model.eval()
fant_exact_model.likelihood.eval()

with torch.no_grad():
    fant_exact_pred = fant_exact_model.likelihood(fant_exact_model(test_x))

In [None]:
fig, ax = plt.subplots(2, 1, figsize = (8, 6), sharex=True, sharey=True, dpi=300)

ax[0].plot(test_x, exact_pred.mean.detach(), label = "Original Model", 
           color = exact_palette[3], linewidth=4, alpha = 0.8)
ax[0].fill_between(test_x, *[x.detach() for x in exact_pred.confidence_region()], alpha = 0.2, 
                   color = exact_palette[3])

ax[1].plot(test_x, fant_exact_pred.mean.detach(), label = "Conditioned Model", 
           color = exact_palette[0], linewidth=4, alpha = 0.8)
ax[1].fill_between(test_x, *[x.detach() for x in fant_exact_pred.confidence_region()], 
                   alpha = 0.2, color = exact_palette[0])

ax[0].scatter(train_x, train_y, color = "#d71e5e", label = "Training Points", marker = "x", s = 100, zorder=30)
ax[1].scatter(train_x, train_y, color = "#d71e5e", label = "Training Points", marker = "x", s = 100, 
              alpha = 0.2, zorder=300)
ax[1].scatter(test_points, test_values, color = "#d71e5e", marker = "x", label = "New Points", s = 100, zorder=30)

ax[0].plot(torch.linspace(-4, 7, 100), 
         fn(torch.linspace(-4, 7, 100)), 
         color = "#6d6d6d", linestyle="--", linewidth=3, 
         label = "True Latent", zorder=0)
ax[1].plot(torch.linspace(-4, 7, 100), 
         fn(torch.linspace(-4, 7, 100)), 
         color = "#6d6d6d", linestyle="--", linewidth=3, 
         label = "True Latent", zorder=0)

# plt.legend(ncol = 5, loc = "upper center", bbox_to_anchor = (0.5, -0.2))
ax[0].set_xlabel("x")
ax[0].set_ylabel("y")
ax[1].set_xlabel("x")
ax[1].set_ylabel("y")
plt.xlim((-4, 7))
plt.ylim((-2, 2))
plt.savefig("fantasization_exact_gaussian.pdf", bbox_inches = "tight")