/
__init__.py
71 lines (71 loc) · 1.94 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# flake8: noqa
from sbi.utils.conditional_density_utils import (
conditional_corrcoeff,
eval_conditional_density,
extract_and_transform_mog,
)
from sbi.utils.get_nn_models import classifier_nn, likelihood_nn, posterior_nn
from sbi.utils.io import get_data_root, get_log_root, get_project_root
from sbi.utils.kde import KDEWrapper, get_kde
from sbi.utils.plot import conditional_pairplot, pairplot
from sbi.utils.restriction_estimator import RestrictedPrior, RestrictionEstimator
from sbi.utils.sbiutils import (
batched_mixture_mv,
batched_mixture_vmv,
check_dist_class,
check_warn_and_setstate,
clamp_and_warn,
del_entries,
expit,
get_simulations_since_round,
gradient_ascent,
handle_invalid_x,
logit,
mask_sims_from_prior,
mcmc_transform,
mog_log_prob,
standardizing_net,
standardizing_transform,
warn_if_zscoring_changes_data,
warn_on_invalid_x,
warn_on_invalid_x_for_snpec_leakage,
within_support,
x_shape_from_simulation,
match_theta_and_x_batch_shapes,
)
from sbi.utils.torchutils import (
BoxUniform,
assert_all_finite,
cbrt,
create_alternating_binary_mask,
create_mid_split_binary_mask,
create_random_binary_mask,
gaussian_kde_log_eval,
get_num_parameters,
get_temperature,
logabsdet,
merge_leading_dims,
random_orthogonal,
repeat_rows,
searchsorted,
split_leading_dim,
sum_except_batch,
tensor2numpy,
tile,
)
from sbi.utils.typechecks import (
is_bool,
is_int,
is_nonnegative_int,
is_positive_int,
is_power_of_two,
)
from sbi.utils.user_input_checks import (
check_estimator_arg,
process_x,
test_posterior_net_for_multi_d_x,
validate_theta_and_x,
)
from sbi.utils.user_input_checks_utils import MultipleIndependent
from sbi.utils.potentialutils import transformed_potential, pyro_potential_wrapper
from sbi.utils.tensorboard_output import plot_summary, list_all_logs