In [1]:
# setup
import os
os.environ["KERAS_BACKEND"] = "torch"
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"


import keras

if keras.backend.backend() == "torch":
    import torch
    print("Use torch backend")
    torch.autograd.set_grad_enabled(False)

import sys
sys.path.append("../")

import bayesflow as bf
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

Use torch backend


  torch.logspace(


In [2]:
from bayesflow.simulators.simulator import Simulator
# from bayesflow.types import Shape, Tensor
from torch import Tensor
from torch.distributions import Distribution
import torch.nn as nn
from typing import Callable

In [70]:
class MyGenericSimulator(Simulator):
    def __init__(self, context_sampler: Callable, prior_sampler: Callable, tau_sampler: Callable, design_generator: nn.Module, simulator_var: dict):
        self.context_sampler = context_sampler
        self.prior_sampler = prior_sampler
        self.tau_sampler = tau_sampler
        self.design_generator = design_generator
        self.simulator_var = simulator_var

    def sample(self, batch_size: torch.Size, context: Tensor = None, params : Tensor = None, tau : int = None, **kwargs) -> dict[str, Tensor]:

        if context is None and params is None and tau is None:
            context = self.context_sampler(batch_size)
            params = self.prior_sampler.sample(context)
            tau = self.tau_sampler()
        
        designs = []
        outcomes = []

        for t in range(tau):
            xi = self.design_generator(batch_size)

            # if params.shape[0] != xi.shape[0]: # for initial design
            #     xi = xi.repeat(params.shape[0], 1)

            y = self.outcome_simulator(params=params, xi=xi, simulator_var = self.simulator_var)

            designs.append(xi)
            outcomes.append(y)

        designs = torch.stack(designs, dim=1).unsqueeze(-1)  #  [B, tau, 1]
        outcomes = torch.stack(outcomes, dim=1) # [B, tau, 1]
        n_obs = torch.sqrt(tau).repeat(batch_size).unsqueeze(1) # [B, 1]

        out = {"context": context, "params": params, "n_obs": n_obs, "designs": designs, "outcomes": outcomes}

        return out
    
    def outcome_simulator(self, params: Tensor, xi: Tensor) -> Tensor:
        raise NotImplementedError

In [4]:
class RandomDesign(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, batch_size: torch.Size, designs: [Tensor] = None, outcomes: [Tensor] = None) -> Tensor:
        return torch.rand(batch_size)

In [68]:
class LikelihoodBasedModel(MyGenericSimulator):
    def __init__(self, context_sampler, prior_sampler, tau_sampler, design_generator, simulator_var) -> None:
        super().__init__(context_sampler, prior_sampler, tau_sampler, design_generator, simulator_var)

    def outcome_likelihood(self, params: Tensor, xi: Tensor) -> Distribution:
        raise NotImplementedError
    
    def outcome_simulator(self, params: Tensor, xi: Tensor) -> Tensor:
        return self.outcome_likelihood(params, xi, self.simulator_var).sample()
    
    def approximate_log_marginal_likelihood(self, context: Tensor, params: Tensor, xi: Tensor, outcomes: Tensor, log_approx_posterior: bf.networks) -> Tensor:

        possible_masks = self.context_sampler.possible_masks
        M = possible_masks.shape[0]

        marginal_likelihood = []

        for m in range(len(M)):
            index = context[context == possible_masks[m]]
            params_m = params[index]; xi_m = xi[index]; outcomes_m = outcomes[index]
            first_term = self.outcome_likelihood(params_m, xi_m, self.simulator_var).log_prob(outcomes_m).sum()
            second_term = self.prior_sampler.log_prob(params_m).sum()
            third_term = log_approx_posterior.log_prob(params_m, outcomes_m).sum()

            marginal_likelihood_m = torch.exp(first_term + second_term - third_term)
            marginal_likelihood.append(marginal_likelihood_m)

        marginal_likelihood = torch.stack(marginal_likelihood, dim = 0) / torch.stack(marginal_likelihood, dim = 0).sum()
        

        return marginal_likelihood

In [6]:
class PolynomialRegression(LikelihoodBasedModel):
    def __init__(self, context_sampler, prior_sampler, tau_sampler, design_generator, simulator_var) -> None:
        super().__init__(context_sampler, prior_sampler, tau_sampler, design_generator, simulator_var)

    def outcome_likelihood(self, params: Tensor, xi: Tensor, simulator_var: dict) -> Distribution:

        xi_powers = torch.stack([torch.ones_like(xi), xi, xi ** 2, xi ** 3], dim=1)
        mean = torch.sum(params * xi_powers, dim=-1, keepdim=True)
        sigma = simulator_var["sigma"]
        return torch.distributions.Normal(mean, sigma)
    
    def analytical_log_marginal_likelihood(outcomes, params: Tensor, param_mask: Tensor) -> Tensor:
        raise NotImplementedError # TODO

In [7]:
class ParameterMask:
    def __init__(self, num_parameters: int = 4, possible_masks: Tensor = None) -> None:
        default_mask = torch.tril(torch.ones((num_parameters, num_parameters)))
        self.num_parameters = num_parameters
        self.possible_masks = torch.tensor(possible_masks, dtype=torch.float32) if possible_masks is not None else default_mask

    def __call__(self, batch_shape: torch.Size) -> Tensor:
        index_samples = torch.randint(0, self.possible_masks.shape[0], batch_shape, dtype=torch.long)
        out_mask = self.possible_masks[index_samples]

        return out_mask

In [8]:
class Prior():
    def __init__(self) -> None:
        super().__init__()
    
    def dist_list(self, param_mask: Tensor) -> [Distribution]:
        raise NotImplementedError

    def sample(self, param_mask: Tensor) -> Tensor:
        return torch.stack([dist.sample() for dist in self.dist_list(param_mask)], dim = 0)

    def log_prob(self, params: Tensor, param_mask: Tensor) -> Tensor:
        return torch.stack([dist.log_prob(param) for dist, param in zip(self.dist_list(param_mask), params)], dim = 0)

In [9]:
class PriorPolynomialReg(Prior):
    def __init__(self, delta: Tensor = Tensor([0.1])) -> None:
        super().__init__()
        self.delta = delta

    def dist_list(self, param_mask: Tensor) -> [Distribution]:
        super().__init__()
        
        self.param_mask = param_mask

        default = Tensor([[0, self.delta]])
        param_mask_unsq = param_mask.unsqueeze(-1)

        prior_0 = torch.where(param_mask_unsq[:, 0] == 1, Tensor([5, 2]), default)
        prior_1 = torch.where(param_mask_unsq[:, 1] == 1, Tensor([3, 1]), default)
        prior_2 = torch.where(param_mask_unsq[:, 2] == 1, Tensor([0, 0.8]), default)
        prior_3 = torch.where(param_mask_unsq[:, 3] == 1, Tensor([0, 0.5]), default)

        hyper_params = torch.stack([prior_0, prior_1, prior_2, prior_3], dim=1)

        mean_s = hyper_params[:, :, 0]
        sigma_s = hyper_params[:, :, 1]
    
        dist_list = [torch.distributions.MultivariateNormal(mean, scale_tril=torch.diag(sigma)) for mean, sigma in zip(mean_s, sigma_s)]

        return dist_list

In [10]:
B = 64
batch_size = torch.Size([B])

param_mask_generator = ParameterMask()
param_mask = param_mask_generator(batch_size)
polynomial_reg = PriorPolynomialReg()
params = polynomial_reg.sample(param_mask)
likelihood = polynomial_reg.log_prob(params, param_mask)

In [11]:
print(f"Shape of parameter mask {param_mask.shape}") # [B, model_dim]
print(f"Shape of parameters {params.shape}") # [B, param_dim]
print(f"Shape of likelihood {likelihood.shape}") # [B]

Shape of parameter mask torch.Size([64, 4])
Shape of parameters torch.Size([64, 4])
Shape of likelihood torch.Size([64])


In [78]:
parameter_mask = ParameterMask()
random_design_generator = RandomDesign()
prior = PriorPolynomialReg()

In [75]:
class random_num_obs():
    def __init__(self, min_obs : int = 1, max_obs : int = T) -> Tensor:
        self.min_obs = min_obs
        self.max_obs = max_obs # T

    def __call__(self):
        return torch.randint(self.min_obs, self.max_obs + 1, (1,))

In [76]:
T = 10
random_num_obs = random_num_obs(min_obs = 1, max_obs = T)

In [77]:
polynomial_reg = PolynomialRegression(context_sampler = parameter_mask,
                                      prior_sampler = prior,
                                      tau_sampler = random_num_obs,
                                      design_generator = random_design_generator,
                                      simulator_var = {"sigma": 1.0})

In [15]:
out = polynomial_reg.sample(batch_size)

In [16]:
out.keys()

dict_keys(['context', 'params', 'n_obs', 'designs', 'outcomes'])

In [17]:
class MyDataSet(keras.utils.PyDataset):
    def __init__(self, batch_size: torch.Size, stage: int, initial_generative_model: MyGenericSimulator, design_network: nn.Module = None):
        super().__init__()

        self.batch_size = batch_size
        self.stage = stage # stage 1,2,3
        self.initial_generative_model = initial_generative_model
        self.design_network = design_network

    def __getitem__(self, item:int) -> dict[str, Tensor]:
        if self.stage == 1:

            data = self.initial_generative_model.sample(self.batch_size)
            return data

        if self.stage == 2:
            second_generative_model = 1
            data = self.second_generative_model.sampel(self.batch_size, )
            return data

        if self.stage == 3:
            ...
            return data
    
    @property
    def num_batches(self):
        # infinite dataset
        return None

In [18]:
dataset = MyDataSet(batch_size = batch_size, stage = 1, initial_generative_model = polynomial_reg)

In [39]:
inference_network = bf.networks.CouplingFlow(depth = 8, subnet_kwargs=dict(kernel_regularizer=None, dropout_prob = False))

In [40]:
summary_network = bf.networks.DeepSet(summary_dim = 10)

In [56]:
approximator = bf.Approximator(
    inference_network = inference_network,
    summary_network = summary_network,
    inference_variables = ["params"],
    inference_conditions = ["context", "n_obs"],
    summary_variables = ["outcomes", "designs"]
)

approximator.compile(optimizer="AdamW")

In [57]:
approximator.fit(dataset, epochs=10, steps_per_epoch=10)

Epoch 1/10
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 2s/step - inference/loss: 15.8874 - loss: 15.8874 - summary/loss: 0.0000e+00
Epoch 2/10
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 2s/step - inference/loss: 7.0853 - loss: 7.0853 - summary/loss: 0.0000e+00
Epoch 3/10
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 2s/step - inference/loss: 4.9676 - loss: 4.9676 - summary/loss: 0.0000e+00
Epoch 4/10
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m20s[0m 2s/step - inference/loss: 4.0515 - loss: 4.0515 - summary/loss: 0.0000e+00
Epoch 5/10
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m20s[0m 2s/step - inference/loss: 3.1545 - loss: 3.1545 - summary/loss: 0.0000e+00
Epoch 6/10
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 2s/step - inference/loss: 3.2678 - loss: 3.2678 - summary/loss: 0.0000e+00
Epoch 7/10
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m20s[0m 2s/step - infere

<keras.src.callbacks.history.History at 0x155f879d0>

In [65]:
approximator.configurator.configure_summary_variables({"outcomes": Tensor([0]), "designs": Tensor([0])})

tensor([0., 0.], device='mps:0')

In [67]:
#outcomes = torch.zeros((1, 2))
#approximator.sample(batch_size, data = {"outcomes": Tensor([0]), "designs": Tensor([0])})

In [406]:
# TODO at last

class InferenceDesignApproximator:
    def __init__(self, hyperparameters: dict, bf_settings: dict, design_settings: dict):

        self.summary_network = ...

        # Dataset object: online dataset https://github.com/stefanradev93/BayesFlow/blob/streamlined-backend/bayesflow/datasets/online_dataset.py
        self.dataset: MyDataSet

        # BayesFlow approximator is encapsuled
        self.bf_approximator = bf.approximators.Approximator(..., summary_network, **bf_settings)
        
        # Design network object is encapsuled
        self.design_net = DesignNetwork(..., summary_network, **design_settings)

        # Hyperparameters: weight terms to balance losses, etc.
        self.hyperparameters = hyperparameters

    def train(self, dataset):
        # Stage 1: Train Bayesflow, use random design
        self.bf_approximator.train(dataset)
        self.dataset.stage = 2

        # Stage 2: Fix BayesFlow, train design network
        self.bf_approximator.freeze_weights() # implement this
        self.design_approximator.train(dataset)
        self.dataset.stage = 3

        # Stage 3: Joint training

In [54]:
class DeepAdaptiveDesign(nn.Module):
  def __init__(
      self,
      encoder_net: nn.Module | bf.networks.CouplingFlow,
      decoder_net: nn.Module,
      design_shape: torch.Size
    ) -> None:
    super().__init__()
    self.design_shape = design_shape
    # initialise first design with random normal
    self.register_parameter(
        "initial_design",
        nn.Parameter(0.1 * torch.ones(design_shape, dtype=torch.float32))
    )
    self.encoder_net = encoder_net
    self.decoder_net = decoder_net

  def forward(self, designs=list[Tensor], outcomes=list[Tensor]) -> Tensor:
    if len(outcomes) == 0:
      return self.initial_design
    else:
      # embed design-outcome pairs
      embeddings = torch.cat([self.encoder_net(xi, y) for (xi, y) in zip(designs, outcomes)]) # TODO aggregate [, dim]
      # get next design
      next_design = self.decoder_net(embeddings)
    return next_design

In [55]:
class EmitterNetwork(nn.Module):
    def __init__(
        self,
        input_dim,
        hidden_dim,
        output_dim,
        n_hidden_layers=2,
        activation=nn.Softplus,
    ):
        super().__init__()
        self.activation_layer = activation()
        self.input_layer = nn.Linear(input_dim, hidden_dim)
        if n_hidden_layers > 1:
            self.middle = nn.Sequential(
                *[
                    nn.Sequential(nn.Linear(hidden_dim, hidden_dim), activation())
                    for _ in range(n_hidden_layers - 1)
                ]
            )
        else:
            self.middle = nn.Identity()
        self.output_layer = nn.Linear(hidden_dim, output_dim)

    def forward(self, r):
        x = self.input_layer(r)
        x = self.activation_layer(x)
        x = self.middle(x)
        x = self.output_layer(x)
        return x

In [None]:
class MutualInformation(nn.Module):
  def __init__(self, joint_model, batch_size: int) -> None:
    super().__init__()
    self.joint_model = joint_model
    self.batch_size = batch_size

  def forward(self) -> Tensor:
    raise NotImplemented

  def estimate(num_eval_samples) -> float:
    raise NotImplemented

class NestedMonteCarlo(MutualInformation):
  def __init__(
      self,
      joint_model: LikelihoodBasedModel,
      amortized_posterior: bf.networks,
      batch_size: torch.Size,
      num_negative_samples: int,
      lower_bound: bool = True
      ) -> None:
    super().__init__(joint_model=joint_model, batch_size=batch_size)
    self.num_negative_samples = num_negative_samples # L
    self.lower_bound = lower_bound

  def forward(self) -> Tensor:

    # simulate history
    context_h, params_h, tau, xi_h, y_h = self.joint_model.sample(self.batch_size).values()

    post_model_prob = self.joint_model.approximate_log_marginal_likelihood(context_h, params_h, xi_h, y_h)

    context = np.random.choice(self.joint_model.context_sampler.possible_masks, 
                               size = self.batch_size, p = post_model_prob)
    
    prior_samples_primary = self.amortized_posterior.sample(context) # TODO

    n_obs = self.joint_model.tau_sampler.max_obs - tau

    _, _, _, designs, outcomes = self.joint_model.sample(self.batch_size, context = context, params = params, n_obs = n_obs)

    # we can resuse negative samples
    prior_samples_negative = self.amortized_posterior.sample(
        torch.Size([self.num_negative_samples])
    ).unsqueeze(1) # [num_neg_samples, ...] -> [num_neg_samples, 1, ...]

    # evaluate the logprob of outcomes under the primary:
    logprob_primary = torch.stack([
        self.joint_model.outcome_likelihood(
            prior_samples_primary, xi
        ).log_prob(y) for (xi, y) in zip(designs, outcomes)
    ], dim=0).sum(0) # [T, B] -> [B]

    # evaluate the logprob of outcomes under the contrastive parameter samples:
    logprob_negative = torch.stack([
        self.joint_model.outcome_likelihood(
            prior_samples_negative, xi.unsqueeze(0) # add dim for <num_neg_samples>
        ).log_prob(y.unsqueeze(0)) for (xi, y) in zip(designs, outcomes)
    ], dim=0).sum(0) # [T, num_neg_samples, B] -> [num_neg_samples, B]

    print("nagative param", prior_samples_negative.shape)
    print("one design", designs[0].unsqueeze(0).shape)
    print("one outcomes", outcomes[0].unsqueeze(0).shape)

    # if lower bound, log_prob primary should be added to the denominator
    if self.lower_bound:
      # concat primary and negative to get [negative_b + 1, B] for the logsumexp
      logprob_negative = torch.cat([
          logprob_negative, logprob_primary.unsqueeze(0)]
      ) # [num_neg_samples + 1, B]
      to_logmeanexp = torch.log(self.num_negative_samples + 1)
    else:
      to_logmeanexp = torch.log(self.num_negative_samples)

    log_denom = torch.logsumexp(logprob_negative, dim=0) - to_logmeanexp # [B]
    mi = (logprob_primary - log_denom).mean(0) # [B] -> scalar
    return -mi

  def estimate(self) -> float:
    with torch.no_grad():
      loss = self.forward()
    return -loss.item()