In [None]:
import torch
import torch.nn as nn

import numpy as np
import matplotlib.pyplot as plt

from utils import sample_datasets_from_gps, plot_samples_and_predictions
from torch.distributions import MultivariateNormal

In [None]:
# =============================================================================
# Fully Connected Neural Network
# =============================================================================


class FullyConnectedNetwork(nn.Module):
    
    def __init__(self,
                 input_dim,
                 output_dim,
                 hidden_dims,
                 nonlinearity):
        
        super().__init__()
        
        shapes = [input_dim] + hidden_dims + [output_dim]
        shapes = [(s1, s2) for s1, s2 in zip(shapes[:-1], shapes[1:])]
        
        self.W = []
        self.b = []
        self.num_linear = len(hidden_dims) + 1
        
        for shape in shapes:

            W = nn.Parameter(torch.randn(size=shape) / shape[0] ** 0.5)
            b = nn.Parameter(torch.randn(size=shape[1:]))

            self.W.append(W)
            self.b.append(b)
            
        self.W = torch.nn.ParameterList(self.W)
        self.b = torch.nn.ParameterList(self.b)
        
        self.nonlinearity = getattr(nn, nonlinearity)()
        
    
    def forward(self, tensor):
        
        for i, (W, b) in enumerate(zip(self.W, self.b)):
            
            tensor = torch.einsum('...i, ij -> ...j', tensor, W)
                
            tensor = tensor + b[(None,) * (len(tensor.shape) - 1)]
            
            if i < self.num_linear - 1:
                tensor = self.nonlinearity(tensor)
        
        return tensor


    
# =============================================================================
# Gaussian Neural Process
# =============================================================================
    
    
class GaussianNeuralProcess(nn.Module):
    
    def __init__(self,
                 encoder,
                 decoder,
                 log_noise):
        
        
        super().__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        
        self.log_noise = torch.tensor(log_noise)
        self.log_noise = nn.Parameter(self.log_noise)
        
        
    def forward(self,
                ctx_in,
                ctx_out,
                trg_in):
        
        D = ctx_out.shape[-1]
        
        ctx = torch.cat([ctx_in, ctx_out], dim=-1)
        
        theta = self.encoder(ctx)
        theta = torch.mean(theta, dim=1)[:, None, :]
        theta = theta.repeat(1, trg_in.shape[1], 1)
        
        tensor = torch.cat([trg_in, theta], dim=-1)
        tensor = self.decoder(tensor)
        
        mean = tensor[:, :, :1]
        log_noise = tensor[:, :, D:(2*D)]
        
        cov_root = tensor[:, :, (2*D):]
        cov = torch.einsum('bni, bmi -> bnm', cov_root, cov_root) / cov_root.shape[-1]
        
        cov_plus_noise = cov + torch.exp(self.log_noise) * torch.eye(cov.shape[1])[None, ...]
                
        return mean, cov, cov_plus_noise
    
    
    def _loss(self,
              ctx_in,
              ctx_out,
              trg_in,
              trg_out):
        
        mean, cov, cov_plus_noise = self.forward(ctx_in, ctx_out, trg_in)
        
        pred_dist = MultivariateNormal(loc=mean[:, :, 0],
                                       covariance_matrix=cov_plus_noise)
        
        log_prob = pred_dist.log_prob(trg_out[:, :, 0])
        log_prob = torch.mean(log_prob)
        
        return - log_prob
    
    
    def loss(self,
             inputs,
             outputs,
             num_samples):
        
        loss = 0
        
        for i in range(num_samples):
            
            N = np.random.choice(np.arange(1, inputs.shape[1]))
            
            ctx_in = inputs[:, :N]
            ctx_out = outputs[:, :N]
            trg_in = inputs[:, N:]
            trg_out = outputs[:, N:]
            
            loss = loss + self._loss(ctx_in,
                                     ctx_out,
                                     trg_in,
                                     trg_out)
        
        loss = loss / (num_samples * inputs.shape[1])
        
        return loss


    
# =============================================================================
# Translation Equivariant Gaussian Neural Process
# =============================================================================

    
class TranslationEquivariantGaussianNeuralProcess(GaussianNeuralProcess):
    
    def __init__(self,
                 encoder,
                 decoder,
                 log_noise):
        
        
        super().__init__(encoder=encoder,
                         decoder=decoder,
                         log_noise=log_noise)
        
        
    def forward(self,
                ctx_in,
                ctx_out,
                trg_in):
        
        D = ctx_out.shape[-1]
        
        # Context and target inputs
        ctx_in = ctx_in[:, None, :, :]
        trg_in = trg_in[:, :, None, :]
        
        diff = ctx_in - trg_in
        
        ctx_out = ctx_out[:, None, :, :]
        ctx_out = ctx_out.repeat(1, diff.shape[1], 1, 1)
        
        ctx = torch.cat([diff, ctx_out], dim=-1)
        
        tensor = self.encoder(ctx)
        tensor = torch.mean(tensor, dim=2) # (B, T, R)
        
        tensor = self.decoder(tensor)
        
        mean = tensor[:, :, :1]
        cov_root = tensor[:, :, 1:]
        cov = torch.einsum('bni, bmi -> bnm', cov_root, cov_root) / cov_root.shape[-1]
        
        diag_noise = torch.exp(self.log_noise) * torch.eye(cov.shape[1])[None, :, :]
        cov_plus_noise = cov + diag_noise
                
        return mean, cov, cov_plus_noise

In [None]:
class Covariance(nn.Module):
    
    def __init__(self, num_basis_dim, extra_cov_dim):
        
        super().__init__()
        
        self.num_basis_dim = num_basis_dim
        self.extra_cov_dim = extra_cov_dim


class InnerProdCov(Covariance):
    
    def __init__(self, num_basis_dim):
        # Extra dimension to add to the output
        extra_cov_dim = 0
        super().__init__(num_basis_dim, extra_cov_dim)
        
    def forward(self, embeddings):
        # Compute the covariance by taking inner products between embeddings
        basis_emb = embeddings[:, :, :self.num_basis_dim]
        cov = torch.einsum('bni, bmi -> bnm', basis_emb, basis_emb) / self.num_basis_dim
        
        return cov
    
    

class AddNoise(nn.Module):
    def __init__(self, extra_noise_dim):
        
        super().__init__()
        
        self.extra_noise_dim = extra_noise_dim
        
        
class AddHomoNoise(AddNoise):
    def __init__(self):
        # Extra dimension to add to the output
        extra_noise_dim = 0
        super().__init__(extra_noise_dim)

        # Noise Parameters
        self.noise_scale = nn.Parameter(torch.zeros(1), requires_grad=True)
    
    def forward(self, cov, embeddings):
        noise_var = torch.eye(cov.shape[1])[None, ...]
        cov_plus_noise = cov + torch.exp(self.noise_scale) * noise_var
        
        return cov_plus_noise

In [None]:
class GaussianNeuralProcess(nn.Module):
    """Conditional Neural Process Module.

    Args:

    """
    def __init__(self, encoder, decoder, covariance, add_noise):
        
        super().__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        self.covariance = covariance
        self.add_noise = add_noise

    
    def forward(self, x_context, y_context, x_target):
        
        r = self.encoder(x_context, y_context, x_target)
        z = self.decoder(r, x_context, y_context, x_target)
        
        # Produce mean
        mean = z[..., 0:1]
        
        # Produce cov
        embedding = z[..., 1:]
        cov = self.covariance(embedding)
        cov_plus_noise = self.add_noise(cov, embedding)
        
        return mean, cov, cov_plus_noise 
    
    
class StandardFullyConnectedTEGNP(GaussianNeuralProcess):
    
    def __init__(self, covariance, add_noise):
        
        input_dim = 1
        output_dim = 1
        rep_dim = 128
        embedding_dim = output_dim +               \
                        covariance.num_basis_dim + \
                        covariance.extra_cov_dim + \
                        add_noise.extra_noise_dim
        
        encoder = StandardFullyConnectedTEEncoder(input_dim=input_dim,
                                                  output_dim=output_dim,
                                                  rep_dim=rep_dim)
        
        decoder = StandardFullyConnectedTEDecoder(input_dim=input_dim,
                                                  output_dim=output_dim,
                                                  rep_dim=rep_dim,
                                                  embedding_dim=embedding_dim)
        
        super().__init__(encoder=encoder, 
                         decoder=decoder,
                         covariance=covariance,
                         add_noise=add_noise)
    
    
    def _loss(self,
              ctx_in,
              ctx_out,
              trg_in,
              trg_out):
        
        mean, cov, cov_plus_noise = self.forward(ctx_in, ctx_out, trg_in)
        
        pred_dist = MultivariateNormal(loc=mean[:, :, 0],
                                       covariance_matrix=cov_plus_noise)
        
        log_prob = pred_dist.log_prob(trg_out[:, :, 0])
        log_prob = torch.mean(log_prob)
        
        return - log_prob
    
    
    def loss(self,
             inputs,
             outputs,
             num_samples):
        
        loss = 0
        
        for i in range(num_samples):
            
            N = np.random.choice(np.arange(1, inputs.shape[1]))
            
            ctx_in = inputs[:, :N]
            ctx_out = outputs[:, :N]
            trg_in = inputs[:, N:]
            trg_out = outputs[:, N:]
            
            loss = loss + self._loss(ctx_in,
                                     ctx_out,
                                     trg_in,
                                     trg_out)
        
        loss = loss / (num_samples * inputs.shape[1])
        
        return loss
        
        
        
# =============================================================================
# Fully Connected Translation Equivariant Encoder
# =============================================================================


class FullyConnectedTEEncoder(nn.Module):
    
    def __init__(self, deepset):
        
        super().__init__()
        
        self.deepset = deepset
    
    
    def forward(self, x_ctx, y_ctx, x_trg):
        
        assert len(x_ctx.shape) == 3
        assert len(y_ctx.shape) == 3
        assert len(x_trg.shape) == 3
        
        # Compute context input pairwise differences
        x_diff = x_ctx[:, None, :, :] - x_ctx[:, :, None, :]
        
        # Tile context outputs to concatenate with input differences
        y_ctx_tile1 = y_ctx[:, None, :, :].repeat(1, x_diff.shape[1], 1, 1)
        y_ctx_tile2 = y_ctx[:, :, None, :].repeat(1, 1, x_diff.shape[2], 1)
        
        # Concatenate input differences and outputs, to obtain complete context
        ctx = torch.cat([x_diff, y_ctx_tile1, y_ctx_tile2], dim=-1)
        
        # Latent representation of context set -- resulting r has shape (B, C, R)
        r = self.deepset(ctx)
        
        return r
    

# =============================================================================
# Standard Translation Equivariant Encoder
# =============================================================================


class StandardFullyConnectedTEEncoder(FullyConnectedTEEncoder):
    
    def __init__(self,
                 input_dim,
                 output_dim,
                 rep_dim):
        
        # Input dimension of encoder (Din + 2 * Dout)
        element_input_dim = input_dim + 2 * output_dim
        
        # Sizes of hidden layers and nonlinearity type
        # Used for both elementwise and aggregate networks
        hidden_dims = [128, 128]
        nonlinearity = 'ReLU'
        
        # Element network -- in (B, C, C, Din + 2 * Dout), out (B, C, C, R)
        element_network = FullyConnectedNetwork(input_dim=element_input_dim,
                                                output_dim=rep_dim,
                                                hidden_dims=hidden_dims,
                                                nonlinearity=nonlinearity)
        
        # Dimensions to mean over -- in (B, C, C, R), out (B, R)
        aggregation_dims = [1]
        
        # Aggregate network -- in (B, R), out (B, R)
        aggregate_network = FullyConnectedNetwork(input_dim=rep_dim,
                                                  output_dim=rep_dim,
                                                  hidden_dims=hidden_dims,
                                                  nonlinearity=nonlinearity)
        
        # Deepset architecture
        deepset = FullyConnectedDeepSet(element_network,
                                        aggregation_dims,
                                        aggregate_network)
        
        super().__init__(deepset=deepset)
        
        
        
# =============================================================================
# Fully Connected Translation Equivariant Decoder
# =============================================================================


class FullyConnectedTEDecoder(nn.Module):
    
    def __init__(self):
        
        super().__init__()
    
    
    def forward(self, r, x_ctx, y_ctx, x_trg):
        
        """
        r     : (B, C, R)
        x_ctx : (B, C, Din)
        y_ctx : (B, C, Dout)
        x_ctx : (B, T, Din)
        """

        D = y_ctx.shape[-1]
        
        # Context and target inputs
        x_ctx = x_ctx[:, :, None, :]
        x_trg = x_trg[:, None, :, :]
        
        diff = x_ctx - x_trg
        
        y_ctx = y_ctx[:, :, None, :]
        y_ctx = y_ctx.repeat(1, 1, diff.shape[2], 1)
        
        r = r[:, :, None, :].repeat(1, 1, diff.shape[2], 1)
        
        ctx = torch.cat([diff, r, y_ctx], dim=-1) # (B, C, T, Din + R + Dout)
        
        tensor = self.deepset(ctx)
        
        return tensor
        
        

# =============================================================================
# Standard Translation Equivariant Decoder
# =============================================================================


class StandardFullyConnectedTEDecoder(FullyConnectedTEDecoder):
    
    def __init__(self,
                 input_dim,
                 output_dim,
                 rep_dim,
                 embedding_dim):
        
        super().__init__()
        
        # Input dimension of encoder (Din + R)
        element_input_dim = input_dim + output_dim + rep_dim
        
        # Sizes of hidden layers and nonlinearity type
        # Used for both elementwise and aggregate networks
        hidden_dims = [128, 128]
        nonlinearity = 'ReLU'
        
        # Element network -- in (B, C, T, Din + R), out (B, C, T, R)
        element_network = FullyConnectedNetwork(input_dim=element_input_dim,
                                                output_dim=rep_dim,
                                                hidden_dims=hidden_dims,
                                                nonlinearity=nonlinearity)
        
        # Dimensions to mean over -- in (B, C, T, R), out (B, T, R)
        aggregation_dims = [1]
        
        # Aggregate network -- in (B, T, R), out (B, T, E)
        aggregate_network = FullyConnectedNetwork(input_dim=rep_dim,
                                                  output_dim=embedding_dim,
                                                  hidden_dims=hidden_dims,
                                                  nonlinearity=nonlinearity)
        
        # Deepset architecture
        deepset = FullyConnectedDeepSet(element_network,
                                        aggregation_dims,
                                        aggregate_network)
        
        self.deepset = deepset

# Training

In [None]:
torch.set_default_tensor_type(torch.FloatTensor)

torch.manual_seed(0)
np.random.seed(0)

input_dim = 1
output_dim = 1
context_rep_features = 128
embedding_features = 512
log_noise = 0.

encoder_input_dim = input_dim + output_dim
encoder_output_dim = context_rep_features
encoder_hidden_dims = [128, 128]

decoder_input_dim = context_rep_features
decoder_output_dim = output_dim + embedding_features
decoder_hidden_dims = [128, 128]
nonlinearity = 'ReLU'

# encoder_fcn = FullyConnectedNetwork(input_dim=encoder_input_dim,
#                                     output_dim=encoder_output_dim,
#                                     hidden_dims=encoder_hidden_dims,
#                                     nonlinearity=nonlinearity)

# decoder_fcn = FullyConnectedNetwork(input_dim=decoder_input_dim,
#                                     output_dim=decoder_output_dim,
#                                     hidden_dims=decoder_hidden_dims,
#                                     nonlinearity=nonlinearity)

# fc_gnp = TranslationEquivariantGaussianNeuralProcess(encoder=encoder_fcn,
#                                                      decoder=decoder_fcn,
#                                                      log_noise=log_noise)

num_basis_dim = 512


covariance = InnerProdCov(num_basis_dim)
add_noise = AddHomoNoise()

fc_gnp = StandardFullyConnectedTEGNP(covariance=covariance, add_noise=add_noise)


# Dataset parameters
xmin = - 3e0
xmax = 3e0
num_batches = 16
batch_size = 32
plot_batch_size = 32
scale = 1e0
cov_coeff = 1e0
noise_coeff = 1e-1
num_samples = 2

# Training parameters and optimizer
num_train_steps = int(1e5)
lr = 1e-3

optimizer = torch.optim.Adam(fc_gnp.parameters(), lr=lr)

losses = []

for step in range(num_train_steps):
    
    inputs, outputs = sample_datasets_from_gps(xmin,
                                               xmax,
                                               num_batches,
                                               batch_size,
                                               scale,
                                               cov_coeff,
                                               noise_coeff,
                                               as_tensor=True)
    
    loss = fc_gnp.loss(inputs, outputs, num_samples=num_samples)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if step % int(1e2) == 0:
        
        print(loss)
        
        plot_samples_and_predictions(fc_gnp,
                                     xmin,
                                     xmax,
                                     plot_batch_size,
                                     scale,
                                     cov_coeff,
                                     noise_coeff,
                                     step)