Skip to content

Commit

Permalink
add test for r2_score (#2729)
Browse files Browse the repository at this point in the history
* add test for r2_score

* add change to release-notes
  • Loading branch information
aloctavodia authored and Junpeng Lao committed Nov 23, 2017
1 parent 1b1caa6 commit 5d30236
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 6 deletions.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Expand Up @@ -11,6 +11,7 @@
- Fixed `compareplot` to use `loo` output.
- Add test for `model.logp_array` and `model.bijection` (#2724)
- Fixed `sample_ppc` and `sample_ppc_w` to iterate all chains(#2633)
- Add test for `stats.r2_score` (#2729)



Expand Down
9 changes: 4 additions & 5 deletions pymc3/stats.py
Expand Up @@ -995,14 +995,13 @@ def r2_score(y_true, y_pred, round_to=2):
if y_true.ndim > 1:
dimension = 1

e = y_true - y_pred
var_y_est = np.var(y_pred, dimension)
var_e = np.var(e, dimension)
var_y_est = np.var(y_pred, axis=dimension)
var_e = np.var(y_true - y_pred, axis=dimension)

r2 = var_y_est / (var_y_est + var_e)
r2_median = np.around(np.median(r2), round_to)
r2_mean = np.around(np.mean(r2), round_to)
r2_std = np.around(np.std(r2), round_to)
R2_r = namedtuple('R2_r', 'R2_median, R2_mean, R2_std')
return R2_r(r2_median, r2_mean, r2_std)
r2_r = namedtuple('r2_r', 'r2_median, r2_mean, r2_std')
return r2_r(r2_median, r2_mean, r2_std)

11 changes: 10 additions & 1 deletion pymc3/tests/test_stats.py
Expand Up @@ -6,7 +6,8 @@
from .helpers import SeededTest
from ..tests import backend_fixtures as bf
from ..backends import ndarray
from ..stats import summary, autocorr, hpd, mc_error, quantiles, make_indices, bfmi
from ..stats import (summary, autocorr, hpd, mc_error, quantiles, make_indices,
bfmi, r2_score)
from ..theanof import floatX_array
import pymc3.stats as pmstats
from numpy.random import random, normal
Expand Down Expand Up @@ -276,6 +277,14 @@ def test_bfmi(self):

assert_almost_equal(bfmi(trace), 0.8)

def test_r2_score(self):
x = np.linspace(0, 1, 100)
y = np.random.normal(x, 1)
res = st.linregress(x, y)
assert_almost_equal(res.rvalue ** 2,
r2_score(y, res.intercept +
res.slope * x).r2_median,
2)

class TestDfSummary(bf.ModelBackendSampledTestCase):
backend = ndarray.NDArray
Expand Down

0 comments on commit 5d30236

Please sign in to comment.