In [2]:
import torch
from torch.distributions import MultivariateNormal, StudentT
from attrdict import AttrDict
import math
from gp import gen_evalset
from tqdm import tqdm
import time
from utils.log import get_logger, RunningAverage

In [3]:
class GPSampler(object):
    def __init__(self, kernel, t_noise=None, seed=None):
        self.kernel = kernel
        self.t_noise = t_noise
        if seed is not None:
            torch.manual_seed(seed)
            torch.cuda.manual_seed(seed)
        self.seed = seed

    def sample(self,
            batch_size=16,
            num_ctx=None,
            num_tar=None,
            max_num_points=50,
            x_range=(-2, 2),
            device='cpu'):

        batch = AttrDict()
        num_ctx = num_ctx or torch.randint(low=3, high=max_num_points-3, size=[1]).item()  # Nc
        num_tar = num_tar or torch.randint(low=3, high=max_num_points-num_ctx, size=[1]).item()  # Nt

        num_points = num_ctx + num_tar  # N = Nc + Nt
        batch.x = x_range[0] + (x_range[1] - x_range[0]) \
                * torch.rand([batch_size, num_points, 1], device=device)  # [B,N,Dx=1]
        batch.xc = batch.x[:,:num_ctx]  # [B,Nc,1]
        batch.xt = batch.x[:,num_ctx:]  # [B,Nt,1]

        # batch_size * num_points * num_points
        cov, length, scale, noise_scale = self.kernel(batch.x)  # [B,N,N]
        mean = torch.zeros(batch_size, num_points, device=device)  # [B,N]
        batch.y = MultivariateNormal(mean, cov).rsample().unsqueeze(-1)  # [B,N,Dy=1]
        batch.yc = batch.y[:,:num_ctx]  # [B,Nc,1]
        batch.yt = batch.y[:,num_ctx:]  # [B,Nt,1]

        if self.t_noise is not None:
            if self.t_noise == -1:
                t_noise = 0.15 * torch.rand(batch.y.shape).to(device)  # [B,N,1]
            else:
                t_noise = self.t_noise
            batch.y += t_noise * StudentT(2.1).rsample(batch.y.shape).to(device)
        return batch, length, scale, noise_scale
        # {"x": [B,N,1], "xc": [B,Nc,1], "xt": [B,Nt,1],
        #  "y": [B,N,1], "yc": [B,Nt,1], "yt": [B,Nt,1]}

class RBFKernel(object):
    def __init__(self, sigma_eps=2e-2, max_length=0.6, max_scale=1.0):
        self.sigma_eps = sigma_eps
        self.max_length = max_length
        self.max_scale = max_scale

    # x: batch_size * num_points * dim  [B,N,Dx=1]
    def __call__(self, x):
        length = 0.1 + (self.max_length-0.1) \
                * torch.rand([x.shape[0], 1, 1, 1], device=x.device)
        scale = 0.1 + (self.max_scale-0.1) \
                * torch.rand([x.shape[0], 1, 1], device=x.device)

        # batch_size * num_points * num_points * dim  [B,N,N,1]
        dist = (x.unsqueeze(-2) - x.unsqueeze(-3))/length

        # batch_size * num_points * num_points  [B,N,N]
        cov = scale.pow(2) * torch.exp(-0.5 * dist.pow(2).sum(-1)) \
                + self.sigma_eps**2 * torch.eye(x.shape[-2]).to(x.device)

        return cov, length, scale, self.sigma_eps  # [B,N,N]

In [4]:
import torch
import gpytorch

class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

In [None]:
def gen_evalset():
    kernel = RBFKernel()

    sampler = GPSampler(kernel, t_noise=None, seed=0)
    batches = []
    for i in tqdm(range(3000), ascii=True):
        batches.append(sampler.sample(
            batch_size=16,
            max_num_points=50,
            device='cuda'))
    return batches

In [6]:
batches = gen_evalset()

100%|##########| 3000/3000 [00:08<00:00, 355.50it/s]


In [108]:
# import pickle
# with open('batches.pkl','wb') as f:
#     pickle.dump(batches,f)

In [109]:
# with open('batches.pkl','rb') as f:
#     data = pickle.load(f)

In [122]:
lls = []
rmses = []

In [None]:
for item in tqdm(batches):
    batch, length, scale, noise_scale = item
    # print(length.shape)
    curr_ll = []
    curr_rmse = []
    noise = noise_scale ** 2
    likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=gpytorch.constraints.GreaterThan(noise))

    for i in range(16):
        model = ExactGPModel(batch.xc[i], batch.yc[i].ravel(), likelihood)
        model.cuda()
        # How to set params
        model.covar_module.base_kernel.lengthscale = length[i]  # this is scale and not scale^2
        model.covar_module.outputscale = scale[i] ** 2 # this is sigma^2 and not sigma
        model.likelihood.noise = noise # this is sigma^2 and not sigma

        model.eval()
        likelihood.eval()
        with torch.no_grad():
            # print(f"{batch.xt.shape=}")
            # print(batch.xc[0].shape, batch.xt[0].shape, batch.yc[0].ravel())
            # observed_pred = model(batch.xt[i])
            observed_pred = likelihood(model(batch.xt[i]))
        # observed_pred is a Multi-variate normal distribution object, but we want to
        # compute point-wise log likelihoods.
        mean = observed_pred.mean
        stddev = observed_pred.stddev
        # print(batch.yt[i].shape)
        # print(mean.shape)
        oracle_log_likelihood = torch.distributions.Normal(mean, stddev).log_prob(batch.yt[i].ravel()).mean(-1)   

        curr_ll.append(oracle_log_likelihood.item())
        oracle_rmse = torch.sqrt(torch.mean((mean - batch.yt[i].ravel())**2))
        curr_rmse.append(oracle_rmse.item())
        
        # print(f"{oracle_rmse=}")

    lls.append(sum(curr_ll)/len(curr_ll))
    rmses.append(sum(curr_rmse)/len(curr_rmse))

100%|██████████| 3000/3000 [08:54<00:00,  5.62it/s]


In [112]:
sum(lls)/len(lls)

1.5369976595707293

In [113]:
sum(rmses)/len(rmses)

0.12042384587068228