# Setup

In [None]:
!pip install gpytorch



In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import gpytorch
import numpy as np
import random
import matplotlib.pyplot as plt
import seaborn as sns
import tqdm.notebook as tqdm
sns.set_style('darkgrid')

In [None]:
device = torch.device('cuda:0')

#Environments

In [None]:
def default_transform(x):
    '''Function to squash and transform last dimension of x.'''
    return torch.sum(x, dim=-1)

clamp_val = 100

def csc_tranform(x):
    return torch.sum(1 / torch.sin(x), dim=-1).clamp(min=-clamp_val, max=clamp_val)

def tan_transform(x):
    return torch.sum(torch.tan(x), dim=-1).clamp(min=-clamp_val, max=clamp_val)

def atan_transform(x):
    return torch.sum(torch.atan(1 / x), dim=-1)

def sin_inv_transform(x):
    return torch.sum(torch.sin(1 / x), dim=-1)

def sum_square_transform(x):
    return torch.sum(x.pow(2), dim=-1)

def square_sum_transform(x):
    return torch.sum(x, dim=-1).pow(2)

def sum_sin_transform(x):
    return torch.sum(torch.sin(x), dim=-1)

def prod_transform(x):
    return torch.prod(x, dim=-1)

def sign_transform(x):
    return torch.prod(torch.sign(x), dim=-1)

def sum_floor_transform(x):
    return 5 * torch.sum(torch.floor(x), dim=-1)

def noisy_sum_floor_transform(x):
    mean = 10 * torch.sum(torch.floor(x), dim=-1)
    noise = torch.normal(torch.zeros(mean.shape), torch.ones(mean.shape)).to(x.device)
    return noise + mean 

class FunctionTaskGenerator(nn.Module):
    def __init__(self, input_dim=1, latent_dim=1, lengthscale=0.5, transform=default_transform, train_size=None):
        '''Define distribution F over regression tasks. A task f~F is a function 
        f(X) = transform(Z) where Z~GP(X) is sampled from a multitask GP using 
        an RBF kernel with latent_dim tasks and transform is an arbitrary map.
        '''
        super().__init__()
        self.mean_module = gpytorch.means.MultitaskMean(
            gpytorch.means.ConstantMean(), num_tasks=latent_dim
        )
        rbf = gpytorch.kernels.RBFKernel()
        rbf.raw_lengthscale.data[...] = lengthscale
        self.covar_module = gpytorch.kernels.MultitaskKernel(
            rbf, num_tasks=latent_dim, rank=1
        )
        self.transform = transform
        self.input_dim = input_dim
        self.dummy = nn.Parameter(torch.empty([]))
    
    def forward(self, batch, K=5, validation=False):
        '''Samples batch of regression tasks with K examples per task. 
        Returns:
        X: datapoints sampled from N(0, I) for batch of tasks 
            with shape [batch, K, input_dim]
        Y: labels f(X) for f~F for batch of tasks with shape [batch, K]
        '''
        shape = [batch, K, self.input_dim]
        with torch.no_grad():
            X = torch.normal(torch.zeros(shape), torch.ones(shape)).to(self.dummy.device)
            Z = gpytorch.distributions.MultitaskMultivariateNormal(
                self.mean_module(X), self.covar_module(X)
            ).sample()
            Y = self.transform(Z)
            assert Y.dim() == 2
        return X, Y


class SinusoidTaskGenerator(nn.Module):
    def __init__(self, a_min=0.1, a_max=5.0, p_min=0.0, p_max=2*np.pi, f_min=0.5, f_max=2.0, 
                 tr_min=-5.0, tr_max=5.0, te_min=-5.0, te_max=10.0, noise=0.01, 
                 transform=torch.sin, out_of_range_val=False):
        '''Define distribution parameters for sinusoidal regression task.'''
        super().__init__()
        self.amp_range = [a_min, a_max]             # DKT: [0.1, 5.0], BMAML: [0.1, 5.0]
        self.phase_range = [p_min, p_max]           # DKT: [0.1, pi], BMAML: [0.0, 2*pi]
        self.freq_range = [f_min, f_max]            # DKT: [1.0, 1.0] (unused), BMAML: [0.5, 2.0]
        self.train_samp_range = [tr_min, tr_max]    # DKT: [-5.0, 5.0], BMAML: [-5.0, 5.0]
        self.test_samp_range = [te_min, te_max]     # DKT: Used for out-of-range testing, if needed
        self.noise_var = noise                      # DKT: 0.0 (unused), BMAML: 0.01.
        self.input_dim = 1
        self.transform = transform
        self.out_of_range_val = out_of_range_val
        self.dummy = nn.Parameter(torch.empty([]))
    
    def forward(self, batch, K, validation=False):
        '''Samples batch of sinusoidal regression tasks with K examples per task.
        Returns:
        X: datapoints of shape (batch, support_size+query_size), sampled uniformly from 
            self.train_samp_range/self.test_samp_range
        Y: Labels for these datapoints of X, aka Y[i][j] = A_i*np.sin(X[i][j] + C_i) for all (i, j).
            A_i is the amplitude for the i-th task, sampled uniformly from self.amplitude_range
            C_i is the phase for the i-th task, sampled uniformly from self.phase_range
        '''
        amp = torch.empty(batch).uniform_(*self.amp_range)
        phase = torch.empty(batch).uniform_(*self.phase_range)
        freq = torch.empty(batch).uniform_(*self.freq_range)
        shape = [batch, K]
        noise = torch.normal(torch.zeros(shape), self.noise_var * amp[:, None].expand(shape))
        with torch.no_grad():
            samp_range = self.test_samp_range if (self.out_of_range_val and validation) else self.train_samp_range
            X = torch.empty(shape).uniform_(*samp_range)
            Y = amp[:, None] * self.transform(freq[:, None] * X + phase[:, None]) + noise    
        return X.unsqueeze(-1).to(self.dummy.device), Y.to(self.dummy.device)


class StepFunctionTaskGenerator(nn.Module):
    def __init__(self, s_min=-2.5, s_max=2.5, samp_min=-5.0, samp_max=5.0, noise=0.03):
        '''Define distribution parameters for step function task.'''
        super().__init__()
        self.switch_range = [s_min, s_max]
        self.samp_range = [samp_min, samp_max]
        self.noise = noise
        self.input_dim = 1
        self.dummy = nn.Parameter(torch.empty([]))

    def forward(self, batch, K, validation=False):
        '''Samples batch of step function tasks.'''
        switches = torch.empty((batch, 3)).uniform_(*self.switch_range)
        print("Switches:", switches)
        shape = [batch, K]
        noise = torch.normal(torch.zeros(shape), self.noise)
        with torch.no_grad():
            X = torch.empty(shape).uniform_(*self.samp_range)
            num_greater = torch.sum(X[:, :, None] > switches[:, None, :], dim=2)
            Y = 2 * (num_greater % 2) - 1 + noise
        return X.unsqueeze(-1).to(self.dummy.device), Y.to(self.dummy.device)


class ConstantTaskGenerator(nn.Module):
    def __init__(self, input_dim=2):
        '''Define distribution F over trivial constant regression tasks.'''
        super().__init__()
        self.input_dim = input_dim
        self.dummy = nn.Parameter(torch.empty([]))
  
    def forward(self, batch, K=5, validation=False):
      '''Samples batch of regression tasks with K examples per task. 
      Returns:
      X: datapoints sampled from N(0, I) for batch of tasks 
          with shape [batch, K, input_dim]
      Y: labels f(X) for f~F for batch of tasks with shape [batch, K]
      '''
      shape = [batch, K, self.input_dim]
      with torch.no_grad():
          X = torch.normal(torch.zeros(shape), torch.ones(shape)).to(self.dummy.device)
          Y = torch.rand([batch, 1]).to(self.dummy.device).expand(-1, K)
      return X, Y

#Models

## Basic Models

In [None]:
class FunctionPriorModel(gpytorch.models.ExactGP):
    '''Exact GP model.'''
    def __init__(self, likelihood):
        super().__init__(None, None, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean = self.mean_module(x)
        covar = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean, covar)
    
    def clear(self):
        self.train_inputs = None
        self.train_targets = None

class LinearRegression(gpytorch.models.ExactGP):
    '''Bayesian linear regression model.'''
    def __init__(self, likelihood, input_dim):
        super().__init__(None, None, likelihood)
        self.mean_module = gpytorch.means.LinearMean(input_dim)
        self.covar_module = gpytorch.kernels.LinearKernel()

    def forward(self, x):
        mean = self.mean_module(x)
        covar = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean, covar)
    
    def clear(self):
        self.train_inputs = None
        self.train_targets = None

class MLP(nn.Module):
    '''Simple multilayer perceptron with relu activations and batchnorm.'''
    def __init__(self, input_units, hidden_units, output_units, hidden_layers):
        super().__init__()
        self.layers = nn.ModuleList()
        units = input_units
        for i in range(hidden_layers):
            self.layers.extend([
                nn.Linear(units, hidden_units),
                nn.BatchNorm1d(hidden_units),
                nn.ReLU(),
            ])
            units = hidden_units
        self.layers.extend([
            nn.Linear(units, output_units),
            nn.BatchNorm1d(output_units),
        ])
        
    def forward(self, x):
        batch_shape = x.shape[:-1]
        x = x.flatten(end_dim=-2)
        for layer in self.layers:
            x = layer(x)
        return x.reshape(*batch_shape, -1)
        
class LatentPriorModel(gpytorch.models.ExactGP):
    '''Exact multitask GP model. Optionally uses provided deep kernel model.'''
    def __init__(self, likelihood, output_dim, deep_kernel=None):
        super().__init__(None, None, likelihood)
        self.mean_module = gpytorch.means.MultitaskMean(
            gpytorch.means.ConstantMean(), num_tasks=output_dim
        )
        self.covar_module = gpytorch.kernels.MultitaskKernel(
            gpytorch.kernels.RBFKernel(), num_tasks=output_dim, rank=1
        )
        self.deep_kernel = deep_kernel

    def forward(self, x):
        if self.deep_kernel is not None:
            x = self.deep_kernel(x)
        mean = self.mean_module(x)
        covar = self.covar_module(x)
        return gpytorch.distributions.MultitaskMultivariateNormal(mean, covar)
    
    def clear(self):
        self.train_inputs = None
        self.train_targets = None

class VariationalModel(gpytorch.models.ExactGP):
    '''Multitask variational posterior model that uses deep kernel
    and mean functions.
    '''
    def __init__(self, likelihood, input_dim, latent_dim, hidden_units, hidden_layers):
        super().__init__(None, None, likelihood)
        self.covar_module = gpytorch.kernels.MultitaskKernel(
            gpytorch.kernels.RBFKernel(), num_tasks=latent_dim, rank=1
        )
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        self.mlp = MLP(input_dim + 1, hidden_units, input_dim + latent_dim, hidden_layers)

    def forward(self, x, y):
        inputs = torch.cat([x] + [y.unsqueeze(-1)], dim=-1)
        outputs = self.mlp(inputs)
        mean = outputs[..., :self.latent_dim]
        embedding = outputs[..., self.latent_dim:]
        assert embedding.size(-1) == self.input_dim
        covar = self.covar_module(embedding)
        return gpytorch.distributions.MultitaskMultivariateNormal(mean, covar)
    
    def clear(self):
        self.train_inputs = None
        self.train_targets = None


## VMGP

In [None]:
class VariationalMetaGP(nn.Module):
    '''Variational GP meta-learner with deep non-Gaussian likelihood.
    out_var: downweighting of the KL-divergence term in the loss; represents
      the variance of the Gaussians in the mixture representing the model's
      posterior predictions
    deep_kernel_dim: (optional int): if set, use a deep kernel for the 
      latent prior p(z|x) represented as a learned projection from input_dim to
      deep_kernel_dim composed with an RBF kernel.
    '''
    def __init__(
        self,
        input_dim,
        hidden_units,
        latent_dim,
        hidden_layers,
        out_var,
        deep_kernel_dim=None,
    ):
        super().__init__()
        if deep_kernel_dim is None:
            kernel_dim = input_dim
            self.deep_kernel = None
        else:
            kernel_dim = deep_kernel_dim
            self.deep_kernel = MLP(input_dim, hidden_units, kernel_dim, hidden_layers)
        self.variational_posterior = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=latent_dim)
        self.variational_model = VariationalModel(
            self.variational_posterior, input_dim, latent_dim, hidden_units, hidden_layers
        )
        self.latent_prior = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=latent_dim)
        self.latent_prior_model = LatentPriorModel(self.latent_prior, latent_dim, self.deep_kernel)
        self.likelihood_transform = MLP(latent_dim, hidden_units, 1, hidden_layers)
        self.latent_dim = latent_dim
        self.out_var = out_var
        
    def loss(self, X, Y):
        '''Compute loss over batch of tasks with datapoints X and labels Y.
        Returns the ELBO loss across all tasks.
        '''
        self.train()
        assert Y.dim() == X.dim() - 1
        self.latent_prior_model.eval()
        self.variational_model.eval()
        p_Z_X = self.latent_prior(self.latent_prior_model(X))
        q_Z_Y = self.variational_posterior(self.variational_model(X, Y))
        Z_samp = q_Z_Y.rsample()
        Y_pred = self.likelihood_transform(Z_samp).squeeze(-1)
        log_p_Y_Z_samp = -torch.pow(Y_pred - Y, 2).sum()
        dkl = torch.distributions.kl.kl_divergence(q_Z_Y, p_Z_X).sum()
        ELBO = log_p_Y_Z_samp - self.out_var * dkl
        return -ELBO

    def forward(self, X_train, Y_train, X_test, samples=100):
        '''Evaluate model predictions on a single meta-test task.
        Takes D_train = (X_train, Y_train) and returns the predicted Y_test 
        corresponding to X_test (in the form of samples from the prediction
        distribution along dimension 0 of the returned tensor).
        '''
        assert X_train.dim() == 2
        self.eval()
        Z_samp = self.variational_posterior(
            self.variational_model(X_train, Y_train)
        ).sample(sample_shape=torch.Size([samples]))
        Y_pred_all = []

        for Z_task_samp in Z_samp:
            self.latent_prior_model.set_train_data(X_train, Z_task_samp, strict=False)
            Z_test = self.latent_prior(self.latent_prior_model(X_test)).sample()
            Y_pred = self.likelihood_transform(Z_test).squeeze(-1)
            Y_pred_all.append(Y_pred)
            self.latent_prior_model.clear()
        
        return torch.stack(Y_pred_all).detach()

## MGP

In [None]:
class MetaGP(nn.Module):
    '''Baseline non-variational GP meta-learner.
    deep_kernel_dim: (optional int): if set, use a deep kernel for the 
      latent prior p(z|x) represented as a learned projection from input_dim to
      deep_kernel_dim composed with an RBF kernel.
    '''
    def __init__(
        self,
        input_dim,
        deep_kernel_dim=None,
        hidden_units=None,
        hidden_layers=None,
    ):
        super().__init__()
        self.likelihood = gpytorch.likelihoods.GaussianLikelihood()
        self.gp = FunctionPriorModel(self.likelihood)
        self.mll = gpytorch.mlls.ExactMarginalLogLikelihood(self.likelihood, self.gp)
        self.deep_kernel = (
            MLP(input_dim, hidden_units, deep_kernel_dim, hidden_layers)
            if deep_kernel_dim is not None
            else None
        )
        
    def loss(self, X, Y):
        '''Compute meta loss over batch of tasks with datapoints X and labels Y.
        Returns negative marginal log likelihood across all tasks.
        '''
        self.train()
        self.gp.eval()
        if self.deep_kernel is not None:
            X = self.deep_kernel(X)
        pred_loss = 0
        p_Y_X = self.gp(X)
        loss = -self.mll(p_Y_X, Y)
        return loss.sum()

    def forward(self, X_train, Y_train, X_test, samples=100):
        '''Evaluate model predictions on a single meta-test task.
        Takes D_train = (X_train, Y_train) and returns the predicted Y_test 
        corresponding to X_test.
        '''
        assert X_train.dim() == 2
        self.eval()
        if self.deep_kernel is not None:
            X_train = self.deep_kernel(X_train)
            X_test = self.deep_kernel(X_test)
        self.gp.set_train_data(X_train, Y_train, strict=False)
        Y_pred = self.likelihood(self.gp(X_test)).sample(sample_shape=torch.Size([samples]))
        self.gp.clear()
        return Y_pred.detach()

## Baseline Models

In [None]:
class FunctionalMLP(nn.Module):
    '''Simple multilayer perceptron, allowing functional passing of parameters.'''
    def __init__(self, input_units, hidden_units, output_units, hidden_layers):
        super().__init__()
        self.weights = nn.ParameterList()
        self.biases = nn.ParameterList()
        units = input_units
        for i in range(hidden_layers):
            self.weights.append(nn.Parameter(
                torch.empty([units, hidden_units])
            ))
            self.biases.append(nn.Parameter(
                torch.zeros([hidden_units])
            ))
            units = hidden_units
        self.weights.append(nn.Parameter(
            torch.empty([units, output_units])
        ))
        self.biases.append(nn.Parameter(
            torch.zeros([output_units])
        ))
        for weight in self.weights:
            nn.init.xavier_uniform_(weight)
        
    def forward(self, x, params=None):
        batch_shape = x.shape[:-1]
        x = x.flatten(end_dim=-2)
        
        if params is None:
            params = list(self.parameters())

        weights = params[:len(params) // 2]
        biases = params[len(params) // 2:]
        
        for i, (weight, bias) in enumerate(zip(weights, biases)):
            x = x @ weight + bias
            if i < len(weights) - 1:
                x = F.relu(x)

        return x.reshape(*batch_shape, -1)

class EMAML(nn.Module):
    '''An ensemble of MAMLs. Can view predictions of ensembles as samples from predictive posterior.'''
    def __init__(self, support_size, query_size, input_dim, hidden_units, hidden_layers, num_mamls=20, inner_lr=0.1):
        super().__init__()
        self.population = num_mamls
        self.input_dim = input_dim
        self.support_size = support_size
        self.query_size = query_size
        self.herd = nn.ModuleList([
            MAML(input_dim=self.input_dim, hidden_units=hidden_units, 
                 hidden_layers=hidden_layers, inner_lr=inner_lr) 
            for k in range(self.population)
        ])
    
    def loss(self, X, Y):
        X_support = X[:, :self.support_size ,:]
        X_query = X[:, self.support_size:, :]
        Y_support = Y[:, :self.support_size]
        Y_query = Y[:, self.support_size:]
        maml_losses = [m.loss(X_support, Y_support, X_query, Y_query) for m in self.herd]
        return sum(maml_losses)
    
    def forward(self, X_train, Y_train, X_test, samples=None):
        preds = [m.forward(X_train, Y_train, X_test) for m in self.herd]
        Y_pred = torch.stack(preds)
        return Y_pred.detach()


class MAML(nn.Module):
    '''Implementation of MAML'''
    def __init__(self, input_dim, hidden_units, hidden_layers, inner_lr):
        super().__init__()
        self.input_dim = input_dim
        self.net = FunctionalMLP(input_dim, hidden_units, 1, hidden_layers)
        self.inner_lr = inner_lr
        self.loss_func = nn.MSELoss()

    def inner_loop(self, init_weights, X_support, Y_support, X_query):
        '''Do inner update step on mini-batch of tasks with input data X and labels Y.'''
        out = self.net(X_support, init_weights).squeeze(-1)
        loss = self.loss_func(out, Y_support)
        grads = torch.autograd.grad(loss, init_weights, create_graph=True)
        temp_weights = [x - self.inner_lr * g for x, g in zip(init_weights, grads)]
        return self.net(X_query, temp_weights).squeeze(-1)

    def loss(self, X_support, Y_support, X_query, Y_query):
        init_weights = [x for x in self.net.parameters()]
        Y_pred = self.inner_loop(init_weights, X_support, Y_support, X_query)
        meta_loss = self.loss_func(Y_pred, Y_query)
        return meta_loss.mean()

    def forward(self, X_train, Y_train, X_test):
        init_weights = [x for x in self.net.parameters()]
        Y_pred = self.inner_loop(init_weights, X_train, Y_train, X_test)
        return Y_pred


class Alpaca(MetaGP):
    '''Implementation of Alpaca algorithm, viewed as Bayesian linear regression.'''
    def __init__(
        self,
        input_dim,
        deep_kernel_dim=None,
        hidden_units=None,
        hidden_layers=None,
    ):
        super().__init__(input_dim, deep_kernel_dim, hidden_units, hidden_layers)
        self.likelihood = gpytorch.likelihoods.GaussianLikelihood()
        self.gp = LinearRegression(self.likelihood, deep_kernel_dim or 1)
        self.mll = gpytorch.mlls.ExactMarginalLogLikelihood(self.likelihood, self.gp)
        self.deep_kernel = (
            MLP(input_dim, hidden_units, deep_kernel_dim, hidden_layers)
            if deep_kernel_dim is not None
            else None
        )

# Initialize Environment

In [None]:
train_size = None
K = 5
query_size = 5

data_generator = FunctionTaskGenerator(input_dim=1, transform=atan_transform).to(device)
input_dim = data_generator.input_dim

if train_size is not None:
    training_tasks = []
    for _ in range(train_size):
        task_x, task_y = data_generator(batch=batch, K=K + query_size)
        training_tasks.append((task_x, task_y))

def get_training_tasks():
    if train_size is None:
        while 1:
            yield data_generator(batch=batch, K=K + query_size)
    else:
        while 1:
            random.shuffle(training_tasks)
            for task_x, task_y in training_tasks:
                yield task_x, task_y

training_generator = iter(get_training_tasks())

#data_generator = FunctionTaskGenerator(
#    input_dim=1,
#    latent_dim=1,
#    lengthscale=0.5,
#    transform=csc_tranform,
#).to(device)

#Initialize Run

In [None]:
batch = 50
val_interval = 1000
val_trials = 50
val_samples = 20
itr = 0
learning_rate = 1e-3

model_name = 'Alpaca'
model = {
    'VMGP': lambda: VariationalMetaGP(
        input_dim=input_dim,
        hidden_units=40,
        latent_dim=10,
        hidden_layers=2,
        out_var=1e-2,
        deep_kernel_dim=10,
    ),
    'MGP': lambda: MetaGP(
        input_dim=input_dim,
        deep_kernel_dim=10,
        hidden_units=40,
        hidden_layers=2,
    ),
    'EMAML': lambda: EMAML(
        input_dim=input_dim,
        hidden_units=40,
        hidden_layers=2,
        support_size=K,
        query_size=query_size
    ),
    'Alpaca': lambda: Alpaca(
        input_dim=input_dim,
        deep_kernel_dim=10,
        hidden_units=40,
        hidden_layers=2,
    )
}[model_name]().to(device)

train_losses = []
validation_nll = []
validation_mse = []
opt = torch.optim.Adam(model.parameters(), lr=learning_rate)

NameError: ignored

In [None]:
def print_stats():
    if train_losses: 
        print(f'Iteration {itr} loss:', np.mean(train_losses[-val_interval:]))
    if validation_nll: 
        print(f'Iteration {itr} nll:', validation_nll[-1])
        print(f'Iteration {itr} mse:', validation_mse[-1])
        print()

def nll_metric(pred_y, test_y, out_var=0.1):
    err = torch.pow(pred_y - test_y.unsqueeze(0), 2) / out_var
    return -torch.logsumexp(-err, dim=0).mean() + np.log(pred_y.size(0))

def mse_metric(pred_y, test_y):
    return torch.pow(pred_y.mean(dim=0) - test_y, 2).mean()

def validate_model(val_trials, query_size=1, out_of_range=False, return_se=False):
    nlls = []
    mses = []
    for _ in range(val_trials):
        task_x, task_y = data_generator(batch=1, K=K + query_size, validation=True)
        train_x, train_y = task_x[0, :K], task_y[0, :K]
        test_x, test_y = task_x[0, -query_size:], task_y[0, -query_size:]
        pred_y = model(train_x, train_y, test_x, samples=val_samples)
        nlls.append(nll_metric(pred_y, test_y).item())
        mses.append(mse_metric(pred_y, test_y).item())
    if return_se:
        return np.mean(nlls), np.mean(mses), np.std(nlls) / np.sqrt(val_trials), np.std(mses) / np.sqrt(val_trials)
    return np.mean(nlls), np.mean(mses)

# Sanity Check

In [None]:
# Example validation task
task_x, task_y = data_generator(batch=1, K=K + query_size + 50000)
train_x, train_y = task_x[0, :K], task_y[0, :K]
test_x, test_y = task_x[0, K:K+query_size], task_y[0, K:K+query_size]
plot_x, plot_y = task_x[0, K+query_size:], task_y[0, K+query_size:]

print('train_x:', train_x.tolist())
print('train_y:', train_y.tolist())
print('test_x:', test_x.tolist())
print('test_y:', test_y.tolist())
print()

pred_y = model(train_x, train_y, test_x, samples=250)
print('mu_test_y:', pred_y.mean(dim=0).tolist())
print('sigma_test_y:', pred_y.std(dim=0).tolist())
print()

nll = nll_metric(pred_y, test_y)
mse = mse_metric(pred_y, test_y)
print('nll:', nll.item())
print('mse:', mse.item())

test_x_ = torch.arange(
    task_x.min() - 1., task_x.max() + 1., 1e-1, device=device
)[:, None].expand(-1, train_x.size(1))
pred_y_ = model(train_x, train_y, test_x_, samples=50)
pred_mu = pred_y_.mean(dim=0)
pred_sigma = pred_y_.std(dim=0)
plt.figure()
plt.hist(pred_y[..., 0].cpu().numpy(), bins=50)
plt.show()
plt.figure(dpi=150)
plt.scatter(train_x[:, 0].cpu().numpy(), train_y.cpu().numpy(), s=10, color='blue', label='training', zorder=3)
plt.scatter(plot_x.cpu().numpy(), plot_y.cpu().numpy(), s=2, color='limegreen', label='actual', zorder=1)
plt.plot(test_x_[:, 0].cpu().numpy(), pred_mu.cpu().numpy(), color='orange', label='prediction', zorder=2)
plt.fill_between(
    test_x_[:, 0].cpu().numpy(),
    (pred_mu - pred_sigma).cpu().numpy(), 
    (pred_mu + pred_sigma).cpu().numpy(),
    alpha=0.2,
    color='orange',
)
plt.scatter(test_x[:, 0].tolist(), test_y.tolist(), s=10, color='red', label='testing', zorder=4)
plt.ylabel('$y$')
plt.xlabel('$x$')
plt.legend()
plt.show()

NameError: ignored

# Training Loop

In [None]:
for _ in tqdm.trange(10000):
    task_x, task_y = next(training_generator)
    if val_interval and itr % val_interval == 0:
        nll, mse = validate_model(val_trials, query_size=query_size)
        validation_nll.append(nll)
        validation_mse.append(mse)
        print_stats()
    loss = model.loss(task_x, task_y)
    train_losses.append(loss.item())
    opt.zero_grad()
    loss.backward()
    opt.step()
    itr += 1


HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))




NameError: ignored

#Results

In [None]:
def smooth(data, kernel, maxnorm=np.inf):
    return nn.functional.conv1d(
        torch.tensor(data)[None, None, :].float().clamp(min=-maxnorm, max=maxnorm),
        torch.ones(kernel)[None, None, :] / kernel,
    ).flatten().numpy()

results = {}
def cache_results(name):
    results[name] = (train_losses, validation_nll, validation_mse)

In [None]:
plt.figure(dpi=100)
plt.plot(smooth(train_losses, 1), linewidth=0.3)
plt.title('Training Loss')
plt.xlabel('Iteration')
plt.ylabel('Training Loss')
plt.show()

In [None]:
plt.figure(dpi=100)
plt.plot(smooth(validation_nll, 1), linewidth=0.8)
plt.title('Validation NLL')
plt.xlabel('Iteration')
plt.ylabel('NLL')
plt.show()

In [None]:
plt.figure(dpi=100)
plt.plot(smooth(validation_mse, 1), linewidth=0.8)
plt.title('Validation MSE')
plt.xlabel('Iteration')
plt.ylabel('MSE')
plt.show()

In [None]:
cache_results('mgp_sin')

In [None]:
plt.figure(dpi=100)
start_iter = 5

for k, (_, nlls, mses) in results.items():
    plt.plot(np.arange(start_iter, len(nlls)), nlls[start_iter:], label=k)

plt.ylabel('NLL')
plt.xlabel('Iteration')
plt.legend()
plt.show()

plt.figure(dpi=100)
for k, (_, nlls, mses) in results.items():
    plt.plot(np.arange(start_iter, len(mses)), mses[start_iter:], label=k)
plt.ylabel('MSE')
plt.xlabel('Iteration')
plt.legend()
plt.show()

In [None]:
nll_mean, mse_mean, nll_se, mse_se = validate_model(val_trials=2000, query_size=query_size, out_of_range=False, return_se=True)
dict(nll_mean=nll_mean, nll_se=nll_se, mse_mean=mse_mean, mse_se=mse_se)



{'mse_mean': 0.6015626240206184,
 'mse_se': 0.01294987489320119,
 'nll_mean': 0.9142115069031715,
 'nll_se': 0.011262755170894577}