In [1]:
import pickle
import matplotlib.pyplot as plt

from stheno import GP, EQ

In [2]:
# file = open('../experiments/synthetic/toy-data/weakly-periodic-100/data/seed-0/dim-2/train-data.pkl', 'rb')

# train_data = pickle.load(file)

In [3]:
# for i in range(3):
#     for j in range(3):

#         epoch = train_data[i]
#         batch = epoch[j]

#         x_context = batch["x_context"]
#         y_context = batch["y_context"]
#         x_target = batch["x_target"]
#         y_target = batch["y_target"]
        
#         print("==============================================================")
#         print(x_context.min(), x_context.max(), x_target.min(), x_target.max())
#         print(y_context.min(), y_context.max(), y_target.min(), y_target.max())
#         print(x_context.shape, x_context.shape, y_target.shape, y_target.shape)

In [4]:
# epoch = train_data[0]
# batch = epoch[0]

# x_context = batch["x_context"]
# y_context = batch["y_context"]
# x_target = batch["x_target"]
# y_target = batch["y_target"]

# idx = 10

# plt.scatter(x_context[idx, :, 0], x_context[idx, :, 1])
# plt.scatter(x_target[idx, :, 0], x_target[idx, :, 1])
# plt.show()

# plt.hist(y_target[idx, :, 0].numpy().flatten())
# plt.show()

In [5]:
import numpy as np
import matplotlib.pyplot as plt
import os
from datetime import datetime
import pickle
import time
import sys

from copy import deepcopy

from cnp.experiment import (
    generate_root,
    WorkingDirectory,
    save_checkpoint,
    log_args
)

from cnp.cnp import (
    StandardGNP,
    StandardAGNP,
    StandardConvGNP,
    FullConvGNP
)

from cnp.lnp import (
    StandardANP,
    StandardConvNP,
    StandardHalfUNetConvNP
)

from cnp.cov import (
    InnerProdCov,
    KvvCov,
    MeanFieldCov,
    AddHomoNoise,
    AddHeteroNoise,
    AddNoNoise
)

from cnp.oracle import (
    eq_cov,
    mat_cov,
    nm_cov,
    wp_cov,
    gp_loglik
)

from cnp.utils import (
    plot_samples_and_data,
    make_generator,
    Logger
)

import torch
from torch.distributions import MultivariateNormal
from torch.utils.tensorboard import SummaryWriter


# =============================================================================
# Training epoch helper
# =============================================================================


def train(data,
          model,
          optimiser,
          log_every,
          device,
          writer,
          iteration):
    
    for step, batch in enumerate(data):

        nll = model.loss(batch['x_context'].to(device),
                         batch['y_context'].to(device),
                         batch['x_target'].to(device),
                         batch['y_target'].to(device))

        if step % log_every == 0:
            print(f"Training   neg. log-lik: {nll:.2f}")

        # Compute gradients and apply them
        nll.backward()
        optimiser.step()
        optimiser.zero_grad()

        # Write to tensorboard
        writer.add_scalar('Train log-lik.', - nll, iteration)
        
        iteration = iteration + 1
        
    return iteration


# =============================================================================
# Validation helper
# =============================================================================


def validate(data,
             oracle_cov,
             model,
             args_np_val_samples,
             device,
             writer,
             latent_model):
    
    # Lists for logging model's training NLL and oracle NLL
    nll_list = []
    oracle_nll_list = []
    
    # If training a latent model, set the number of latent samples accordingly
    loss_kwargs = {'num_samples' : args_np_val_samples} if latent_model else {}
    
    with torch.no_grad():
        
        for step, batch in enumerate(data):
            
            nll = model.loss(batch['x_context'].to(device),
                             batch['y_context'].to(device),
                             batch['x_target'].to(device),
                             batch['y_target'].to(device),
                             **loss_kwargs)
            
            oracle_nll = torch.tensor(0.)

            # Oracle loss exists only for GP-generated data, not sawtooth
            if oracle_cov is not None:
                for b in range(batch['x_context'].shape[0]):
                    oracle_nll = oracle_nll - gp_loglik(batch['x_context'][b],
                                                            batch['y_context'][b],
                                                            batch['x_target'][b],
                                                            batch['y_target'][b],
                                                            oracle_cov)[0]
                        

            # Scale by the average number of target points
            nll_list.append(nll.item())
            oracle_nll_list.append(oracle_nll.item() / batch['x_context'].shape[0])

    mean_nll = np.mean(nll_list)
    std_nll = np.var(nll_list)**0.5
    
    mean_oracle_nll = np.mean(oracle_nll_list)
    std_oracle_nll = np.var(oracle_nll_list)**0.5

    # Print validation loss and oracle loss
    print(f"Validation neg. log-lik: "
          f"{mean_nll:.2f}")

    print(f"Oracle     neg. log-lik: "
          f"{mean_oracle_nll:.2f}")

    return mean_nll, std_nll, mean_oracle_nll, std_oracle_nll

args_data = 'eq-100'
args_x_dim = 2
args_seed = 0
args_validate_every = 10
args_model = 'GNP'
args_covtype = 'kvv-homo'
args_np_loss_samples = 20
args_np_val_samples = 8
args_num_basis_dim = 512
args_learning_rate = 5e-4
args_weight_decay = 0.
args_num_params = False
args_gpu = 1
    
# =============================================================================
# Set random seed, device and tensorboard writer
# =============================================================================

# Set seed
np.random.seed(args_seed)
torch.manual_seed(args_seed)

# Set device
if torch.cuda.is_available():
    torch.cuda.set_device(args_gpu)
    
use_cpu = not torch.cuda.is_available() and args_gpu == 0
device = torch.device('cpu') if use_cpu else torch.device('cuda')

root = '../experiments/synthetic'

# Working directory for saving results
experiment_name = os.path.join(f'{root}',
                               f'results',
                               f'{args_data}',
                               f'models',
                               f'{args_model}',
                               f'{args_covtype}',
                               f'seed-{args_seed}',
                               f'dim-{args_x_dim}')
working_directory = WorkingDirectory(root=experiment_name)

# Data directory for loading data
data_root = os.path.join(f'{root}',
                         f'toy-data',
                         f'{args_data}',
                         f'data',
                         f'seed-{args_seed}',
                         f'dim-{args_x_dim}')
data_directory = WorkingDirectory(root=data_root)

log_path = f'{root}/logs'
log_filename = f'{args_data}-{args_model}-{args_covtype}-{args_seed}'
log_directory = WorkingDirectory(root=log_path)
# sys.stdout = Logger(log_directory=log_directory, log_filename=log_filename)

# Tensorboard writer
writer = SummaryWriter(f'{experiment_name}/log')
    
file = open(working_directory.file('data_location.txt'), 'w')
file.write(data_directory.root)
file.close()

# =============================================================================
# Load data and validation oracle generator
# =============================================================================
    
file = open(data_directory.file('train-data.pkl'), 'rb')
data_train = pickle.load(file)
file.close()

file = open(data_directory.file('valid-data.pkl'), 'rb')
data_val = pickle.load(file)
file.close()

oracle_cov = None

Root: ../experiments/synthetic/results/eq-100/models/GNP/kvv-homo/seed-0/dim-2
Root: ../experiments/synthetic/toy-data/eq-100/data/seed-0/dim-2
Root: ../experiments/synthetic/logs


In [6]:
import numpy as np
import torch
import torch.nn as nn

from torch.distributions import MultivariateNormal

from cnp.encoders import (
    StandardEncoder,
    StandardANPEncoder,
    StandardConvNPEncoder
)

from cnp.decoders import (
    StandardDecoder,
    ConvDecoder
)

from cnp.architectures import (
    UNet,
    HalfUNet,
    StandardDepthwiseSeparableCNN
)

# =============================================================================
# General Latent Neural Process
# =============================================================================


class LatentNeuralProcess(nn.Module):
    
    
    def __init__(self, encoder, decoder, add_noise, num_samples):
        
        super().__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        self.add_noise = add_noise
        self.num_samples = num_samples

    
    def forward(self, x_context, y_context, x_target, num_samples=None):
        
        num_samples = self.num_samples if num_samples is None else num_samples
        
        # Pass context set and target inputs through the encoder to obtain
        # the encoder output, as expected by encoder.sample
        encoder_forward_output = self.encoder(x_context, y_context, x_target)
        
        means = []
        noise_vars = []
        
        for i in range(num_samples):
            
            r = self.encoder.sample(encoder_forward_output)
            output = self.decoder(r, x_context, y_context, x_target)
            
            assert (len(output.shape) == 3) and (output.shape[2] == 2)
            
            mean = output[:, :, :1]
            noise_var = torch.exp(output[:, :, 1])
            noise_var = torch.diag_embed(noise_var)
            noise_var = self.add_noise(noise_var, None)
            
            means.append(mean)
            noise_vars.append(noise_var)
            
        means = torch.stack(means, dim=0)
        noise_vars = torch.stack(noise_vars, dim=0)
        
        return means, noise_vars
    
    
    def loss(self, x_context, y_context, x_target, y_target, num_samples=None):
        
        B = y_target.shape[0]
        
        num_samples = self.num_samples if num_samples is None else num_samples
        
        # Compute mean and variance tensors, each of shape (S, B, N, D)
        means, noise_vars = self.forward(x_context,
                                         y_context,
                                         x_target,
                                         num_samples=num_samples)
        
        means = means[:, :, :, 0]
        idx = torch.arange(noise_vars.shape[2])
        noise_vars = noise_vars[:, :, idx, idx]
        
        logprobs = []
        
        for mean, noise_var in zip(means, noise_vars):
            
            distribution = torch.distributions.Normal(loc=mean,
                                                      scale=noise_var**0.5)
            logprob = torch.sum(distribution.log_prob(y_target[:, :, 0]),
                                axis=-1)
            
            logprobs.append(logprob)
            
        logprobs = torch.stack(logprobs, axis=1)
        logprob = 0
        
        for i, batch_logprobs in enumerate(logprobs):
            
            max_batch_logprob = torch.max(batch_logprobs)
        
            batch_logprobs = batch_logprobs - max_batch_logprob
            
            batch_mix_logprob = torch.log(torch.mean(torch.exp(batch_logprobs)))
            batch_mix_logprob = batch_mix_logprob + max_batch_logprob
            
            logprob = logprob + batch_mix_logprob
        
        return - logprob / B
    

    def mean_and_marginals(self, x_context, y_context, x_target):
        raise NotImplementedError


    @property
    def num_params(self):
        """Number of parameters."""
    
        return np.sum([torch.tensor(param.shape).prod() \
                       for param in self.parameters()])



# =============================================================================
# Attentive Latent Neural Process
# =============================================================================


class StandardANP(LatentNeuralProcess):
    
    
    def __init__(self, input_dim, add_noise, num_samples):
        
        # Standard input/output dim and latent representation dim
        # latent_dim is common to stochastic and deterministic paths, and
        # these are concatenated, producing a (2 * latent_dim) representation
        output_dim = 2
        latent_dim = 128
        
        # Decoder output dimension
        decoder_output_dim = output_dim

        # Construct the standard encoder
        encoder = StandardANPEncoder(input_dim=input_dim,
                                     latent_dim=latent_dim)
        
        # Construct the standard decoder
        decoder = StandardDecoder(input_dim=input_dim,
                                  latent_dim=2*latent_dim,
                                  output_dim=decoder_output_dim)

        super().__init__(encoder=encoder,
                         decoder=decoder,
                         add_noise=add_noise,
                         num_samples=num_samples)
        
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.latent_dim = latent_dim

In [13]:
def gp_loglik(xc, yc, xt, yt, covariance, noise=0.05**2):
    
    p = GP(covariance)
    p_post = p | (p(xc, noise), yc)
    
    loglik = pred.logpdf(y_target)
    
    return loglik

In [14]:
# =============================================================================
# Create model
# =============================================================================

# Create covariance method
cov = MeanFieldCov(num_basis_dim=1)
noise = AddNoNoise()

# cov = InnerProdCov(args_num_basis_dim)
# noise = AddHomoNoise()
    
# cov = KvvCov(args_num_basis_dim)
# noise = AddHomoNoise()
    
# Create model architecture

if args_model == 'GNP':
    model = StandardGNP(input_dim=args_x_dim,
                        covariance=cov,
                        add_noise=noise)

elif args_model == 'AGNP':
    model = StandardAGNP(input_dim=args_x_dim,
                         covariance=cov,
                         add_noise=noise)

elif args_model == 'convGNP':
    model = StandardConvGNP(input_dim=args_x_dim,
                            covariance=cov,
                            add_noise=noise)

elif args_model == 'ANP':
    
    noise = AddHomoNoise()
    model = StandardANP(input_dim=args_x_dim,
                        add_noise=noise,
                        num_samples=args_np_loss_samples)
    
elif args_model == 'convNP':
    
    noise = AddHomoNoise()
    model = StandardConvNP(input_dim=args_x_dim,
                           add_noise=noise,
                           num_samples=args_np_loss_samples)
    
else:
    raise ValueError(f'Unknown model {args_model}.')


print(f'{args_model} '
      f'{args_covtype} '
      f'{args_num_basis_dim}: '
      f'{model.num_params}')

with open(working_directory.file('num_params.txt'), 'w') as f:
    f.write(f'{model.num_params}')
        
if args_num_params:
    exit()
    
    
# Load model to appropriate device
model = model.to(device)

latent_model = args_model in ['ANP', 'convNP', 'convNPHalfUNet']

# if 'eq' in args_data:
#     oracle_cov = eq_cov(lengthscale=1.,
#                         coefficient=1.,
#                         noise=5e-2)

# elif 'matern' in args_data:
#     oracle_cov = mat_cov(lengthscale=1.,
#                          coefficient=1.,
#                          noise=5e-2)

# elif 'noisy-mixture' in args_data:
#     oracle_cov = nm_cov(lengthscale1=1.,
#                         lengthscale2=0.25,
#                         coefficient=1.,
#                         noise=5e-2)

# elif 'weakly-periodic' in args_data:
#     oracle_cov = wp_cov(period=0.25,
#                         lengthscale=1.,
#                         coefficient=1.,
#                         noise=5e-2)

if 'eq' in args_data:
    oracle_cov = EQ().stretch(1.)

        
# =============================================================================
# Train or test model
# =============================================================================

# Number of epochs between validations
train_iteration = 0
log_every = 500

# Create optimiser
optimiser = torch.optim.Adam(model.parameters(),
                         args_learning_rate,
                         weight_decay=args_weight_decay)

# Run the training loop, maintaining the best objective value
best_nll = np.inf

epochs = len(data_train)

start_time = time.time()
for epoch in range(epochs):

    print('\nEpoch: {}/{}'.format(epoch + 1, epochs))

    if epoch % args_validate_every == 0:

        valid_epoch = data_val[epoch // args_validate_every]

        # Compute validation negative log-likelihood
        val_nll, _, val_oracle, _ = validate(valid_epoch,
                                             oracle_cov,
                                             model,
                                             args_np_val_samples,
                                             device,
                                             writer,
                                             latent_model)

        # Log information to tensorboard
        writer.add_scalar('Valid log-lik.',
                          -val_nll,
                          epoch)

        writer.add_scalar('Valid oracle log-lik.',
                          -val_oracle,
                          epoch)

        writer.add_scalar('Oracle minus valid log-lik.',
                          -val_oracle + val_nll,
                          epoch)

        # Update the best objective value and checkpoint the model
        is_best, best_obj = (True, val_nll) if val_nll < best_nll else \
                            (False, best_nll)

        plot_marginals = args_covtype == 'meanfield'

        if args_x_dim == 1:
            
            plot_samples_and_data(model=model,
                                  valid_epoch=valid_epoch,
                                  x_plot_min=-3.,
                                  x_plot_max=3.,
                                  root=working_directory.root,
                                  epoch=epoch,
                                  latent_model=latent_model,
                                  plot_marginals=plot_marginals,
                                  device=device)


    train_epoch = data_train[epoch]

    # Compute training negative log-likelihood
    train_iteration = train(train_epoch,
                            model,
                            optimiser,
                            log_every,
                            device,
                            writer,
                            train_iteration)

GNP kvv-homo 512: 116610

Epoch: 1/101


TypeError: __call__() takes 2 positional arguments but 3 were given

In [2]:
import torch
import numpy as np

from stheno import *
from cnp.oracle import oracle_loglik

In [9]:
lengthscale = 1.
noise = 5e-2

cov = EQ().stretch(lengthscale)
p = GP(cov)

np.random.seed(0)

x_ctx = np.random.rand(3, 2)
y_ctx = np.random.rand(3)
x_trg = np.random.rand(10, 2)
y_trg = np.random.rand(10)

oracle_loglik(x_ctx, y_ctx, x_trg, y_trg, cov, noise)

-2.5705792788326036