In [None]:
import math
import torch
import gpytorch
import spectralgp
import matplotlib.pyplot as plt
import matplotlib.cm as cm
%matplotlib inline

In [None]:
class SpectralModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood, **kwargs):
        super(SpectralModel, self).__init__(train_x, train_y, likelihood)

        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = spectralgp.kernels.SpectralGPKernel(**kwargs)
        self.covar_module.initialize_from_data(train_x, train_y, **kwargs)

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


In [None]:
likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_prior=gpytorch.priors.SmoothedBoxPrior(1e-8, 1e-4))
model = SpectralModel(train_x, train_y, likelihood, nomg=100)

In [None]:
n_iters = 10
ess_iters = 5
optim_iters = 5

In [None]:
# Trials(T)
T = 20
for t in range(T):
    Ft = model.sample(sample_shape=torch.Size(1,)).detach().numpy()
    At = np.argmax(Ft)
    
    '''
        OBSERVE Yt, Yt = Ft + Et
        
    '''
    
    train_y[AT] = Yt
    model.set_train_data(train_x, train_y, strict=False)
    alt_sampler = spectralgp.samplers.AlternatingSampler(
        [model], [likelihood], 
        spectralgp.sampling_factories.ss_factory, [spectralgp.sampling_factories.ess_factory],
        totalSamples=n_iters, numInnerSamples=ess_iters, numOuterSamples=optim_iters
        )
    alt_sampler.run()

In [None]:
model.eval()
n_samples = 10
spectrum_samples = alt_sampler.gsampled[0][0,:, -10:].detach()

predictions = torch.zeros(len(full_x), 10) # predictions for each sample
upper_bds = torch.zeros(len(full_x), 10) # upper conf. bd for each sample
lower_bds = torch.zeros(len(full_x), 10) # lower conf. bd for each sample

with torch.no_grad():
    for ii in range(n_samples):
        model.covar_module.set_latent_params(spectrum_samples[:, ii])
        model.set_train_data(train_x, train_y) # to clear out the cache
        pred_dist = model(full_x) 
        lower_bds[:, ii], upper_bds[:, ii] = pred_dist.confidence_region()
        predictions[:, ii] = pred_dist.mean

In [None]:
colors = cm.get_cmap("tab10")
## plot the predictions ##
plt.plot(full_x.numpy(), predictions[:, 0].detach().numpy(), label="Predictions",
         color=colors(0), linewidth=2)
plt.plot(full_x.numpy(), predictions.detach().numpy(), linewidth=2, 
         color=colors(0))

## Shade region +/- 2 SD around the mean ##
plt.fill_between(full_x.numpy(), lower_bds[:, 0].detach().numpy(), 
                 upper_bds[:, 0].detach().numpy(),
                 color=colors(0), alpha=0.03, label = r"$\pm 2$ SD")
for ii in range(n_samples):
    plt.fill_between(full_x.numpy(), lower_bds[:, ii].detach().numpy(), 
                     upper_bds[:, ii].detach().numpy(), 
                     color=colors(0), alpha=0.03)
    
## plot data ##
plt.plot(train_x.numpy(), train_y.numpy(), color=colors(1),
        linewidth=2, label="Train Data")
plt.plot(test_x.numpy(), test_y.numpy(), color=colors(1),
        linestyle="None", marker=".", markersize=12,
        label="Test Data")
plt.xlabel("X")
plt.ylabel("Y")
plt.title("Predictions and Data")
plt.legend()
plt.show()

In [None]:
plt.plot(model.covar_module.omega.numpy(), spectrum_samples.exp().numpy(), label = 'Posterior Samples')
plt.xlabel('Omega')
plt.ylabel('Density')
plt.xlim((0, 7))
plt.ylim((0,1))
plt.vlines(2/(2*3.14159),ymin=0, ymax=10, label = 'True Period')
plt.legend()