In [None]:
import torch
import torch.nn as nn

import numpy as np
import matplotlib.pyplot as plt

from utils import sample_datasets_from_gps, plot_samples_and_predictions

import sys
sys.path.append('../')
from kernelcnp.fc_gnp import FullyConnectedNetwork, TranslationEquivariantGaussianNeuralProcess

# Training

In [None]:
torch.set_default_tensor_type(torch.FloatTensor)

torch.manual_seed(0)
np.random.seed(0)

input_dim = 1
output_dim = 1
context_rep_features = 128
embedding_features = 512
log_noise = 0.

encoder_input_dim = input_dim + output_dim
encoder_output_dim = context_rep_features
encoder_hidden_dims = [128, 128]

decoder_input_dim = input_dim + context_rep_features
decoder_output_dim = output_dim + embedding_features
decoder_hidden_dims = [128, 128]
nonlinearity = 'ReLU'

encoder_fcn = FullyConnectedNetwork(input_dim=encoder_input_dim,
                                    output_dim=encoder_output_dim,
                                    hidden_dims=encoder_hidden_dims,
                                    nonlinearity=nonlinearity)

decoder_fcn = FullyConnectedNetwork(input_dim=decoder_input_dim,
                                    output_dim=decoder_output_dim,
                                    hidden_dims=decoder_hidden_dims,
                                    nonlinearity=nonlinearity)

fc_gnp = TranslationEquivariantGaussianNeuralProcess(encoder=encoder_fcn,
                                                     decoder=decoder_fcn,
                                                     log_noise=log_noise)



# Dataset parameters
xmin = - 3e0
xmax = 3e0
num_batches = 16
batch_size = 32
plot_batch_size = 16
scale = 1e0
cov_coeff = 1e0
noise_coeff = 1e-1
num_samples = 2

# Training parameters and optimizer
num_train_steps = int(1e5)
lr = 1e-3

optimizer = torch.optim.Adam(fc_gnp.parameters(), lr=lr)

losses = []

for step in range(num_train_steps):
    
    inputs, outputs = sample_datasets_from_gps(xmin,
                                               xmax,
                                               num_batches,
                                               batch_size,
                                               scale,
                                               cov_coeff,
                                               noise_coeff,
                                               as_tensor=True)
    
    loss = fc_gnp.loss(inputs, outputs, num_samples=num_samples)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if step % 1e2 == 0:
        
        print(loss)
        
        plot_samples_and_predictions(fc_gnp,
                                     xmin,
                                     xmax,
                                     plot_batch_size,
                                     scale,
                                     cov_coeff,
                                     noise_coeff,
                                     step)