In [1]:
import sys
sys.path.append("../")

import keras
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from bayesflow.networks import CouplingFlow
from bayesflow.networks import InferenceNetwork
from bayesflow.distributions import DiagonalNormal
from bayesflow import Approximator
from bayesflow import OfflineDataset
from bayesflow import OnlineDataset
from bayesflow.simulators import TwoMoonsSimulator

Find example datasets here: https://github.com/jrmcornish/cif/blob/master/cif/datasets/two_d.py

In [2]:
# def sample_2lines(batch_size, num_batches):
#     samples = batch_size * num_batches
#     x1 = np.empty(samples)
#     x1[:samples//2] = -1.
#     x1[samples//2:] = 1.
#     x1 += 0.01 * (np.random.rand(samples) - .5)
#     x2 = 2 * (np.random.rand(samples) - 0.5)
#     return dict(x1=x1, x2=x2)

# def extract_params(x: dict):
#     z = np.stack([x[key] for key in x.keys()])
#     return z

batch_size = 128
num_batches = 32
# data = sample_2lines(batch_size, num_batches)

simulator = TwoMoonsSimulator()
data = simulator.sample((batch_size * num_batches,))
dataset = OfflineDataset(data, workers=4, batch_size=batch_size)
print("Batch size:", dataset.batch_size)
print([key for key in dataset[0].keys()])


Batch size: 128
['r', 'alpha', 'theta', 'x']


In [3]:
class CINF2(InferenceNetwork):
    def __init__(self, **kwargs):
        super().__init__(base_distribution="normal", **kwargs)
        # Member variables according to paper
        # prior is base dist
        self.p_u_density = CouplingFlow()
        self.q_u_density = CouplingFlow()
        
        self.feature_net = CouplingFlow()
    
    def build(self, xz_shape, conditions_shape):
        super().build(xz_shape)
        
        xz = keras.ops.zeros(xz_shape)
        # xz = keras.random.beta(xz_shape, 2, 2)
        if conditions_shape is None:
            conditions = None
        else:
            conditions = keras.ops.zeros(conditions_shape)
            # conditions = keras.random.beta(conditions_shape, 2, 2)
            
        
        # Build local layers and couplings
        self.p_u_density.build(xz_shape, xz_shape)
        self.q_u_density.build(xz_shape, conditions_shape)
        self.feature_net.build(xz_shape, conditions_shape)
        
        
        self.call(xz, conditions)
    
    def call(self, xz, conditions):
        return self._forward(xz, conditions)
    
    def bijection(self, x):
        # TODO: Make conditional
        # z = keras.ops.log(x) - keras.ops.log(1-x)
        z = keras.ops.sigmoid(x)
        
        # Log-jacobian
        eps = 1e-7
        x_clipped = keras.ops.clip(x, eps, 1-eps)
        z_log_clipped = -keras.ops.log(x_clipped) - keras.ops.log(1-x_clipped)
        z_log_jac = -keras.ops.sum(z_log_clipped, axis=1, keepdims=True)
        z_log_jac = keras.ops.squeeze(z_log_jac)
        
        return z, z_log_jac
    
    def _forward(self, x, conditions):
        return self.elbo(x, conditions)
    
    
    def _inverse(self, x, conditions):
        return self.elbo(x, conditions)
    
    
    
    
    
    def elbo(self, x, conditions):
        # sample_shape = 1 if x.shape[0] is None else x.shape[0]
        # fix_this = 128 # needs to somehow infer batch size generally
        print(x.shape)
        fix_this = keras.ops.shape(x)[0]
        u = self.q_u_density.sample((128,), conditions=conditions)
        log_q_u = self.q_u_density.log_prob(u, conditions=conditions)
        
        # bijection sampling
        z, z_log_jac = self.bijection(x)
        
        # p_u sampling
        log_p_u = self.p_u_density.log_prob(u, conditions=z)
        
        # prior sampling
        log_prior = self.base_distribution.log_prob(z)
        
        # elbo
        log_p = z_log_jac + log_p_u + log_prior
        log_q = log_q_u # missing prior elbo call, which is just zeros
        log_density = log_p - log_q
        
        return z, log_density
    
    
    def compute_metrics(self, data, stage="training"):
        base_metrics = super().compute_metrics(data, stage=stage)
        inference_variables = data["inference_variables"]
        inference_conditions = data.get("inference_conditions")
        
        z, log_density = self(inference_variables, conditions=inference_conditions)
        loss = -keras.ops.mean(log_density)
        return base_metrics | {"loss": loss}
        
        

In [4]:
class CINF(InferenceNetwork):
    def __init__(self, **kwargs):
        super().__init__(base_distribution="normal", **kwargs)
        # Member variables according to nux implementation
        self.feature_net = CouplingFlow() 	 # no conditions
        self.flow = CouplingFlow() 			 # bijective transformer
        self.u_dist = self.base_distribution # Gaussian prior
        self.v_dist = CouplingFlow()		 # conditioned flow / parameterized prior
        
    
    def build(self, xz_shape, conditions_shape):
        super().build(xz_shape)            
        self.feature_net.build(xz_shape)
        self.flow.build(xz_shape, xz_shape)
        self.v_dist.build(xz_shape, xz_shape)
        
    
    def call(self, xz, conditions, inverse=False, **kwargs):
        if inverse:
            return self._inverse(xz, conditions, **kwargs)
        return self._forward(xz, conditions, **kwargs)
    
    
    def _forward(self, x, conditions, density=False, **kwargs):
        # Sample u ~ q(u|phi_x)
        phi_x = self.feature_net(x, conditions=None)
        u, log_qu = self.v_dist(keras.ops.zeros_like(x), conditions=phi_x, inverse=True, density=True)
        
        # Compute z = f(x; phi_u) and p(x|u)
        phi_u = self.feature_net(u, conditions=None)
        z, log_px = self.flow(x, conditions=phi_u, inverse=False, density=True)
        
        # Compute p(u)
        log_pu = self.base_distribution.log_prob(u)
        
        # Log likelihood?
        llc = log_px + log_pu - log_qu
        
        # NOTE - this can be moved up when I'm done tinkering
        if density:
            return z, llc
        return z
    
    
    def _inverse(self, z, conditions, density=False, **kwargs):
        # Sample u ~ p(u)
        u = self.base_distribution.sample(keras.ops.shape(z))
        log_pu = self.base_distribution.log_prob(keras.ops.zeros_like(z))
        
        # Compute inverse of f(z; u)
        phi_u = self.feature_net(u)
        x, log_px = self.flow(z, conditions=phi_u, inverse=True, density=True)
        
        # Predict q(u|x)
        phi_x = self.feature_net(x)
        _, log_qu = self.v_dist(u, conditions=phi_x, inverse=False, density=True)
        
        # Log likelihood?
        llc = log_px + log_pu - log_qu
        
        # NOTE: this can be moved up when I'm done tinkering
        if density:
            return x, llc
        return x
    
    
    def compute_metrics(self, data, stage="training"):
        base_metrics = super().compute_metrics(data, stage=stage)
        inference_variables = data["inference_variables"]
        inference_conditions = data.get("inference_conditions")
        
        z, log_density = self(inference_variables, conditions=inference_conditions, inverse=False, density=True)
        loss = -keras.ops.mean(log_density)
        return base_metrics | {"loss": loss}
        
        

In [5]:
cinf = CINF()
approximator = Approximator(
    inference_network=cinf,
    inference_variables=["theta"],
    inference_conditions=["r", "alpha", "x"]
)
approximator.compile(optimizer="adamw", loss="mse")
approximator.build_from_data(next(iter(dataset)))

In [6]:
# metrics = approximator.evaluate(dataset, return_dict=True)
approximator.fit(dataset, epochs=20)

Epoch 1/20
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m62s[0m 110ms/step - loss: nan
Epoch 2/20
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 114ms/step - loss: nan
Epoch 3/20
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 106ms/step - loss: nan
Epoch 4/20
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 105ms/step - loss: nan
Epoch 5/20
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 100ms/step - loss: nan
Epoch 6/20
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 98ms/step - loss: nan
Epoch 7/20
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 97ms/step - loss: nan
Epoch 8/20
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 94ms/step - loss: nan
Epoch 9/20
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 96ms/step - loss: nan
Epoch 10/20
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 92ms/step - loss: nan
Epoch 11/20
[1m32/32[

<keras.src.callbacks.history.History at 0x13b3c1a93a0>

In [10]:
samples = approximator.sample((128,2), next(iter(dataset)))

InvalidArgumentError: Exception encountered when calling SingleCoupling.call().

[1m{{function_node __wrapped__ConcatV2_N_2_device_/job:localhost/replica:0/task:0/device:CPU:0}} ConcatOp : Ranks of all input tensors should match: shape[0] = [128,2,1] vs. shape[1] = [128,2,2,2] [Op:ConcatV2] name: concat[0m

Arguments received by SingleCoupling.call():
  • x1=tf.Tensor(shape=(128, 2, 1), dtype=float32)
  • x2=tf.Tensor(shape=(128, 2, 1), dtype=float32)
  • conditions=tf.Tensor(shape=(128, 2, 2, 2), dtype=float32)
  • inverse=True
  • kwargs=<class 'inspect._empty'>