In [2]:
import numpy as np
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp

import equinox as eqx
import optax

import numpyro.distributions as dist

from typing import List, Tuple
from jaxtyping import Int, Array, Float, PyTree

In [3]:
# Defining the MLP model, to be used as a batch MLP
# Inherit the eqx.Module class

class MLP(eqx.Module):
    
    layers: List
        
    def __init__(
        self, 
        layer_sizes: List, 
        key: jax.random.PRNGKey,
    ):
        
        self.layers = []
        
        for (feat_in, feat_out) in zip(layer_sizes[:-2], layer_sizes[1:-1]):
            key, subkey = jax.random.split(key)
            
            self.layers.append(
                eqx.nn.Linear(feat_in, feat_out, use_bias=True, key=subkey)
            )  # fully-connected layer
            self.layers.append(
                jnp.tanh
            )  # activation function
        
        key, subkey = jax.random.split(key)
        
        self.layers.append(
            eqx.nn.Linear(layer_sizes[-2], layer_sizes[-1], use_bias=True, key=subkey)
        )  # final layer
    
    # __call__ turns an instance of this class into a callable object, which behaves like a function
    def __call__(
        self, 
        x: Float[Array, "1 1"],
    ) -> Float[Array, "1 1"]:
        
        # apply each layer in sequence
        for layer in self.layers:
            x = layer(x)

        return x

In [None]:
class CNP(eqx.Module):

    layers: list

    def __init__(
            self,
            encoder: eqx.Module,
            decoder: eqx.Module,
    ):
        
        self.layers = [encoder, decoder]


    def __call__(
            self,
            x_context: Float[Array, "batch n_context 1"],
            y_context: Float[Array, "batch n_context 1"],
            x_target: Float[Array, "batch n_target 1"],
    ) -> dist.Distribution:

        # get number of target points
        _, n_target, _ = x_target.shape

        # encoder step
        encoded_rep = self._encode(x_context, y_context)  # (batch_size, 1, encoder_dim)

        # tile sample before passing to the decoder
        representation = self._tile(encoded_rep, n_target)  # (batch_size, n_target, encoder_dim)

        # decoder step to produce distribution of functions
        distribution = self._decode(representation, x_target)

        return distribution
    

    def _encode(
            self,
            x_context: Float[Array, "batch n_context 1"], 
            y_context: Float[Array, "batch n_context 1"],
    ) -> dist.Distribution:
        
        xy_context = jnp.concatenate([x_context, y_context], axis=-1)  # (batch_size, n_context, 2)
        
        return jnp.mean(self._encode_mlp(xy_context), axis=1, keepdims=True)  # (batch_size, 1, encoder_dim)
    

    def _encode_mlp(
            self,
            xy_context: Float[Array, "batch n_context 2"],
    ) -> Float[Array, "batch n_context encoder_dim"]:
        
        return jax.vmap(jax.vmap(self.layers[0]))(xy_context)
    

    def _tile(
            self,
            z_latent: Float[Array, "batch 1 encoder_dim"],
            n_target: Int,
    ) -> Float[Array, "batch n_target encoder_dim"]:     

        return jnp.tile(z_latent, [1, n_target, 1])
    

    def _decode(
            self,
            representation: Float[Array, "batch n_target encoder_dim"],
            x_target: Float[Array, "batch n_target 1"],
    ) -> dist.Distribution:
        
        representation = jnp.concatenate([representation, x_target], axis=-1)  # (batch_size, n_target, encoder_dim + 1)

        mlp_out = jax.vmap(jax.vmap(self.layers[1]))(representation)  # (batch_size, n_target, 1)

        mu, sigma = jnp.split(mlp_out, 2, axis=-1)  # each (batch_size, n_target, 1)

        sigma = 0.1 + 0.9 * jax.nn.softplus(sigma)
        
        return dist.Normal(loc=mu, scale=sigma)

In [None]:
def loss(
        model: CNP,
        x_context: Float[Array, "batch n_context 1"],
        y_context: Float[Array, "batch n_context 1"],
        x_target: Float[Array, "batch n_context 1"],
        y_target: Float[Array, "batch n_context 1"],
) -> Float[Array, ""]:
    
    distribution = model(x_context, y_context, x_target, y_target)

    return np.mean(distribution.log_prob(y_target))