In [2]:
# setup
import os
os.environ["KERAS_BACKEND"] = "torch"
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

import keras

if keras.backend.backend() == "torch":
    import torch
    print("Use torch backend")
    torch.autograd.set_grad_enabled(False)

import sys
sys.path.append("../")

import bayesflow as bf
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

Use torch backend


  torch.logspace(


In [3]:
from bayesflow.simulators.simulator import Simulator
# from bayesflow.types import Shape, Tensor
from torch import Tensor
from torch.distributions import Distribution
import torch.nn as nn
from typing import Callable

In [4]:
from custom_simulators import MyGenericSimulator, LikelihoodBasedModel, ParameterMask, Prior, RandomNumObs
from design_networks import RandomDesign, DeepAdaptiveDesign
from design_loss import NestedMonteCarlo

In [5]:
class PolynomialRegression(LikelihoodBasedModel):
    def __init__(self, mask_sampler, prior_sampler, tau_sampler, design_generator, simulator_var) -> None:
        super().__init__(mask_sampler, prior_sampler, tau_sampler, design_generator, simulator_var)

    def outcome_likelihood(self, params: Tensor, xi: Tensor, simulator_var: dict) -> Distribution:

        xi_powers = torch.stack([torch.ones_like(xi), xi, xi ** 2, xi ** 3], dim=1)
        mean = torch.sum(params * xi_powers, dim=-1, keepdim=True)
        sigma = simulator_var["sigma"]
        return torch.distributions.Normal(mean, sigma)
    
    def analytical_log_marginal_likelihood(outcomes, params: Tensor, masks: Tensor) -> Tensor:
        raise NotImplementedError # TODO

In [None]:
class PriorPolynomialReg(Prior):
    def __init__(self, delta: Tensor = Tensor([0.1])) -> None:
        super().__init__()
        self.delta = delta

    def dist(self, masks: Tensor) -> [Distribution]:
        super().__init__()
        
        self.masks = masks

        default = Tensor([[0, self.delta]])
        masks_ = masks.unsqueeze(-1)

        prior_0 = torch.where(masks_[:, 0] == 1, Tensor([5, 2]), default)
        prior_1 = torch.where(masks_[:, 1] == 1, Tensor([3, 1]), default)
        prior_2 = torch.where(masks_[:, 2] == 1, Tensor([0, 0.8]), default)
        prior_3 = torch.where(masks_[:, 3] == 1, Tensor([0, 0.5]), default)

        hyper_params = torch.stack([prior_0, prior_1, prior_2, prior_3], dim=1)

        means = hyper_params[:, :, 0]
        sds = hyper_params[:, :, 1]
    
        dist = torch.distributions.MultivariateNormal(means, scale_tril=torch.stack([torch.diag(sd) for sd in sds]))

        return dist

In [None]:
B = 64
batch_size = torch.Size([B])

mask_sampler = ParameterMask()
masks = mask_sampler(batch_size)
prior_sampler = PriorPolynomialReg()
params = prior_sampler.sample(masks)
likelihood = prior_sampler.log_prob(params, masks)

In [None]:
print(f"Shape of parameter mask {masks.shape}") # [B, model_dim]
print(f"Shape of parameters {params.shape}") # [B, param_dim]
print(f"Shape of likelihood {likelihood.shape}") # [B]

In [None]:
random_design_generator = RandomDesign()

In [None]:
T = 10
random_num_obs = RandomNumObs(min_obs = 1, max_obs = T)

In [34]:
polynomial_reg = PolynomialRegression(mask_sampler = mask_sampler,
                                      prior_sampler = prior_sampler,
                                      tau_sampler = random_num_obs,
                                      design_generator = random_design_generator,
                                      simulator_var = {"sigma": 1.0})

In [None]:
class MyDataSet(keras.utils.PyDataset):
    def __init__(self, batch_size: torch.Size, stage: int, initial_generative_model: MyGenericSimulator, design_network: nn.Module = None):
        super().__init__()

        self.batch_size = batch_size
        self.stage = stage # stage 1,2,3
        self.initial_generative_model = initial_generative_model
        self.design_network = design_network

    def __getitem__(self, item:int) -> dict[str, Tensor]:
        if self.stage == 1:

            data = self.initial_generative_model.sample(self.batch_size)
            return data

        if self.stage == 2:
            second_generative_model = 1
            data = self.second_generative_model.sampel(self.batch_size, )
            return data

        if self.stage == 3:
            ...
            return data
    
    @property
    def num_batches(self):
        # infinite dataset
        return None

In [None]:
dataset = MyDataSet(batch_size = batch_size, stage = 1, initial_generative_model = polynomial_reg)

In [42]:
# inference_network = bf.networks.FlowMatching(depth = 8, subnet_kwargs=dict(kernel_regularizer=None, dropout_prob = False))

ValueError: Unrecognized keyword arguments passed to FlowMatching: {'depth': 8}

In [44]:
inference_network = bf.networks.FlowMatching()

In [46]:
summary_network = bf.networks.DeepSet(summary_dim = 10)

In [47]:
approximator = bf.Approximator(
    inference_network = inference_network,
    summary_network = summary_network,
    inference_variables = ["params"],
    inference_conditions = ["masks", "n_obs"],
    summary_variables = ["outcomes", "designs"]
)

approximator.compile(optimizer="AdamW")

In [48]:
approximator.fit(dataset, epochs=1, steps_per_epoch=10)

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 714ms/step - inference/loss: 4.6162 - loss: 4.6162 - summary/loss: 0.0000e+00


<keras.src.callbacks.history.History at 0x14f5c9150>

In [35]:
out = polynomial_reg.sample(batch_size)
# ml = polynomial_reg.approximate_log_marginal_likelihood(masks, params, designs, outcomes, approximator)

In [None]:
designs.shape

In [37]:
out["params"]

tensor([[ 6.7436e+00,  4.3102e+00,  1.1993e-01,  9.0188e-02],
        [ 7.1472e+00,  3.6448e+00, -3.2375e-01,  5.1362e-01],
        [ 4.2119e+00,  1.7138e+00,  5.9136e-02,  1.1370e-02],
        [ 6.1270e+00,  4.9598e+00,  2.7261e-01,  7.6893e-02],
        [ 3.3783e+00,  3.1083e+00, -5.6946e-01,  1.1533e+00],
        [ 5.7629e+00,  3.2712e+00,  8.4440e-01, -1.3402e-01],
        [ 5.2291e+00, -1.1635e-02, -2.0583e-02, -1.3015e-02],
        [ 7.6896e-01,  2.2936e+00,  2.0385e-01,  1.8806e-03],
        [ 7.0878e+00,  1.4856e+00,  1.0000e-02,  8.7278e-02],
        [ 6.3467e+00,  2.7240e-03, -3.1471e-02, -7.7249e-02],
        [ 4.3338e+00,  6.2948e-02, -3.7126e-02, -7.8729e-02],
        [ 4.9342e+00, -1.0691e-03,  7.4798e-02, -1.8127e-02],
        [ 2.3736e+00,  3.0084e+00, -1.9020e-01,  2.0153e-01],
        [ 6.1500e+00,  4.0420e+00,  1.0929e+00,  5.5368e-02],
        [ 1.7408e+00,  1.1703e-01,  1.2565e-01, -5.2945e-02],
        [ 6.0542e+00,  2.8295e+00,  4.2144e-02,  5.0379e-02],
        

In [None]:
approximator.sample(batch_size, data = {"outcomes": Tensor([0]), "designs": Tensor([0])})

In [None]:
# TODO at last

class InferenceDesignApproximator:
    def __init__(self, hyperparameters: dict, bf_settings: dict, design_settings: dict):

        self.summary_network = ...

        # Dataset object: online dataset https://github.com/stefanradev93/BayesFlow/blob/streamlined-backend/bayesflow/datasets/online_dataset.py
        self.dataset: MyDataSet

        # BayesFlow approximator is encapsuled
        self.bf_approximator = bf.approximators.Approximator(..., summary_network, **bf_settings)
        
        # Design network object is encapsuled
        self.design_net = DesignNetwork(..., summary_network, **design_settings)

        # Hyperparameters: weight terms to balance losses, etc.
        self.hyperparameters = hyperparameters

    def train(self, dataset):
        # Stage 1: Train Bayesflow, use random design
        self.bf_approximator.train(dataset)
        self.dataset.stage = 2

        # Stage 2: Fix BayesFlow, train design network
        self.bf_approximator.freeze_weights() # implement this
        self.design_approximator.train(dataset)
        self.dataset.stage = 3

        # Stage 3: Joint training

In [49]:
num_test_samples = 64
num_posterior_samples = 500

test_dataset = dataset.__getitem__(0)

samples = approximator.sample(batch_shape=(num_test_samples, num_posterior_samples), data=test_dataset)

In [53]:
samples.keys()

dict_keys(['params', 'summaries'])