In [1]:
# setup
import os
os.environ["KERAS_BACKEND"] = "torch"
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

import keras

if keras.backend.backend() == "torch":
    import torch
    print("Use torch backend")
    torch.autograd.set_grad_enabled(False)

import sys
sys.path.append("../")

import bayesflow as bf
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

Use torch backend


  torch.logspace(


In [2]:
from bayesflow.simulators.simulator import Simulator
# from bayesflow.types import Shape, Tensor
from torch import Tensor
from torch.distributions import Distribution
import torch.nn as nn

In [3]:
from custom_simulators import MyGenericSimulator, LikelihoodBasedModel, ParameterMask, Prior, RandomNumObs
from design_networks import RandomDesign, DeepAdaptiveDesign, EmitterNetwork
from design_loss import NestedMonteCarlo

In [4]:
class PolynomialRegression(LikelihoodBasedModel):
    def __init__(self, mask_sampler, prior_sampler, tau_sampler, design_generator, simulator_var) -> None:
        super().__init__(mask_sampler, prior_sampler, tau_sampler, design_generator, simulator_var)

    def outcome_likelihood(self, params: Tensor, xi: Tensor, simulator_var: dict) -> Distribution:

        xi_powers = torch.stack([torch.ones_like(xi), xi, xi ** 2, xi ** 3], dim=1)
        mean = torch.sum(params * xi_powers, dim=-1, keepdim=True)
        sigma = simulator_var["sigma"]
        return torch.distributions.Normal(mean, sigma)
    
    def analytical_log_marginal_likelihood(outcomes, params: Tensor, masks: Tensor) -> Tensor:
        raise NotImplementedError # TODO

In [5]:
class PriorPolynomialReg(Prior):
    def __init__(self, delta: Tensor = Tensor([0.1])) -> None:
        super().__init__()
        self.delta = delta

    def dist(self, masks: Tensor) -> [Distribution]:
        super().__init__()
        
        self.masks = masks

        default = Tensor([[0, self.delta]])
        masks_ = masks.unsqueeze(-1)

        prior_0 = torch.where(masks_[:, 0] == 1, Tensor([5, 2]), default)
        prior_1 = torch.where(masks_[:, 1] == 1, Tensor([3, 1]), default)
        prior_2 = torch.where(masks_[:, 2] == 1, Tensor([0, 0.8]), default)
        prior_3 = torch.where(masks_[:, 3] == 1, Tensor([0, 0.5]), default)

        hyper_params = torch.stack([prior_0, prior_1, prior_2, prior_3], dim=1)

        means = hyper_params[:, :, 0]
        sds = hyper_params[:, :, 1]
    
        dist = torch.distributions.MultivariateNormal(means, scale_tril=torch.stack([torch.diag(sd) for sd in sds]))

        return dist

In [6]:
B = 64
batch_size = torch.Size([B])

mask_sampler = ParameterMask()
masks = mask_sampler(batch_size)
prior_sampler = PriorPolynomialReg()
params = prior_sampler.sample(masks)
likelihood = prior_sampler.log_prob(params, masks)

In [7]:
print(f"Shape of parameter mask {masks.shape}") # [B, model_dim]
print(f"Shape of parameters {params.shape}") # [B, param_dim]
print(f"Shape of likelihood {likelihood.shape}") # [B]

Shape of parameter mask torch.Size([64, 4])
Shape of parameters torch.Size([64, 4])
Shape of likelihood torch.Size([64])


In [8]:
random_design_generator = RandomDesign()

In [9]:
T = 10
random_num_obs = RandomNumObs(min_obs = 1, max_obs = T)

In [10]:
polynomial_reg = PolynomialRegression(mask_sampler = mask_sampler,
                                      prior_sampler = prior_sampler,
                                      tau_sampler = random_num_obs,
                                      design_generator = random_design_generator,
                                      simulator_var = {"sigma": 1.0})

In [11]:
class MyDataSet(keras.utils.PyDataset):
    def __init__(self, batch_size: torch.Size, stage: int, initial_generative_model: MyGenericSimulator, design_network: nn.Module = None):
        super().__init__()

        self.batch_size = batch_size
        self.stage = stage # stage 1,2,3
        self.initial_generative_model = initial_generative_model
        self.design_network = design_network

    def __getitem__(self, item:int) -> dict[str, Tensor]:
        if self.stage == 1:

            data = self.initial_generative_model.sample(self.batch_size)
            return data

        if self.stage == 2:
            second_generative_model = 1
            data = self.second_generative_model.sampel(self.batch_size, )
            return data

        if self.stage == 3:
            ...
            return data
    
    @property
    def num_batches(self):
        # infinite dataset
        return None

In [12]:
dataset = MyDataSet(batch_size = batch_size, stage = 1, initial_generative_model = polynomial_reg)

In [13]:
# inference_network = bf.networks.CouplingFlow(depth = 8, subnet_kwargs=dict(kernel_regularizer=None, dropout_prob = False))

In [14]:
inference_network = bf.networks.FlowMatching()

In [15]:
summary_network = bf.networks.DeepSet(summary_dim = 10)

In [16]:
approximator = bf.Approximator(
    inference_network = inference_network,
    summary_network = summary_network,
    inference_variables = ["params"],
    inference_conditions = ["masks", "n_obs"],
    summary_variables = ["outcomes", "designs"]
)

approximator.compile(optimizer="AdamW")

In [17]:
approximator.fit(dataset, epochs=1, steps_per_epoch=10)

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 731ms/step - inference/loss: 4.7153 - loss: 4.7153 - summary/loss: 0.0000e+00


<keras.src.callbacks.history.History at 0x12a5c3a90>

In [18]:
# TODO at last

class InferenceDesignApproximator:
    def __init__(self, hyperparameters: dict, bf_settings: dict, design_settings: dict):

        self.summary_network = ...

        # Dataset object: online dataset https://github.com/stefanradev93/BayesFlow/blob/streamlined-backend/bayesflow/datasets/online_dataset.py
        self.dataset: MyDataSet

        # BayesFlow approximator is encapsuled
        self.bf_approximator = bf.approximators.Approximator(..., summary_network, **bf_settings)
        
        # Design network object is encapsuled
        self.design_net = DesignNetwork(..., summary_network, **design_settings)

        # Hyperparameters: weight terms to balance losses, etc.
        self.hyperparameters = hyperparameters

    def train(self, dataset):
        # Stage 1: Train Bayesflow, use random design
        self.bf_approximator.train(dataset)
        self.dataset.stage = 2

        # Stage 2: Fix BayesFlow, train design network
        self.bf_approximator.freeze_weights() # implement this
        self.design_approximator.train(dataset)
        self.dataset.stage = 3

        # Stage 3: Joint training

In [19]:
num_test_samples = 64
num_posterior_samples = 500

test_sims = dataset.__getitem__(0)
samples = approximator.sample(batch_shape=(num_test_samples, num_posterior_samples), data=test_sims)

In [20]:
# observed data
out = polynomial_reg.sample(torch.Size([1]))

In [21]:
out["masks"]

tensor([[1., 1., 0., 0.]])

In [22]:
obs_data = {"designs": out["designs"], "outcomes": out["outcomes"], "masks": out["masks"], "n_obs": out["n_obs"]}
post_given_obs  = approximator.sample(batch_shape = (1, num_posterior_samples), data = obs_data)

In [25]:
obs_data["masks"].shape

torch.Size([1, 4])

In [26]:
obs_data = {"designs": out["designs"], "outcomes": out["outcomes"]}

In [36]:
obs_data.values

<function dict.values>

In [27]:
obs_data["outcomes"].shape

torch.Size([1, 9, 1])

In [28]:
from bayesflow.utils import filter_concatenate

filter_concatenate(obs_data, keys=["outcomes", "designs"]).shape

torch.Size([1, 9, 2])

In [29]:
summary_out = summary_network(filter_concatenate(obs_data, keys=["outcomes", "designs"]))

In [29]:
decoder_net = EmitterNetwork(input_dim = 10, hidden_dim = 24, output_dim = 1)

In [30]:
designs_net = DeepAdaptiveDesign(encoder_net=summary_network,
                                 decoder_net = decoder_net,
                                 design_shape = torch.Size([1]), 
                                 summary_variables=["outcomes", "designs"])

In [31]:
summary_batched = torch.randn(64, 10)

In [32]:
summary_batched.shape

torch.Size([64, 10])

In [33]:
decoder_net(summary_batched).shape

torch.Size([64, 1])

In [34]:
tensor = torch.empty((0, 1, 1))

In [36]:
tensor.shape[0]

0

In [40]:
xi = torch.tensor([1])

# Adding two singleton dimensions and repeating
B = 3  # Example size for repetition
expanded_xi = xi.view(1, 1, 1).repeat(B, 1, 1)

print(expanded_xi)

tensor([[[1]],

        [[1]],

        [[1]]])


In [41]:
expanded_xi.shape

torch.Size([3, 1, 1])