In [1]:
import torch
import numpy as np
import time
from sbi import utils as utils
from sbi.inference import SNPE, simulate_for_sbi
from sbi.utils.user_input_checks import (
    check_sbi_inputs,
    process_prior,
    process_simulator,
)

from parsing_functions import save_pickle_data
#from simulations_model1 import simulator_sbi
from simulations_model2 import simulator_sbi

# Dataset creation with simulate_for_sbi

In [2]:
low_tensor = torch.tensor([0,0.01,0])
high_tensor = torch.tensor([2,2,1])

dt = 1e-2
oversampling = 5
prerun = 1e3
Npts = 5e4


def simulator_to_sbi(pars):
    return simulator_sbi(np.array(pars), dt, oversampling, int(prerun),int(Npts))


prior_sbi = utils.BoxUniform(low=low_tensor, high=high_tensor)

# Check prior, return PyTorch prior.
prior, num_parameters, prior_returns_numpy = process_prior(prior_sbi)

# Check simulator, returns PyTorch simulator able to simulate batches.
simulator = process_simulator(simulator_to_sbi, prior, prior_returns_numpy)

# Consistency check after making ready for sbi.
check_sbi_inputs(simulator, prior)

In [3]:
num_simulations=10000
num_workers=10
sim_batch_size=int(num_simulations/(num_workers*10))

#to use parallel processing you need python<3.11
start=time.time()
theta, x = simulate_for_sbi(simulator, proposal=prior, num_simulations=num_simulations, 
                            num_workers=num_workers, simulation_batch_size=sim_batch_size)
end=time.time()

#bsize = 10, nworkers=10, ~ 0.62s/it
#new sim with 100 batches ~0.29s/it, same with 10 batches

print('mean time per iteration: ', (end-start)/num_simulations)

Running 10000 simulations in 100 batches.:   0%|          | 0/100 [00:00<?, ?it/s]

mean time per iteration:  0.03796813266277313


In [4]:
# saving dataset for future use
save_dir = 'saved_datasets/model2/'
data = {
    'theta': theta,
    'x': x,
    'num_simulations': num_simulations,
    'Npts': Npts,
    'dt': dt,
    'oversampling': oversampling,
    'prerun': prerun,
    'low_tensor': low_tensor,
    'high_tensor': high_tensor,
    'data_type': 'full'  # Indicates the type of the data
}

save_pickle_data(data=data, folder_path=save_dir, prefix = None)

Saved dataset at saved_datasets/model2/dataset_10000sim_5e+04np_1e-02dt_5os_1e+03pre_1.pickle



'saved_datasets/model2/dataset_10000sim_5e+04np_1e-02dt_5os_1e+03pre_1.pickle'

In [5]:
print(x.shape)

torch.Size([10000, 4, 1000])
