In [1]:
# setup
import os
os.environ["KERAS_BACKEND"] = "torch"
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"


import keras

if keras.backend.backend() == "torch":
    import torch
    print("Use torch backend")
    torch.autograd.set_grad_enabled(False)

import sys
sys.path.append("../")

import bayesflow as bf
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

Use torch backend


  torch.logspace(


In [2]:
from bayesflow.simulators.simulator import Simulator
# from bayesflow.types import Shape, Tensor
from torch import Tensor
from torch.distributions import Distribution
import torch.nn as nn
from typing import Callable

In [3]:
class MyGenericSimulator(Simulator):
    def __init__(self, context_sampler: Callable, prior_sampler: Callable, tau_sampler: Callable, design_generator: nn.Module, simulator_var: dict):
        self.context_sampler = context_sampler
        self.prior_sampler = prior_sampler
        self.tau_sampler = tau_sampler
        self.design_generator = design_generator
        self.simulator_var = simulator_var

    def sample(self, batch_size: torch.Size, **kwargs) -> dict[str, Tensor]:

        context = self.context_sampler(batch_size)
        params = self.prior_sampler.sample(context)
        tau = self.tau_sampler()
        
        designs = []
        outcomes = []

        for t in range(tau):
            xi = self.design_generator(batch_size)

            # if params.shape[0] != xi.shape[0]: # for initial design
            #     xi = xi.repeat(params.shape[0], 1)

            y = self.outcome_simulator(params=params, xi=xi, simulator_var = self.simulator_var)

            designs.append(xi)
            outcomes.append(y)

        designs = torch.stack(designs, dim=1) #  [B, tau]
        outcomes = torch.stack(outcomes, dim=1).squeeze(-1) # [B, tau]
        n_obs = tau.repeat(batch_size).unsqueeze(1) # [B, 1]

        out = {"context": context, "params": params, "n_obs": n_obs, "designs": designs, "outcomes": outcomes}

        return out
    
    def outcome_simulator(self, params: Tensor, xi: Tensor) -> Tensor:
        raise NotImplementedError

In [4]:
class RandomDesign(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, batch_size: torch.Size, designs: [Tensor] = None, outcomes: [Tensor] = None) -> Tensor:
        return torch.rand(batch_size)

In [5]:
class LikelihoodBasedModel(MyGenericSimulator):
    def __init__(self, context_sampler, prior_sampler, tau_sampler, design_generator, simulator_var) -> None:
        super().__init__(context_sampler, prior_sampler, tau_sampler, design_generator, simulator_var)

    def outcome_likelihood(self, params: Tensor, xi: Tensor, simulator_var: dict) -> Distribution:
        raise NotImplementedError
    
    def outcome_simulator(self, params: Tensor, xi: Tensor, simulator_var: dict) -> Tensor:
        return self.outcome_likelihood(params, xi, simulator_var).sample()
    
    def approximate_log_marginal_likelihood(self, params: Tensor, outcomes: Tensor, log_approx_posterior):
        firt_term = self.outcome_likelihood(parms, xi, simulator_var).log_prob(outcomes) 
        second_term = self.prior_sampler.log_prob(params)
        third_term = log_approx_posterior.log_prob(prams, outcomes, xi).mean()




In [6]:
class PolynomialRegression(LikelihoodBasedModel):
    def __init__(self, context_sampler, prior_sampler, tau_sampler, design_generator, simulator_var) -> None:
        super().__init__(context_sampler, prior_sampler, tau_sampler, design_generator, simulator_var)

    def outcome_likelihood(self, params: Tensor, xi: Tensor, simulator_var: dict) -> Distribution:

        xi_powers = torch.stack([torch.ones_like(xi), xi, xi ** 2, xi ** 3], dim=1)
        mean = torch.sum(params * xi_powers, dim=-1, keepdim=True)
        sigma = simulator_var["sigma"]
        return torch.distributions.Normal(mean, sigma)
    
    def analytical_log_marginal_likelihood(outcomes, params: Tensor, param_mask: Tensor) -> Tensor:
        raise NotImplementedError # TODO

In [7]:
class ParameterMask:
    def __init__(self, num_parameters: int = 4, possible_masks: Tensor = None) -> None:
        default_mask = torch.tril(torch.ones((num_parameters, num_parameters)))
        self.num_parameters = num_parameters
        self.possible_masks = torch.tensor(possible_masks, dtype=torch.float32) if possible_masks is not None else default_mask

    def __call__(self, batch_shape: torch.Size) -> Tensor:
        index_samples = torch.randint(0, self.possible_masks.shape[0], batch_shape, dtype=torch.long)
        out_mask = self.possible_masks[index_samples]

        return out_mask

In [8]:
class Prior():
    def __init__(self) -> None:
        super().__init__()
    
    def dist_list(self, param_mask: Tensor) -> [Distribution]:
        raise NotImplementedError

    def sample(self, param_mask: Tensor) -> Tensor:
        return torch.stack([dist.sample() for dist in self.dist_list(param_mask)], dim = 0)

    def log_prob(self, params: Tensor) -> Tensor:
        return torch.stack([dist.log_prob(param) for dist, param in zip(self.dist_list(param_mask), params)], dim = 0)

In [9]:
class PriorPolynomialReg(Prior):
    def __init__(self, delta: Tensor = Tensor([0.1])) -> None:
        super().__init__()
        self.delta = delta

    def dist_list(self, param_mask: Tensor) -> [Distribution]:
        super().__init__()
        
        self.param_mask = param_mask

        default = Tensor([[0, self.delta]])
        param_mask_unsq = param_mask.unsqueeze(-1)

        prior_0 = torch.where(param_mask_unsq[:, 0] == 1, Tensor([5, 2]), default)
        prior_1 = torch.where(param_mask_unsq[:, 1] == 1, Tensor([3, 1]), default)
        prior_2 = torch.where(param_mask_unsq[:, 2] == 1, Tensor([0, 0.8]), default)
        prior_3 = torch.where(param_mask_unsq[:, 3] == 1, Tensor([0, 0.5]), default)

        hyper_params = torch.stack([prior_0, prior_1, prior_2, prior_3], dim=1)

        mean_s = hyper_params[:, :, 0]
        sigma_s = hyper_params[:, :, 1]
    
        dist_list = [torch.distributions.MultivariateNormal(mean, scale_tril=torch.diag(sigma)) for mean, sigma in zip(mean_s, sigma_s)]

        return dist_list

In [10]:
batch_size = torch.Size([20])

param_mask_generator = ParameterMask()
param_mask = param_mask_generator(batch_size)
polynomial_reg = PriorPolynomialReg()
params = polynomial_reg.sample(param_mask)
likelihood = polynomial_reg.log_prob(params)

In [72]:
print(f"Shape of parameter mask {param_mask.shape}") # [B, model_dim]
print(f"Shape of parameters {params.shape}") # [B, param_dim]
print(f"Shape of likelihood {likelihood.shape}") # [B]

Shape of parameter mask torch.Size([20, 4])
Shape of parameters torch.Size([20, 4])
Shape of likelihood torch.Size([20])


In [73]:
parameter_mask = ParameterMask()
random_design_generator = RandomDesign()
prior = PriorPolynomialReg()

In [74]:
T = 10
def random_num_obs(min_obs : int = 1, max_obs : int = T) -> Tensor:
    return torch.randint(min_obs, max_obs + 1, (1,))

In [75]:
polynomial_reg = PolynomialRegression(context_sampler = parameter_mask,
                                      prior_sampler = prior,
                                      tau_sampler = random_num_obs,
                                      design_generator = random_design_generator,
                                      simulator_var = {"sigma": 1.0})

In [76]:
out = polynomial_reg.sample(batch_size)

In [77]:
out.keys()

dict_keys(['context', 'params', 'n_obs', 'designs', 'outcomes'])

In [78]:
class MyDataSet(keras.utils.PyDataset):
    def __init__(self, batch_size: torch.Size, stage: int, initial_generative_model: MyGenericSimulator, design_network: nn.Module = None):
        super().__init__()

        self.batch_size = batch_size
        self.stage = stage # stage 1,2,3
        self.initial_generative_model = initial_generative_model
        self.design_network = design_network

    def __getitem__(self, item:int) -> dict[str, Tensor]:
        if self.stage == 1:

            data = self.initial_generative_model.sample(self.batch_size)
            return data

        if self.stage == 2:
            second_generative_model = 1
            data = self.second_generative_model.sampel(self.batch_size, )
            return data

        if self.stage == 3:
            ...
            return data
    
    @property
    def num_batches(self):
        # infinite dataset
        return None

In [96]:
dataset = MyDataSet(batch_size = batch_size, stage = 1, initial_generative_model = polynomial_reg)

In [106]:
inference_network = bf.networks.CouplingFlow(depth = 8)

In [107]:
summary_network = bf.networks.DeepSet(summary_dim=10)

In [108]:
approximator = bf.Approximator(
    inference_network = inference_network,
    summary_network = summary_network,
    inference_variables = ["params"],
    inference_conditions = ["context", "n_obs"],
    summary_variables = ["outcomes", "designs"]
)

approximator.compile(optimizer="AdamW")

In [111]:
approximator.fit(dataset, epochs=1, steps_per_epoch=1)

ValueError: Input 0 of layer "dense_822" is incompatible with the layer: expected min_ndim=2, found ndim=1. Full shape received: (128,)

In [56]:
# TODO at last

class InferenceDesignApproximator:
    def __init__(self, hyperparameters: dict, bf_settings: dict, design_settings: dict):

        self.summary_network = ...

        # Dataset object: online dataset https://github.com/stefanradev93/BayesFlow/blob/streamlined-backend/bayesflow/datasets/online_dataset.py
        self.dataset: MyDataSet

        # BayesFlow approximator is encapsuled
        self.bf_approximator = bf.approximators.Approximator(..., summary_network, **bf_settings)
        
        # Design network object is encapsuled
        self.design_net = DesignNetwork(..., summary_network, **design_settings)

        # Hyperparameters: weight terms to balance losses, etc.
        self.hyperparameters = hyperparameters

    def train(self, dataset):
        # Stage 1: Train Bayesflow, use random design
        self.bf_approximator.train(dataset)
        self.dataset.stage = 2

        # Stage 2: Fix BayesFlow, train design network
        self.bf_approximator.freeze_weights() # implement this
        self.design_approximator.train(dataset)
        self.dataset.stage = 3

        # Stage 3: Joint training