In [None]:
import torch
import gpytorch
import matplotlib.pyplot as plt

import seaborn as sns

sns.set_style("whitegrid")
sns.set_palette("bright")

sns.set(font_scale=2.0)
sns.set_style('whitegrid')

torch.set_default_dtype(torch.double)
torch.random.manual_seed(0)

In [None]:
from volatilitygp.models import SingleTaskVariationalGP
from volatilitygp.likelihoods import PoissonLikelihood

In [None]:
train_x = torch.randn(150)

fn = lambda x: 2. * torch.sin(4. * x)
latent = fn(train_x)
train_y = PoissonLikelihood()(latent).sample()

In [None]:
plt.scatter(train_x, train_y)
plt.scatter(train_x, latent, c = "green", s=4)

In [None]:
model = SingleTaskVariationalGP(
    mean_module=gpytorch.means.ZeroMean(),
    init_points=3. * torch.randn(25,1), 
    likelihood = PoissonLikelihood(),
    use_piv_chol_init=False,
    covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.MaternKernel()),
    use_whitened_var_strat=True,
    learn_inducing_locations=True,
)

In [None]:
# model.variational_strategy.inducing_points = torch.nn.Parameter(
#     torch.linspace(-2, 4.5, 25).view(-1,1), requires_grad = False
# )

In [None]:
model.train()
# likelihood.train()

optimizer = torch.optim.Adam([
    {'params': model.parameters()},
    # {'params': likelihood.parameters()},
], lr=0.05)

# Our loss object. We're using the VariationalELBO
mll = gpytorch.mlls.VariationalELBO(model.likelihood, model, num_data=train_y.size(0))

In [None]:
for i in range(250):
    optimizer.zero_grad()
    loss = -mll(model(train_x), train_y)
    loss.backward()
    optimizer.step()
    
    if i % 50 == 0:
        print("loss: ", loss.item())

In [None]:
model.eval()
model.likelihood.eval()

test_x = torch.linspace(-4, 7, 100)

with torch.no_grad():
    pred = model(test_x)

In [None]:
plt.scatter(train_x, latent, color = "maroon")
plt.plot(test_x, pred.mean.detach())
plt.fill_between(test_x, *[x.detach() for x in pred.confidence_region()], alpha = 0.3)

In [None]:
test_points = 3. * torch.rand(150, 1, requires_grad = False) + 1.5
# test_values = model(test_points).rsample(torch.Size((128,))).unsqueeze(-1)
test_latent = fn(test_points)
test_values = PoissonLikelihood()(test_latent).sample()

In [None]:
test_values.shape

In [None]:
plt.scatter(train_x, train_y, color = "maroon")
plt.scatter(test_points, test_values, color = "red")
plt.scatter(model.variational_strategy.inducing_points.detach(),
           5. * torch.ones(25, 1))

In [None]:
%pdb

In [None]:
fant_model = model.get_fantasy_model(test_points, test_values.squeeze(), targets_are_gaussian=False)

In [None]:
ind_responses = fant_model.train_targets
ind_points = fant_model.train_inputs[0]

In [None]:
plt.scatter(train_x, latent, color = "maroon")
plt.scatter(ind_points.detach(), ind_responses.detach())


In [None]:
# plt.scatter(fant_model.covar_module.inducing_points.data, torch.ones(25))

In [None]:
fant_model.eval()
fant_model.likelihood.eval()

with torch.no_grad():
    fant_pred = fant_model(test_x)

In [None]:
plt.plot(test_x, fant_pred.variance.detach())
plt.plot(test_x, pred.variance.detach())

plt.scatter(train_x, 0.1 * torch.ones_like(train_x), color = "maroon")
plt.scatter(test_points, 0.1 * torch.ones_like(test_points), color = "red")

In [None]:
palette = sns.color_palette("Paired", 10)
palette.reverse()

In [None]:
from matplotlib.ticker import MaxNLocator

fig, ax = plt.subplots(1, 1, figsize = (8, 6.1))
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
ax.yaxis.set_major_locator(MaxNLocator(integer=True))

ax2 = ax.twinx()

ax2.plot(test_x, pred.mean.detach(), label = "Original Model", color = palette[-2], linewidth=3, alpha = 0.8)
ax2.fill_between(test_x, *[x.detach() for x in pred.confidence_region()], alpha = 0.2, color = palette[-2])

ax2.plot(test_x, fant_pred.mean.detach(), label = "Fantasy Model", color = palette[0], linewidth=3, alpha = 0.8)
ax2.fill_between(test_x, *[x.detach() for x in fant_pred.confidence_region()], alpha = 0.2, color = palette[0])
ax2.set_ylabel("Latent")

ax.scatter(train_x, train_y, color = palette[4], label = "Training Points", marker = "x", s = 100, alpha = 0.3)
ax.scatter(test_points, test_values, color = palette[2], label = "Fantasy Points", s = 100, alpha = 0.3)

# plt.plot(torch.sort(train_x)[0], fn(torch.sort(train_x)[0]), color = palette[4], linewidth=3)
# plt.plot(torch.sort(test_points.view(-1))[0], fn(torch.sort(test_points.view(-1))[0]), color = palette[2], linewidth=3)
plt.plot(torch.linspace(-4, 7, 100), fn(torch.linspace(-4, 7, 100)), 
         linestyle="--", color = palette[4], linewidth=3, zorder=0)
# plt.legend(ncol = 1, loc = "upper center", bbox_to_anchor = (0.5, -0.2))
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.grid()

plt.savefig("fantasization_poisson.pdf", bbox_inches="tight")

In [None]:
torch.sort(test_points)[0]