In [1]:
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_VISIBLE_DEVICES=0


In [2]:
import jax
from jax import random
from jax import numpy as jnp

from flax.training import train_state
import optax

from npf.jax.models import CNP, NP, AttnCNP, AttnNP, BNP, AttnBNP, ConvCNP, ConvNP, ConvBNP, NeuBNP
from npf.jax.data import GPSampler, RBFKernel

from tqdm.auto import trange

In [3]:
def get_train_step(model, **kwargs):
    @jax.jit
    def train_step(state, rngs, x_ctx, y_ctx, x_tar, y_tar, mask_ctx, mask_tar):
        def loss_fn(params):
            loss = model.apply(
                params,
                x_ctx=x_ctx,
                y_ctx=y_ctx,
                x_tar=x_tar,
                y_tar=y_tar,
                mask_ctx=mask_ctx,
                mask_tar=mask_tar,
                rngs=rngs,
                method=model.loss,
                **kwargs,
            )
            return loss

        grad_fn = jax.value_and_grad(loss_fn)
        loss, grads = grad_fn(state.params)
        state = state.apply_gradients(grads=grads)
        return state, loss
    return train_step

In [4]:
def get_eval_step(model, **kwargs):
    @jax.jit
    def eval_step(state, rngs, x_ctx, y_ctx, x_tar, y_tar, mask_ctx, mask_tar):
        loss = model.apply(
            state.params,
            x_ctx=x_ctx,
            y_ctx=y_ctx,
            x_tar=x_tar,
            y_tar=y_tar,
            mask_ctx=mask_ctx,
            mask_tar=mask_tar,
            rngs=rngs,
            method=model.loss,
            **kwargs,
        )
        return loss
    return eval_step

In [5]:
x_ctx    = jnp.ones((1, 3, 1))
y_ctx    = jnp.ones((1, 3, 1))
x_tar    = jnp.ones((1, 4, 1))
y_tar    = jnp.ones((1, 4, 1))
mask_ctx = jnp.ones(3)
mask_tar = jnp.ones(4)

In [6]:
def init_model(model):
    key = random.PRNGKey(0)
    key, params_init_key, sample_init_key = random.split(key, 3)

    params = model.init(dict(
        params=params_init_key,
        sample=sample_init_key,
    ), x_ctx, y_ctx, x_tar, mask_ctx, mask_tar)

    tx = optax.adam(learning_rate=5e-4)
    state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)
    return state

In [7]:
def train_model(key, model, state, num_steps=10000, **kwargs):
    train_step = get_train_step(model, **kwargs)
    eval_step  = get_eval_step(model, **kwargs)
    
    for i in trange(1, num_steps + 1):
        key, model_key, data_key = random.split(key, 3)
        batch = sampler.sample(data_key, batch_size=256)
        state, loss = train_step(
            state,
            dict(sample=model_key),
            x_ctx=batch.x_ctx,
            y_ctx=batch.y_ctx,
            x_tar=batch.x,
            y_tar=batch.y,
            mask_ctx=batch.mask_ctx,
            mask_tar=batch.mask,
        )

        if i % 100 == 0:
            batch = sampler.sample(data_key, batch_size=512)
            loss_ctx = eval_step(
                state,
                dict(sample=model_key),
                x_ctx=batch.x_ctx,
                y_ctx=batch.y_ctx,
                x_tar=batch.x_ctx,
                y_tar=batch.y_ctx,
                mask_ctx=batch.mask_ctx,
                mask_tar=batch.mask_ctx,
            )
            loss_tar = eval_step(
                state,
                dict(sample=model_key),
                x_ctx=batch.x_ctx,
                y_ctx=batch.y_ctx,
                x_tar=batch.x_tar,
                y_tar=batch.y_tar,
                mask_ctx=batch.mask_ctx,
                mask_tar=batch.mask_tar,
            )

            print(f"Step {i}/{num_steps}  CTX Loss: {loss_ctx:7.4f}  TAR Loss: {loss_tar:7.4f}  Loss: {loss:7.4f}")
    return state

In [8]:
sampler = GPSampler(RBFKernel())

In [11]:
model = NeuBNP(y_dim=1)
state = init_model(model)
state = train_model(random.PRNGKey(0), model, state, num_samples=5)

  0%|          | 0/10000 [00:00<?, ?it/s]

Step 100/10000  CTX Loss:  0.1170  TAR Loss:     nan  Loss:  0.2498
Step 200/10000  CTX Loss:  0.0794  TAR Loss:     nan  Loss:  0.4106
Step 300/10000  CTX Loss:  0.0351  TAR Loss:     nan  Loss:  0.1544
Step 400/10000  CTX Loss: -0.0637  TAR Loss:     nan  Loss: -0.0805
Step 500/10000  CTX Loss: -0.0801  TAR Loss:     nan  Loss:  0.0002
Step 600/10000  CTX Loss: -0.4965  TAR Loss:     nan  Loss:  0.3843
Step 700/10000  CTX Loss: -0.2920  TAR Loss:     nan  Loss:  0.1538
Step 800/10000  CTX Loss: -0.1858  TAR Loss:     nan  Loss: -0.1333
Step 900/10000  CTX Loss: -0.7575  TAR Loss:     nan  Loss:  0.1944
Step 1000/10000  CTX Loss: -0.6376  TAR Loss:     nan  Loss:  0.0547
Step 1100/10000  CTX Loss: -0.3592  TAR Loss:     nan  Loss: -0.3612
Step 1200/10000  CTX Loss: -0.5350  TAR Loss:     nan  Loss: -0.3505
Step 1300/10000  CTX Loss: -0.5403  TAR Loss:     nan  Loss: -0.2933
Step 1400/10000  CTX Loss: -0.5543  TAR Loss:     nan  Loss: -0.1712
Step 1500/10000  CTX Loss: -0.4446  TAR Los

In [10]:
raise Exception("Barrier")

Exception: Barrier

In [None]:
model = CNP(y_dim=1)
state = init_model(model)
train_model(random.PRNGKey(0), model, state)

  0%|          | 0/10000 [00:00<?, ?it/s]

Step 100/10000  CTX Loss:  0.4053  TAR Loss:  0.4900  Loss:  0.4532
Step 200/10000  CTX Loss:  0.4009  TAR Loss:  0.5270  Loss:  0.5206
Step 300/10000  CTX Loss:  0.3657  TAR Loss:  0.4229  Loss:  0.4524
Step 400/10000  CTX Loss:  0.2599  TAR Loss:  0.3327  Loss:  0.2098
Step 500/10000  CTX Loss:  0.4103  TAR Loss:  0.4571  Loss:  0.4368
Step 600/10000  CTX Loss:  0.0409  TAR Loss:  0.7083  Loss:  0.5128
Step 700/10000  CTX Loss:  0.2515  TAR Loss:  0.4890  Loss:  0.3600
Step 800/10000  CTX Loss:  0.2026  TAR Loss:  0.3145  Loss:  0.1958
Step 900/10000  CTX Loss: -0.1382  TAR Loss:  0.9736  Loss:  0.5885
Step 1000/10000  CTX Loss:  0.0310  TAR Loss:  0.4535  Loss:  0.2546
Step 1100/10000  CTX Loss:  0.0718  TAR Loss:  0.1600  Loss:  0.0417
Step 1200/10000  CTX Loss:  0.0706  TAR Loss:  0.3105  Loss:  0.0547
Step 1300/10000  CTX Loss: -0.0155  TAR Loss:  0.1775  Loss:  0.0668
Step 1400/10000  CTX Loss:  0.0235  TAR Loss:  0.2288  Loss:  0.0757
Step 1500/10000  CTX Loss:  0.0228  TAR Los

In [None]:
model = AttnCNP(y_dim=1)
state = init_model(model)
train_model(random.PRNGKey(0), model, state)

  0%|          | 0/10000 [00:00<?, ?it/s]

Step 100/10000  CTX Loss:  0.2974  TAR Loss:  0.3893  Loss:  0.3388
Step 200/10000  CTX Loss:  0.3066  TAR Loss:  0.4439  Loss:  0.4393
Step 300/10000  CTX Loss:  0.1301  TAR Loss:  0.1761  Loss:  0.2319
Step 400/10000  CTX Loss:  0.0410  TAR Loss:  0.1388  Loss:  0.0253
Step 500/10000  CTX Loss:  0.0654  TAR Loss:  0.1876  Loss:  0.1367
Step 600/10000  CTX Loss: -0.1792  TAR Loss:  0.6703  Loss:  0.5355
Step 700/10000  CTX Loss: -0.0340  TAR Loss:  0.3742  Loss:  0.1759
Step 800/10000  CTX Loss: -0.0941  TAR Loss:  0.0882  Loss: -0.0866
Step 900/10000  CTX Loss: -0.2952  TAR Loss:  0.7256  Loss:  0.3284
Step 1000/10000  CTX Loss: -0.1652  TAR Loss:  0.4289  Loss:  0.1359
Step 1100/10000  CTX Loss: -0.2417  TAR Loss: -0.0859  Loss: -0.2607
Step 1200/10000  CTX Loss: -0.2318  TAR Loss:  0.1554  Loss: -0.2341
Step 1300/10000  CTX Loss: -0.3414  TAR Loss: -0.0756  Loss: -0.2280
Step 1400/10000  CTX Loss: -0.2661  TAR Loss:  0.0234  Loss: -0.1891
Step 1500/10000  CTX Loss: -0.2967  TAR Los

In [None]:
model = BNP(y_dim=1)
state = init_model(model)
train_model(random.PRNGKey(0), model, state, num_samples=5)

  0%|          | 0/10000 [00:00<?, ?it/s]

Step 100/10000  CTX Loss:  0.8429  TAR Loss:  1.0419  Loss:  0.9270
Step 200/10000  CTX Loss:  0.7931  TAR Loss:  1.0553  Loss:  1.0319
Step 300/10000  CTX Loss:  0.7532  TAR Loss:  0.8729  Loss:  0.9421
Step 400/10000  CTX Loss:  0.5245  TAR Loss:  0.6489  Loss:  0.4179
Step 500/10000  CTX Loss:  0.7572  TAR Loss:  0.8536  Loss:  0.8438
Step 600/10000  CTX Loss:  0.0820  TAR Loss:  1.4975  Loss:  1.0546
Step 700/10000  CTX Loss:  0.4992  TAR Loss:  0.9769  Loss:  0.7214
Step 800/10000  CTX Loss:  0.4232  TAR Loss:  0.6541  Loss:  0.3870
Step 900/10000  CTX Loss: -0.2599  TAR Loss:  1.8279  Loss:  1.1455
Step 1000/10000  CTX Loss:  0.0978  TAR Loss:  0.9172  Loss:  0.5394
Step 1100/10000  CTX Loss:  0.1285  TAR Loss:  0.3221  Loss:  0.0639
Step 1200/10000  CTX Loss:  0.1346  TAR Loss:  0.6226  Loss:  0.0904
Step 1300/10000  CTX Loss: -0.0260  TAR Loss:  0.3658  Loss:  0.1177
Step 1400/10000  CTX Loss:  0.0169  TAR Loss:  0.4342  Loss:  0.1349
Step 1500/10000  CTX Loss:  0.0245  TAR Los

In [None]:
model = ConvCNP(y_dim=1, x_min=-2, x_max=2)
state = init_model(model)
train_model(random.PRNGKey(0), model, state)

In [None]:
model = ConvBNP(y_dim=1, x_min=-2, x_max=2)
state = init_model(model)
state = train_model(random.PRNGKey(0), model, state, num_samples=5)