In [8]:
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
import pickle
import time

from copy import deepcopy

from cnp.experiment import (
    generate_root,
    WorkingDirectory,
    save_checkpoint,
    log_args
)

from cnp.cnp import (
    StandardPredPreyConvGNP,
    FullConvGNP
)

from cnp.lnp import (
    StandardANP,
    StandardConvNP
)

from cnp.cov import (
    MeanFieldGaussianLayer,
    InnerprodGaussianLayer,
    KvvGaussianLayer,
    ExponentialCopulaLayer
)

from cnp.utils import make_generator

import torch
from torch.distributions import MultivariateNormal
from torch.utils.tensorboard import SummaryWriter


# =============================================================================
# Validation helper
# =============================================================================


def test(data,
         model,
         device,
         args_np_test_samples,
         latent_model):
    
    print('In test')
    
    # Lists for logging model's training NLL and oracle NLL
    nll_list = []
    oracle_nll_list = []
    
    # If training a latent model, set the number of latent samples accordingly
    loss_kwargs = {'num_samples' : args_np_test_samples} if latent_model else \
                  {}
    
    with torch.no_grad():
        
        for step, batch in enumerate(data):

            nll = model.loss(batch['x_context'][:, :, None].to(device),
                             batch['y_context'][:, 0, :, None].to(device) / 100 + 1e-2,
                             batch['x_target'][:, :, None].to(device),
                             batch['y_target'][:, 0, :, None].to(device) / 100 + 1e-2)

            # Scale by the number of target points
            nll_list.append(nll.item() / 100.)

#             mean, var, var_plus_noise = model.means_and_marginals(batch['x_context'][:, :, None].to(device),
#                              batch['y_context'][:, 0, :, None].to(device) / 100 + 1e-2,
#                              batch['x_target'][:, :, None].to(device),
#                              batch['y_target'][:, 0, :, None].to(device) / 100 + 1e-2)
            
#             print(mean.shape)
#             raise Exception

            if step % 100 == 0:
                print(f"Validation neg. log-lik, {step+1}: "
                      f"{np.mean(nll_list):.2f} +/- "
                      f"{np.var(nll_list)**0.5  / (step+1)**0.5:.2f}")

            
    mean_nll = np.mean(nll_list)
    std_nll = np.var(nll_list)**0.5 / np.sqrt(step + 1)
    
    return mean_nll, std_nll

In [9]:
args_train_data = 'sim-pred-prey-100-16-50-100-100-0'
args_test_data = 'sim-pred-prey-100-16-50-100-100-0'
args_seed = 0
args_model = 'convGNP'
args_cov_type = 'meanfield'
args_noise_type = 'hetero'
args_marginal_type = 'identity'
args_np_loss_samples = 16
args_np_test_samples = 512
args_num_basis_dim = 32
args_gpu = 0

# Set seed
np.random.seed(args_seed)
torch.manual_seed(args_seed)

# Set device
if torch.cuda.is_available():
    torch.cuda.set_device(args_gpu)
    
use_cpu = not torch.cuda.is_available() and args_gpu == 0
device = torch.device('cpu') if use_cpu else torch.device('cuda')

root = '../../experiments/predator-prey'

# Working directory for saving results
experiment_name = os.path.join(f'{root}',
                               f'results',
                               f'{args_train_data}',
                               f'models',
                               f'{args_model}',
                               f'{args_cov_type}',
                               f'{args_noise_type}',
                               f'{args_marginal_type}',
                               f'seed-{args_seed}')
working_directory = WorkingDirectory(root=experiment_name)

# Data directory for loading data
data_root = os.path.join(f'{root}',
                         f'simulated-data',
                         f'{args_test_data}')
data_directory = WorkingDirectory(root=data_root)

log_path = f'{root}/logs'
log_filename = f'test-{args_train_data}-{args_test_data}-{args_model}-{args_cov_type}-{args_seed}'
log_directory = WorkingDirectory(root=log_path)

    
# =============================================================================
# Create model
# =============================================================================

cov_types = {
    'meanfield' : MeanFieldGaussianLayer,
    'innerprod' : InnerprodGaussianLayer,
    'kvv'       : KvvGaussianLayer
}

if args_cov_type == 'meanfield':
    output_layer = MeanFieldGaussianLayer()
    
else:
    output_layer = cov_types[args_cov_type](num_embedding=args_num_basis_dim,
                                            noise_type=args_noise_type)

if args_marginal_type == 'exponential':
    output_layer = ExponentialCopulaLayer(gaussian_layer=output_layer,
                                          scale=3.,
                                          device=device)
    
# Create model architecture
if args_model == 'GNP':
    model = StandardGNP(input_dim=1, output_layer=output_layer)
    
elif args_model == 'AGNP':
    model = StandardAGNP(input_dim=1, output_layer=output_layer)
    
elif args_model == 'convGNP':
    model = StandardPredPreyConvGNP(input_dim=1, output_layer=output_layer)

elif args_model == 'FullConvGNP':
    model = FullConvGNP()

elif args_model == 'ANP':
    model = StandardANP(input_dim=1,
                        num_samples=args_np_loss_samples)
    
elif args_model == 'convNP':
    model = StandardConvNP(input_dim=1,
                           num_samples=args_np_loss_samples)
    
else:
    raise ValueError(f'Unknown model {args_model}.')


print(f'{args_model} '
      f'{args_cov_type} '
      f'{args_noise_type} '
      f'{args_marginal_type} '
      f'{args_num_basis_dim}: '
      f'{model.num_params}')

with open(working_directory.file('num_params.txt'), 'w') as f:
    f.write(f'{model.num_params}')
    
# Load model to appropriate device
model = model.to(device)

latent_model = args_model in ['ANP', 'convNP']

# Load model from saved state
load_dict = torch.load(working_directory.file('checkpoint.pth.tar', exists=True))
model.load_state_dict(load_dict['state_dict'])

# =============================================================================
# Load data and validation oracle generator
# =============================================================================
    
file = open(data_directory.file('test-data.pkl'), 'rb')
data_test = pickle.load(file)
file.close()


# =============================================================================
# Train or test model
# =============================================================================
print("Starting testing...")

start_time = time.time()
test_mean_nll, test_std_nll = test(data_test,
                                   model,
                                   args_np_test_samples,
                                   device,
                                   latent_model)
stop_time = time.time()
elapsed_time = stop_time - start_time

print("finished testing.")

file = open(working_directory.file('test_log_likelihood.txt'), 'w')
file.write(str(test_mean_nll))
file.close()

file = open(working_directory.file('test_log_likelihood_standard_error.txt'), 'w')
file.write(str(test_std_nll))
file.close()

file = open(working_directory.file('test_time.txt'), 'w')
file.write(str(elapsed_time))
file.close()

Root: ../../experiments/predator-prey/results/sim-pred-prey-100-16-50-100-100-0/models/convGNP/meanfield/hetero/identity/seed-0
Root: ../../experiments/predator-prey/simulated-data/sim-pred-prey-100-16-50-100-100-0
Root: ../../experiments/predator-prey/logs
convGNP meanfield hetero identity 32: 50701.0
Starting testing...
In test
Validation neg. log-lik, 1: -2.06 +/- 0.00
Validation neg. log-lik, 101: -1.37 +/- 0.07
Validation neg. log-lik, 201: -1.41 +/- 0.05
Validation neg. log-lik, 301: -1.38 +/- 0.04
Validation neg. log-lik, 401: -1.42 +/- 0.03
Validation neg. log-lik, 501: -1.43 +/- 0.03
Validation neg. log-lik, 601: -1.43 +/- 0.03
Validation neg. log-lik, 701: -1.44 +/- 0.02
Validation neg. log-lik, 801: -1.44 +/- 0.02
Validation neg. log-lik, 901: -1.43 +/- 0.02
Validation neg. log-lik, 1001: -1.45 +/- 0.02
finished testing.
