Skip to content

Commit

Permalink
moving namedtuple() calls to module level in stats to facilitate pick…
Browse files Browse the repository at this point in the history
…ling/multiprocessing (#3268)

* moving namedtuple() calls to module level in stats

* restoring extra newlines between functions
  • Loading branch information
Spaak authored and ColCarroll committed Nov 23, 2018
1 parent ea23b57 commit 6d07591
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions pymc3/stats.py
Expand Up @@ -170,6 +170,8 @@ def logp_vals_point(pt):
points.close()


WAIC_r_pointwise = namedtuple('WAIC_r_pointwise', 'WAIC, WAIC_se, p_WAIC, var_warn, WAIC_i')
WAIC_r = namedtuple('WAIC_r', 'WAIC, WAIC_se, p_WAIC, var_warn')
def waic(trace, model=None, pointwise=False, progressbar=False):
"""Calculate the widely available information criterion, its standard error
and the effective number of parameters of the samples in trace from model.
Expand Down Expand Up @@ -230,13 +232,13 @@ def waic(trace, model=None, pointwise=False, progressbar=False):
please double check the Observed RV in your model to make sure it
returns element-wise logp.
""")
WAIC_r = namedtuple('WAIC_r', 'WAIC, WAIC_se, p_WAIC, var_warn, WAIC_i')
return WAIC_r(waic, waic_se, p_waic, warn_mg, waic_i)
return WAIC_r_pointwise(waic, waic_se, p_waic, warn_mg, waic_i)
else:
WAIC_r = namedtuple('WAIC_r', 'WAIC, WAIC_se, p_WAIC, var_warn')
return WAIC_r(waic, waic_se, p_waic, warn_mg)


LOO_r_pointwise = namedtuple('LOO_r_pointwise', 'LOO, LOO_se, p_LOO, shape_warn, LOO_i')
LOO_r = namedtuple('LOO_r', 'LOO, LOO_se, p_LOO, shape_warn')
def loo(trace, model=None, pointwise=False, reff=None, progressbar=False):
"""Calculates leave-one-out (LOO) cross-validation for out of sample
predictive model fit, following Vehtari et al. (2015). Cross-validation is
Expand Down Expand Up @@ -309,10 +311,8 @@ def loo(trace, model=None, pointwise=False, reff=None, progressbar=False):
please double check the Observed RV in your model to make sure it
returns element-wise logp.
""")
LOO_r = namedtuple('LOO_r', 'LOO, LOO_se, p_LOO, shape_warn, LOO_i')
return LOO_r(loo_lppd, loo_lppd_se, p_loo, warn_mg, loo_lppd_i)
return LOO_r_pointwise(loo_lppd, loo_lppd_se, p_loo, warn_mg, loo_lppd_i)
else:
LOO_r = namedtuple('LOO_r', 'LOO, LOO_se, p_LOO, shape_warn')
return LOO_r(loo_lppd, loo_lppd_se, p_loo, warn_mg)


Expand Down Expand Up @@ -1068,6 +1068,7 @@ def bfmi(trace):
return np.square(np.diff(energy)).mean() / np.var(energy)


r2_r = namedtuple('r2_r', 'r2_median, r2_mean, r2_std')
def r2_score(y_true, y_pred, round_to=2):
R"""R-squared for Bayesian regression models. Only valid for linear models.
http://www.stat.columbia.edu/%7Egelman/research/unpublished/bayes_R2.pdf
Expand Down Expand Up @@ -1099,6 +1100,5 @@ def r2_score(y_true, y_pred, round_to=2):
r2_median = np.around(np.median(r2), round_to)
r2_mean = np.around(np.mean(r2), round_to)
r2_std = np.around(np.std(r2), round_to)
r2_r = namedtuple('r2_r', 'r2_median, r2_mean, r2_std')
return r2_r(r2_median, r2_mean, r2_std)

0 comments on commit 6d07591

Please sign in to comment.