In [1]:
import torch
import numpy as np
import time
from sbi import utils as utils
from sbi.inference import SNPE, simulate_for_sbi
from sbi.utils.user_input_checks import (
    check_sbi_inputs,
    process_prior,
    process_simulator,
)

from parsing_functions import save_pickle_data
from simulations_model1 import simulator_sbi

# Dataset creation with simulate_for_sbi

In [2]:
low_tensor = torch.tensor([0,0.1,0])
high_tensor = torch.tensor([4,4,4])

dt = 1e-2
oversampling = 10
prerun = 1e2
Npts = 5e4


def simulator_to_sbi(pars):
    return simulator_sbi(np.array(pars), dt, oversampling, int(prerun),int(Npts))


prior_sbi = utils.BoxUniform(low=low_tensor, high=high_tensor)

# Check prior, return PyTorch prior.
prior, num_parameters, prior_returns_numpy = process_prior(prior_sbi)

# Check simulator, returns PyTorch simulator able to simulate batches.
simulator = process_simulator(simulator_to_sbi, prior, prior_returns_numpy)

# Consistency check after making ready for sbi.
check_sbi_inputs(simulator, prior)

In [3]:
num_simulations=2500
num_workers=10
sim_batch_size=int(num_simulations/(num_workers*10))

#saving string for future saving
specs = f'{num_simulations}sim_{Npts:.0e}np_{dt:.0e}dt_{oversampling:}os_{prerun:.0e}pre'

#to use parallel processing you need python<3.11
start=time.time()
theta, x = simulate_for_sbi(simulator, proposal=prior, num_simulations=num_simulations, 
                            num_workers=num_workers, simulation_batch_size=sim_batch_size)
end=time.time()

#bsize = 10, nworkers=10, ~ 0.62s/it
#new sim with 100 batches ~0.29s/it, same with 10 batches

Running 2500 simulations in 100 batches.:   0%|          | 0/100 [00:00<?, ?it/s]

In [4]:
print('time: ', end-start)
print('mean time per iteration: ', (end-start)/num_simulations)

time:  72.29905271530151
mean time per iteration:  0.028919621086120607


In [5]:
# saving dataset for future use
dataset_name = 'dataset_' + str(specs)
save_dir = 'saved_datasets'
data = {"theta": theta, "x": x}

save_pickle_data(data=data, filename=dataset_name, folder_path=save_dir)

Saved dataset at saved_datasets/dataset_2500sim_5e+04np_1e-02dt_10os_1e+02pre.pickle

