In [None]:
import os
import sys
import pickle
import time

from copy import deepcopy

from cnp.experiment import WorkingDirectory

from cnp.cnp import GaussianNeuralProcess, StandardConvGNP
from cnp.lnp import StandardConvNP
from cnp.architectures import UNet

from cnp.cov import (
    OutputLayer,
    MeanFieldGaussianLayer,
    InnerprodGaussianLayer,
    KvvGaussianLayer,
    LogLogitCopulaLayer
)

import numpy as np
import matplotlib.pyplot as plt

import torch

from cnp.encoders import (
    StandardEncoder,
    ConvEncoder,
    ConvPDEncoder,
)

from cnp.decoders import (
    StandardDecoder,
    ConvDecoder,
    ConvPDDecoder,
)

import torch
import torch.nn as nn

from torch.distributions import (
    Normal,
    MultivariateNormal,
    LowRankMultivariateNormal
)

from cnp.cov import GaussianLayer

In [None]:
args_data = 'sim-pred-prey-16-50-100-100-0' # 'sim-pred-prey-16-50-100-100-0'
args_cov_type = 'meanfield'
args_noise_type = 'homo'
args_marginal_type = 'loglogit'
args_model = 'convGNP'
args_num_basis_dim = 32

args_seed = 0
args_learning_rate = 1e-4
args_weight_decay = 0.
args_validate_every = 1
args_jitter = 1e-4

In [None]:
# Set seed
np.random.seed(args_seed)
torch.manual_seed(args_seed)

# root = '/Users/stratis/repos/kernelcnp/kernelcnp/experiments/predator-prey'
root = '/scratches/cblgpu07/em626/kernelcnp/kernelcnp/experiments/predator-prey'

# Working directory for saving results
experiment_name = os.path.join(f'{root}',
                               f'results',
                               f'{args_data}',
                               f'models',
                               f'{args_model}',
                               f'{args_cov_type}',
                               f'{args_noise_type}',
                               f'{args_marginal_type}',
                               f'seed-{args_seed}')
working_directory = WorkingDirectory(root=experiment_name)

# Data directory for loading data
data_root = os.path.join(f'{root}',
                         f'simulated-data',
                         f'{args_data}')
data_directory = WorkingDirectory(root=data_root)
    
file = open(working_directory.file('data_location.txt'), 'w')
file.write(data_directory.root)
file.close()

# =============================================================================
# Load data and validation oracle generator
# =============================================================================
    
file = open(data_directory.file('train-data.pkl'), 'rb')
data_train = pickle.load(file)
file.close()

file = open(data_directory.file('valid-data.pkl'), 'rb')
data_val = pickle.load(file)
file.close()

In [None]:
# =============================================================================
# Training epoch helper
# =============================================================================

def train(data,
          model,
          optimiser,
          log_every,
          device,
          writer,
          iteration):
    
    for step, batch in enumerate(data):

        nll = model.loss(batch['x_context'][:, :, None].to(device),
                         batch['y_context'][:, 0, :, None].to(device) / 100 + 1e-2,
                         batch['x_target'][:, :, None].to(device),
                         batch['y_target'][:, 0, :, None].to(device) / 100 + 1e-2)
        
        encoder_scale = torch.exp(model.encoder.sigma).detach().cpu().numpy().squeeze()
        decoder_scale = torch.exp(model.decoder.sigma).detach().cpu().numpy().squeeze()

        if step % log_every == 0:
            print(f"Training   neg. log-lik: {nll:.2f}, "
                  f"Encoder/decoder scales {encoder_scale:.3f} "
                  f"{decoder_scale:.3f}")

        # Compute gradients and apply them
        nll.backward()
        optimiser.step()
        optimiser.zero_grad()

#         # Write to tensorboard
#         writer.add_scalar('Train log-lik.', - nll, iteration)
        
        iteration = iteration + 1
        
    return iteration


# =============================================================================
# Validation helper
# =============================================================================


def validate(data,
             model,
             device,
             writer,
             latent_model):
    
    # Lists for logging model's training NLL and oracle NLL
    nll_list = []
    oracle_nll_list = []
    
    # If training a latent model, set the number of latent samples accordingly
    loss_kwargs = {'num_samples' : args_np_val_samples} \
                  if latent_model else {}
    
    with torch.no_grad():
        for step, batch in enumerate(data):
            
            print(batch['x_context'].shape,
                  batch['y_context'].shape,
                  batch['x_target'].shape,
                  batch['y_target'].shape)
            
            nll = model.loss(batch['x_context'][:, :, None].to(device),
                             batch['y_context'][:, 0, :, None].to(device) / 100 + 1e-2,
                             batch['x_target'][:, :, None].to(device),
                             batch['y_target'][:, 0, :, None].to(device) / 100 + 1e-2,
                             **loss_kwargs)
            
            # Scale by the average number of target points
            nll_list.append(nll.item())

    mean_nll = np.mean(nll_list)
    std_nll = np.var(nll_list)**0.5

    # Print validation loss and oracle loss
    print(f"Validation neg. log-lik: "
          f"{mean_nll:.2f}")

    return mean_nll, std_nll

In [None]:
class StandardConvGNP(GaussianNeuralProcess):
    
    def __init__(self, input_dim, output_layer):
        
        # Standard input/output dimensions and discretisation density
        output_dim = 1
        points_per_unit = 64

        conv_channels = 8
        conv_in_channels = conv_channels
        conv_out_channels = 8
        
        # Standard convolutional architecture
        conv_architecture = UNet(input_dim=input_dim,
                                 in_channels=conv_in_channels,
                                 out_channels=conv_out_channels)

        # Construct the convolutional encoder
        grid_multiplyer =  2 ** conv_architecture.num_halving_layers
        encoder_init_length_scale = 1e-1 # 1.0 / points_per_unit
        decoder_init_length_scale = 1e-1 # 1.0 / points_per_unit
        grid_margin = 5.
        
        encoder = ConvEncoder(input_dim=input_dim,
                              out_channels=conv_channels,
                              init_length_scale=encoder_init_length_scale,
                              points_per_unit=points_per_unit,
                              grid_multiplier=grid_multiplyer,
                              grid_margin=grid_margin)
        
        # Construct the convolutional decoder
        decoder_out_channels = output_layer.num_features
        
        decoder = ConvDecoder(input_dim=input_dim,
                              conv_architecture=conv_architecture,
                              conv_out_channels=conv_architecture.out_channels,
                              out_channels=decoder_out_channels,
                              init_length_scale=decoder_init_length_scale,
                              points_per_unit=points_per_unit,
                              grid_multiplier=grid_multiplyer,
                              grid_margin=grid_margin)


        super().__init__(encoder=encoder,
                         decoder=decoder,
                         output_layer=output_layer)
        
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.conv_architecture = conv_architecture

In [None]:
# =============================================================================
# Log-logit copula output layer
# =============================================================================

class LogLogitCopulaLayer(OutputLayer):
    
    
    def __init__(self, gaussian_layer, device):
        
        super().__init__()
        
        # Initialise Gaussian layer
        self.gaussian_layer = gaussian_layer
        
        # Number of features equal to number of Gaussian layer features plus
        # two additional features for the Gamma - rate and concentration
        self.num_features = self.gaussian_layer.num_features + 2
        
        # Set device
        self.device = device

    
    def loglik(self, tensor, y_target):
        """
        Arguments:
            tensor   : torch.tensor, (B, T, C)
            y_target : torch.tensor, (B, T)
            
        Returns:
            tensor : torch.tensor, (B, T)
        """
        
        # Unpack parameters and apply inverse transformation
        tensor, a, b = self.unpack_parameters(tensor=tensor)
        v_target = self.inverse_marginal_transformation(x=y_target,
                                                        a=a,
                                                        b=b)
        
        # Log-likelihood of transformed variables under Gaussian
        loglik = self.gaussian_layer.loglik(tensor=tensor, y_target=v_target)
        
        # Compute change-of-variables contribution (Jacobian is diagonal)
        grad = self.inverse_marginal_transformation(x=y_target,
                                                    a=a,
                                                    b=b,
                                                    grad=True)
        jacobian_term = torch.sum(torch.log(torch.abs(grad)), dim=-1)
        
        # Ensure shapes are compatible
        assert loglik.shape == jacobian_term.shape
        
        return loglik + jacobian_term

    
    def sample(self, tensor, num_samples, noiseless, double=False):
        """
        Arguments:
            tensor      : torch.tensor, (B, T, C)
            num_samples : int, number of samples to draw
            noiseless   : bool, whether to include the noise term
            
        Returns:
            tensor : torch.tensor, (B, T)
        """
        
        # Unpack parameters and apply inverse transformation
        tensor, a, b = self.unpack_parameters(tensor=tensor)
        
        # Draw samples from Gaussian and apply marginal transformation
        v_samples = self.gaussian_layer.sample(tensor=tensor,
                                               num_samples=num_samples,
                                               noiseless=noiseless,
                                               double=double)
        
        # Repeat a and b, (num_samples, B, T)
        a = a[None, :, :].repeat(num_samples, 1, 1)
        b = b[None, :, :].repeat(num_samples, 1, 1)
        
        # Apply marginal transformation to Gaussian samples
        samples = self.marginal_transformation(v_samples, a=a, b=b)
        
        return samples
        
        
    def unpack_parameters(self, tensor):
        """
        Arguments:
            tensor : torch.tensor, (B, T, C)
            
        Returns:
            tensor : torch.tensor, (B, T, C-2)
            a      : torch.tensor, (B, T)
            b      : torch.tensor, (B, T)
        """
        
        epsilon = 1e-2
        
        # Check tensor has correct number of features
        assert (len(tensor.shape) == 3) and \
               (tensor.shape[-1] == self.num_features)
        
        # Get rate and concentration from tensor
        a = 0. * tensor[:, :, 0] + 1. #torch.nn.Softplus()(tensor[:, :, 0]) + epsilon
        b = torch.nn.Softplus()(1e-2 * tensor[:, :, 1]) + 1e0 + epsilon
        
        # Slice out rate and concentration
        tensor = tensor[:, :, 2:] / 1e2
        
        return tensor, a, b
    
    
    def pdf(self, x, a, b):
        """
        Probability distribution function of the log-logistic distribution.
        
            PDF(x) = (b/a) * (x/a)^(b-1) / (1 + (x/a)^b)^2
        
        Arguments:
            x : torch.tensor, (B, T)
            a : torch.tensor, (B, T)
            b : torch.tensor, (B, T)
            
        Returns:
            tensor : torch.tensor, (B, T)
        """
        
        # Check shapes are compatible, all x values are positive
        assert x.shape == a.shape == b.shape
        assert torch.all(x > 0.)
        
        return (b/a) * (x/a)**(b-1) / (1+(x/a)**b)**2
    
    
    def cdf(self, x, a, b):
        """
        Cumulative distribution function of the log-logistic distribution.
        
            CDF(x) = 1 / (1 + (x/a)^-b)
        
        Arguments:
            x : torch.tensor, (B, T)
            a : torch.tensor, (B, T)
            b : torch.tensor, (B, T)
            
        Returns:
            tensor : torch.tensor, (B, T)
        """
        
        # Check shapes are compatible, all x values are positive
        assert x.shape == a.shape == b.shape
        assert torch.all(x > 0.)
        
        x = x.double()
        a = a.double()
        b = b.double()
        
        cdf = 1 / (1+(x/a)**-b)
        cdf = cdf.float()
        
        return cdf
    
    
    def icdf(self, x, a, b):
        """
        Inverse cumulative distribution function of the log-logistic
        distribution.
        
            CDF^-1(x) = a * (x^-1 - 1)^(-1/b)
        
        Arguments:
            x : torch.tensor, (B, T)
            a : torch.tensor, (B, T)
            b : torch.tensor, (B, T)
            
        Returns:
            tensor : torch.tensor, (B, T)
        """
        
        # Check shapes are compatible, all x values are positive
        assert x.shape == a.shape == b.shape
        assert torch.all(x > 0.)
        
        x = x.double()
        a = a.double()
        b = b.double()
        
        icdf = a * (x**-1 - 1) ** (-1/b)
        icdf = icdf.float()
        
        return icdf
    
    
    def marginal_transformation(self, x, a, b):
        """
        Arguments:
            x : torch.tensor, (B, T)
            a : torch.tensor, (B, T)
            b : torch.tensor, (B, T)
            
        Returns:
            tensor : torch.tensor, (B, T)
        """
        
        # Check shapes are compatible, all x values are positive
        assert x.shape == a.shape == b.shape
        
        zeros = torch.zeros(size=x.shape).double().to(self.device)
        ones = torch.ones(size=x.shape).double().to(self.device)
        
        gaussian = Normal(loc=zeros, scale=ones)
        
        x = gaussian.cdf(x)
        x = self.icdf(x, a, b)
        
        return x
        
        
    def inverse_marginal_transformation(self, x, a, b, grad=False):
        """
        Arguments:
            x : torch.tensor, (B, T)
            a : torch.tensor, (B, T)
            b : torch.tensor, (B, T)
            
        Returns:
            x : torch.tensor, (B, T)
        """
        
        # Check shapes are compatible, all x values are positive
        assert x.shape == a.shape == b.shape
        assert torch.all(x > 0.)
        
        zeros = torch.zeros(size=x.shape).double().to(self.device)
        ones = torch.ones(size=x.shape).double().to(self.device)
        
        gaussian = Normal(loc=zeros, scale=ones)
        
        if grad:
            x = self.pdf(x, a, b) / gaussian.icdf(self.cdf(x, a, b))
        
        else:
            x = self.cdf(x, a, b)
            x = gaussian.icdf(x)
        
        return x

In [None]:
# =============================================================================
# Create model
# =============================================================================
    
# Set device
if torch.cuda.is_available():
    torch.cuda.set_device(0)
    
use_cpu = False
device = torch.device('cpu') if use_cpu else torch.device('cuda')

cov_types = {
    'meanfield' : MeanFieldGaussianLayer,
    'innerprod' : InnerprodGaussianLayer,
    'kvv'       : KvvGaussianLayer
}

if args_cov_type == 'meanfield':
    output_layer = MeanFieldGaussianLayer(jitter=args_jitter)
    
else:
    output_layer = cov_types[args_cov_type](num_embedding=args_num_basis_dim,
                                            noise_type=args_noise_type,
                                            jitter=args_jitter)

if args_marginal_type == 'loglogit':
    output_layer = LogLogitCopulaLayer(gaussian_layer=output_layer,
                                       device=device)
    
# Create model architecture
if args_model == 'convGNP':
    model = StandardConvGNP(input_dim=1, output_layer=output_layer)
    
elif args_model == 'convNP':
    model = StandardConvNP(input_dim=1,
                           num_samples=args_np_loss_samples)
    
else:
    raise ValueError(f'Unknown model {args_model}.')

model.to(device)

latent_model = args_model == 'convNP'

print(f'{args_data} '
      f'{args_model} '
      f'{args_cov_type} '
      f'{args_noise_type} '
      f'{args_marginal_type} '
      f'{args_num_basis_dim}: '
      f'{model.num_params}')

In [None]:
# =============================================================================
# Train or test model
# =============================================================================

# Number of epochs between validations
train_iteration = 0
log_every = 100

# Create optimiser
optimiser = torch.optim.Adam(model.parameters(),
                         args_learning_rate,
                         weight_decay=args_weight_decay)

# Run the training loop, maintaining the best objective value
best_nll = np.inf

epochs = len(data_train)

start_time = time.time()
for epoch in range(30): # range(epochs):

    print('\nEpoch: {}/{}'.format(epoch+1, epochs))

    if False: # epoch % args_validate_every == 0:

        valid_epoch = data_val[epoch // args_validate_every]

        # Compute negative log-likelihood on validation data
        val_nll, _,  = validate(valid_epoch,
                                model,
                                device,
                                None,
                                latent_model)

#         # Log information to tensorboard
#         writer.add_scalar('True data log-lik.',
#                           -true_nll,
#                           epoch)

#         # Log information to tensorboard
#         writer.add_scalar('Validation log-lik.',
#                           -val_nll,
#                           epoch)

        # Update the best objective value and checkpoint the model
        is_best, best_obj = (True, val_nll) if val_nll < best_nll else \
                            (False, best_nll)


    train_epoch = data_train[epoch]

    # Compute training negative log-likelihood
    train_iteration = train(train_epoch,
                            model,
                            optimiser,
                            log_every,
                            device,
                            None,
                            train_iteration)

#     save_checkpoint(working_directory,
#                     {'epoch'         : epoch + 1,
#                      'state_dict'    : model.state_dict(),
#                      'best_acc_top1' : best_obj,
#                      'optimizer'     : optimiser.state_dict()},
#                     is_best=is_best,
#                     epoch=epoch)

In [None]:
e_idx = 0
i_idx = 3
b_idx = 0

train_epoch = data_train[e_idx]

x_min = torch.min(train_epoch[i_idx]['x_target'][b_idx, :])
x_max = torch.max(train_epoch[i_idx]['x_target'][b_idx, :])
x_plot = torch.linspace(x_min-3., x_max+3., 500)[None, :, None].to(device)
x_plot = x_plot.repeat(train_epoch[i_idx]['x_context'].shape[0], 1, 1).to(device)

max_ctx = train_epoch[i_idx]['x_context'].shape[1]

samples = model.sample(train_epoch[i_idx]['x_context'][:, :, None].to(device),
                       train_epoch[i_idx]['y_context'][:, 0, :, None].to(device) / 100 + 1e-2,
                       x_plot,
                       num_samples=100,
                       noiseless=False,
                       double=True)

plt.figure(figsize=(12, 6))
plt.plot(x_plot[0, :, 0].cpu().detach().numpy(),
         samples[:, b_idx, :].cpu().detach().numpy().T,
         color='green',
         alpha=0.02,
         zorder=1)

plt.plot(x_plot[0, :, 0].cpu().detach().numpy(),
         torch.mean(samples[:, b_idx, :], axis=0).cpu().detach().numpy().T,
         color='black',
         alpha=0.5,
         zorder=2)

plt.scatter(train_epoch[i_idx]['x_context'][b_idx, :].detach().numpy(),
            train_epoch[i_idx]['y_context'][b_idx, 0, :].detach().numpy() / 100,
            color='black',
            zorder=2)

# plt.scatter(train_epoch[i_idx]['x_context'][b_idx, max_ctx:].detach().numpy(),
#             train_epoch[i_idx]['y_context'][b_idx, 0, max_ctx:].detach().numpy() / 100,
#             color='red',
#             zorder=2)

plt.scatter(train_epoch[i_idx]['x_target'][b_idx, :].detach().numpy(),
            train_epoch[i_idx]['y_target'][b_idx, 0, :].detach().numpy() / 100,
            color='red',
            zorder=2)
plt.ylim([0, 10])
plt.show()