# Linear Gaussian

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pytest
import sbi.simulators as simulators
import sbi.utils as utils
import torch
from sbi.inference.snpe.snpe_c import APT
from torch import distributions

# use cpu by default
torch.set_default_tensor_type("torch.FloatTensor")

# seed the simulations
torch.manual_seed(0)

In [None]:
dim, std = 3, 1.0
simulator = simulators.LinearGaussianSimulator(dim=dim, std=std)
prior = distributions.MultivariateNormal(
    loc=torch.zeros(dim), covariance_matrix=torch.eye(dim)
)

In [None]:
true_observation = torch.zeros(dim)

# TODO: fails with unknown kwarg. adapt to sbi new version 
apt = APT(
    simulator=simulator,
    true_observation=true_observation,
    prior=prior,
    num_atoms=-1,
    density_estimator='maf',
    calibration_kernel=None,
    z_score_obs=True,
    use_combined_loss=False,
    train_with_mcmc=False,
    mcmc_method="slice-np",
    summary_net=None,
    retrain_from_scratch_each_round=False,
    discard_prior_samples=False,
)

In [None]:
# run inference
num_rounds, num_simulations_per_round = 2, 500
apt.run_inference(
    num_rounds=num_rounds, num_simulations_per_round=num_simulations_per_round
)

# draw samples from posterior
samples = apt.sample(1000)

In [None]:
samples = apt.sample(2500)
samples = utils.tensor2numpy(samples)
figure = utils.plot_hist_marginals(
    data=samples,
    lims=[-4, 4],
)

In [None]:
log_prob = apt.evaluate(torch.tensor([0.0, 0.0, 0.0]))
print('log probability of origin: ', log_prob)

# SNPE-B

In [None]:
from sbi.inference.snpe.snpe_b import SNPE_B

In [None]:
true_observation = torch.zeros(dim)

snpe_b = SNPE_B(
    simulator=simulator,
    true_observation=true_observation,
    prior=prior,
    density_estimator='maf',
    calibration_kernel=None,
    z_score_obs=True,
    use_combined_loss=False,
    train_with_mcmc=False,
    mcmc_method="slice-np",
    summary_net=None,
    retrain_from_scratch_each_round=False,
    discard_prior_samples=False,
)

In [None]:
# run inference
num_rounds, num_simulations_per_round = 2, 500
snpe_b.run_inference(
    num_rounds=num_rounds, num_simulations_per_round=num_simulations_per_round
)

In [None]:
samples = snpe_b.sample(2500)
samples = utils.tensor2numpy(samples)
figure = utils.plot_hist_marginals(
    data=samples,
    lims=[-4, 4],
)

# SNPE-A