In [1]:
import os, sys
from functools import partial
import jax
import jax.numpy as jnp

import haiku as hk

from lfiax.utils.oed import sdm_minebed

import torch
from torch.distributions import Uniform, TransformedDistribution, Distribution
from torch.distributions.transforms import Transform
from torch.distributions import constraints

import distrax

current_dir = os.getcwd()
module_dir = os.path.join(current_dir, 'bmp_simulator')
sys.path.append(module_dir)

from simulate_bmp_torch import bmp_simulator

In [2]:
class LogUniform(Transform):
    """
    Defines a transformation for a log-uniform distribution.
    """
    bijective = True
    sign = +1  # Change to -1 if the transform is decreasing in the interval

    def __init__(self, low, high):
        super().__init__()
        self.low = low
        self.high = high

    def _call(self, x):
        return torch.exp(x * (self.high - self.low) + self.low)

    def _inverse(self, y):
        return (torch.log(y) - self.low) / (self.high - self.low)

    def log_abs_det_jacobian(self, x, y):
        return (self.high - self.low) * x + self.low

    @property
    def domain(self):
        return constraints.interval(0.0, 1.0)

    @property
    def codomain(self):
        return constraints.positive


def make_torch_bmp_prior():
    low = torch.log(torch.tensor(1e-6))
    high = torch.log(torch.tensor(1.0))

    uniform = Uniform(low=torch.tensor(1e-6), high=torch.tensor(1.0))

    log_uniform = TransformedDistribution(uniform, LogUniform(low, high))

    return log_uniform


class MultiLogUniform(Distribution):
    """
    A class that represents multiple independent log-uniform distributions.
    """
    def __init__(self, num_priors):
        super().__init__()
        self.priors = [make_torch_bmp_prior() for _ in range(num_priors)]

    def sample(self, sample_shape=torch.Size()):
        return torch.stack([prior.sample(sample_shape) for prior in self.priors], dim=-1)

num_priors = 2
priors = MultiLogUniform(num_priors)

# Simulator (BMP onestep model) to use
model_size = (1,1,1)
fixed_receptor = True
simulator = partial(
    bmp_simulator, 
    model_size=model_size,
    model='onestep', 
    fixed_receptor=fixed_receptor)


y_obs = None
DATASIZE = 5_000
BATCHSIZE = DATASIZE
BO_MAX_NUM = 1
NN_layers = 1
NN_hidden = 150
design_dims = 1

thetas = priors.sample((DATASIZE,)).numpy()

opt_design, sim_samples, inf_time = sdm_minebed(
    simulator = simulator,
    params = priors,
    y_obs = y_obs,
    DATASIZE = DATASIZE,
    BATCHSIZE = BATCHSIZE,
    BO_MAX_NUM = BO_MAX_NUM, 
    dims = design_dims,
    NN_layers = NN_layers,
    NN_hidden = NN_hidden,
    prior_sims = thetas,
    )

Initialize Probabilistic Model
Start Bayesian Optimisation


In [None]:
inf_time

In [4]:
dir(opt_design)

['LB_type',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_compute_optimal_design',
 '_objective',
 '_reset_model',
 '_reset_optimizer',
 '_reset_scheduler',
 '_reset_weights',
 '_store_data',
 'batch_size',
 'bo_obj',
 'constraints',
 'd_opt',
 'domain',
 'init_opt_state',
 'init_sched_state',
 'ma_window',
 'mine_obj',
 'model',
 'model_states',
 'n_epoch',
 'optimizer',
 'prior',
 'save',
 'save_models',
 'scheduler',
 'simulator',
 'train',
 'train_curves',
 'train_final_model',
 'y_obs']

In [8]:
opt_design.n_epoch

10000

In [3]:
opt_design

array([474.98189133,  78.26579717, 241.45848354, 130.53998316,
       535.79220425, 176.17569778, 453.41530812, 894.35109868,
       102.62779979, 119.87057954])

In [5]:
opt_design

array([806.86810217, 806.20344177, 475.9710831 , 290.4084334 ,
       163.20680804, 169.96205008, 957.72600199, 885.45266693,
        36.69306414,  59.68024322,  97.79581232, 845.49068452,
       879.67875425, 156.48108807, 348.41462141,  90.03640688,
       379.52034936, 523.31585044, 598.11014117, 835.13668931,
       807.90702088, 508.50131256, 537.2414328 , 629.94159156,
       420.85867719, 291.65597374, 588.41196456,  88.90706723,
       792.05306393, 643.10354878, 235.33425007, 559.77737638,
       319.88807603, 827.10032867, 972.78287947, 260.82582545,
       930.61385903, 283.09241928, 734.1811769 , 759.61420321,
       281.323178  , 166.17200807, 455.45930484,  80.49402825,
       436.87749332, 329.23303213, 474.19464043,   7.83761571,
       182.0931305 , 275.84461381,  42.11636713, 628.62024093,
       468.55113461, 904.80197655, 982.5131518 , 531.46076522,
        72.84408505, 911.6526855 , 971.13226328, 516.32879388,
       317.48041928, 445.98051395, 868.15856918, 447.01