In [1]:
import pickle
import matplotlib.pyplot as plt

import stheno
from stheno import GP, EQ

import numpy as np
import matplotlib.pyplot as plt
import os
from datetime import datetime
import pickle
import time
import sys

from copy import deepcopy

from cnp.experiment import (
    generate_root,
    WorkingDirectory,
    save_checkpoint,
    log_args
)

from cnp.cnp import (
    StandardGNP,
    StandardAGNP,
    StandardConvGNP,
    FullConvGNP
)

from cnp.lnp import (
    StandardANP,
    StandardConvNP,
    StandardHalfUNetConvNP
)

from cnp.cov import (
    InnerProdCov,
    KvvCov,
    MeanFieldCov,
    AddHomoNoise,
    AddHeteroNoise,
    AddNoNoise
)

from cnp.oracle import (
    eq_cov,
    mat_cov,
    nm_cov,
    wp_cov,
    gp_loglik
)

from cnp.utils import (
    plot_samples_and_data,
    make_generator,
    Logger
)

import torch
from torch.distributions import MultivariateNormal
from torch.utils.tensorboard import SummaryWriter

In [2]:
def gp_loglik(xc, yc, xt, yt, covariance, noise=0.05**2):
    
    p = GP(covariance)
    p_post = p | (p(xc, noise), yc)
    pred = p_post(xt, noise)
    
    loglik = pred.logpdf(yt)
    
    return loglik

In [3]:
# =============================================================================
# Training epoch helper
# =============================================================================


def train(data,
          model,
          optimiser,
          log_every,
          device,
          writer,
          iteration):
    
    for step, batch in enumerate(data):

        nll = model.loss(batch['x_context'].to(device),
                         batch['y_context'].to(device),
                         batch['x_target'].to(device),
                         batch['y_target'].to(device))

        if step % log_every == 0:
            print(f"Training   neg. log-lik: {nll:.2f}")

        # Compute gradients and apply them
        nll.backward()
        optimiser.step()
        optimiser.zero_grad()

        # Write to tensorboard
        writer.add_scalar('Train log-lik.', - nll, iteration)
        
        iteration = iteration + 1
        
    return iteration


# =============================================================================
# Validation helper
# =============================================================================


def validate(data,
             oracle_cov,
             model,
             args_np_val_samples,
             device,
             writer,
             latent_model):
    
    # Lists for logging model's training NLL and oracle NLL
    nll_list = []
    oracle_nll_list = []
    
    # If training a latent model, set the number of latent samples accordingly
    loss_kwargs = {'num_samples' : args_np_val_samples} if latent_model else {}
    
    with torch.no_grad():
        
        for step, batch in enumerate(data):
            
            nll = model.loss(batch['x_context'].to(device),
                             batch['y_context'].to(device),
                             batch['x_target'].to(device),
                             batch['y_target'].to(device),
                             **loss_kwargs)
            
            oracle_nll = torch.tensor(0.)

            # Oracle loss exists only for GP-generated data, not sawtooth
            if oracle_cov is not None:
                for b in range(batch['x_context'].shape[0]):
                    oracle_nll = oracle_nll - gp_loglik(batch['x_context'][b].detach().cpu().numpy(),
                                                        batch['y_context'][b].detach().cpu().numpy(),
                                                        batch['x_target'][b].detach().cpu().numpy(),
                                                        batch['y_target'][b].detach().cpu().numpy(),
                                                        oracle_cov)
                        

            # Scale by the average number of target points
            nll_list.append(nll.item())
            oracle_nll_list.append(oracle_nll.item() / batch['x_context'].shape[0])

    mean_nll = np.mean(nll_list)
    std_nll = np.var(nll_list)**0.5
    
    mean_oracle_nll = np.mean(oracle_nll_list)
    std_oracle_nll = np.var(oracle_nll_list)**0.5

    # Print validation loss and oracle loss
    print(f"Validation neg. log-lik: "
          f"{mean_nll:.2f}")

    print(f"Oracle     neg. log-lik: "
          f"{mean_oracle_nll:.2f}")

    return mean_nll, std_nll, mean_oracle_nll, std_oracle_nll

args_data = 'eq'
args_x_dim = 2
args_batch_size = 4
args_max_num_context = 50
args_min_num_target = 100
args_max_num_target = 100
args_seed = 0
args_validate_every = 10
args_model = 'convNP'
args_covtype = 'meanfield'
args_np_loss_samples = 10
args_np_val_samples = 10
args_num_basis_dim = 512
args_learning_rate = 5e-4
args_weight_decay = 0.
args_num_params = False
args_gpu = 1

In [4]:
# =============================================================================
# Set random seed, device and tensorboard writer
# =============================================================================

# Set seed
np.random.seed(args_seed)
torch.manual_seed(args_seed)

# Set device
if torch.cuda.is_available():
    torch.cuda.set_device(args_gpu)
    
use_cpu = not torch.cuda.is_available() and args_gpu == 0
device = torch.device('cpu') if use_cpu else torch.device('cuda')

root = '../experiments/synthetic'

# Working directory for saving results
experiment_name = os.path.join(f'{root}',
                               f'results',
                               f'{args_data}',
                               f'models',
                               f'{args_model}',
                               f'{args_covtype}',
                               f'seed-{args_seed}',
                               f'dim-{args_x_dim}')
working_directory = WorkingDirectory(root=experiment_name)

# Data directory for loading data
suffix = f'{args_x_dim}-'           + \
         f'{args_batch_size}-'      + \
         f'{args_max_num_context}-' + \
         f'{args_min_num_target}-'  + \
         f'{args_max_num_target}-'  + \
         f'{args_seed}'

data_root = os.path.join(root, 'toy-data', f'{args_data}-{suffix}')
data_directory = WorkingDirectory(root=data_root)

log_path = f'{root}/logs'
log_filename = f'{args_data}-{args_model}-{args_covtype}-{args_seed}'
log_directory = WorkingDirectory(root=log_path)
# sys.stdout = Logger(log_directory=log_directory, log_filename=log_filename)

# Tensorboard writer
writer = SummaryWriter(f'{experiment_name}/log')
    
file = open(working_directory.file('data_location.txt'), 'w')
file.write(data_directory.root)
file.close()

# =============================================================================
# Load data and validation oracle generator
# =============================================================================
    
file = open(data_directory.file('train-data.pkl'), 'rb')
data_train = pickle.load(file)
file.close()

file = open(data_directory.file('valid-data.pkl'), 'rb')
data_val = pickle.load(file)
file.close()

oracle_cov = None

Root: ../experiments/synthetic/results/eq/models/convNP/meanfield/seed-0/dim-2
Root: ../experiments/synthetic/toy-data/eq-2-4-50-100-100-0
Root: ../experiments/synthetic/logs


In [5]:
import numpy as np
import torch
import torch.nn as nn

from torch.distributions import MultivariateNormal

from cnp.encoders import (
    StandardEncoder,
    ConvEncoder,
    StandardANPEncoder,
    StandardConvNPEncoder
)

from cnp.decoders import (
    StandardDecoder,
    ConvDecoder
)

from cnp.architectures import (
    UNet,
    HalfUNet,
    StandardDepthwiseSeparableCNN
)



In [None]:
# =============================================================================
# Create model
# =============================================================================

# Create covariance method
if args_covtype == 'meanfield':
    cov = MeanFieldCov(num_basis_dim=1)
    noise = AddNoNoise()

elif args_covtype == 'innerprod-homo':
    cov = InnerProdCov(args_num_basis_dim)
    noise = AddHomoNoise()
    
elif args_covtype == 'kvv-homo':
    cov = KvvCov(args_num_basis_dim)
    noise = AddHomoNoise()
    
# Create model architecture

if args_model == 'GNP':
    model = StandardGNP(input_dim=args_x_dim,
                        covariance=cov,
                        add_noise=noise)

elif args_model == 'AGNP':
    model = StandardAGNP(input_dim=args_x_dim,
                         covariance=cov,
                         add_noise=noise)

elif args_model == 'convGNP':
    model = StandardConvGNP(input_dim=args_x_dim,
                            covariance=cov,
                            add_noise=noise)

elif args_model == 'ANP':
    
    noise = AddHomoNoise()
    model = StandardANP(input_dim=args_x_dim,
                        add_noise=noise,
                        num_samples=args_np_loss_samples)
    
elif args_model == 'convNP':
    
    noise = AddHomoNoise()
    model = StandardConvNP(input_dim=args_x_dim,
                           add_noise=noise,
                           num_samples=args_np_loss_samples)
    
else:
    raise ValueError(f'Unknown model {args_model}.')


print(f'{args_model} '
      f'{args_covtype} '
      f'{args_num_basis_dim}: '
      f'{model.num_params}')

with open(working_directory.file('num_params.txt'), 'w') as f:
    f.write(f'{model.num_params}')
        
if args_num_params:
    exit()
    
    
# Load model to appropriate device
model = model.to(device)

latent_model = args_model in ['ANP', 'convNP', 'convNPHalfUNet']

if 'eq' in args_data:
    oracle_cov = EQ().stretch(1.)

elif 'matern' in args_data:
    oracle_cov = stheno.Matern52().stretch(1.)

elif 'noisy-mixture' in args_data:
    oracle_cov = stheno.EQ().stretch(1.) + \
                 stheno.EQ().stretch(0.25)

elif 'weakly-periodic' in args_data:
    oracle_cov = stheno.EQ().stretch(1.) * \
                 stheno.EQ().periodic(period=0.25)

        
# =============================================================================
# Train or test model
# =============================================================================

# Number of epochs between validations
train_iteration = 0
log_every = 500

# Create optimiser
optimiser = torch.optim.Adam(model.parameters(),
                         args_learning_rate,
                         weight_decay=args_weight_decay)

# Run the training loop, maintaining the best objective value
best_nll = np.inf

epochs = len(data_train)

start_time = time.time()
for epoch in range(epochs):

    print('\nEpoch: {}/{}'.format(epoch + 1, epochs))

    if epoch % args_validate_every == 0:

        valid_epoch = data_val[epoch // args_validate_every]

        # Compute validation negative log-likelihood
        val_nll, _, val_oracle, _ = validate(valid_epoch,
                                             oracle_cov,
                                             model,
                                             args_np_val_samples,
                                             device,
                                             writer,
                                             latent_model)

        # Log information to tensorboard
        writer.add_scalar('Valid log-lik.',
                          -val_nll,
                          epoch)

        writer.add_scalar('Valid oracle log-lik.',
                          -val_oracle,
                          epoch)

        writer.add_scalar('Oracle minus valid log-lik.',
                          -val_oracle + val_nll,
                          epoch)

        # Update the best objective value and checkpoint the model
        is_best, best_obj = (True, val_nll) if val_nll < best_nll else \
                            (False, best_nll)

        plot_marginals = args_covtype == 'meanfield'

        if args_x_dim == 1:
            
            plot_samples_and_data(model=model,
                                  valid_epoch=valid_epoch,
                                  x_plot_min=-3.,
                                  x_plot_max=3.,
                                  root=working_directory.root,
                                  epoch=epoch,
                                  latent_model=latent_model,
                                  plot_marginals=plot_marginals,
                                  device=device)


    train_epoch = data_train[epoch]

    # Compute training negative log-likelihood
    train_iteration = train(train_epoch,
                            model,
                            optimiser,
                            log_every,
                            device,
                            writer,
                            train_iteration)

convNP meanfield 512: 503415

Epoch: 1/101
Validation neg. log-lik: 381.38
Oracle     neg. log-lik: -109.75
Training   neg. log-lik: 382.25
Training   neg. log-lik: 104.07
Training   neg. log-lik: 81.04
Training   neg. log-lik: 65.14
Training   neg. log-lik: 51.86
Training   neg. log-lik: 47.87
Training   neg. log-lik: 77.57
Training   neg. log-lik: 41.30
Training   neg. log-lik: 20.99

Epoch: 2/101
Training   neg. log-lik: 18.13
Training   neg. log-lik: 4.63
Training   neg. log-lik: 60.51
Training   neg. log-lik: 35.91
