Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

clean up tests for discrete distributions #3138

Merged
merged 5 commits into from
Dec 15, 2013
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 37 additions & 105 deletions scipy/stats/tests/test_discrete_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
check_var_expect, check_skew_expect, check_kurt_expect,
check_entropy, check_private_entropy, check_edge_support,
check_named_args)

DECIMAL_meanvar = 0 # 1 # was 0
knf = npt.dec.knownfailureif

distdiscrete = [
['bernoulli',(0.3,)],
Expand All @@ -34,31 +33,21 @@

def test_discrete_basic():
for distname, arg in distdiscrete:
distfn = getattr(stats,distname)
distfn = getattr(stats, distname)
np.random.seed(9765456)
rvs = distfn.rvs(size=2000,*arg)
rvs = distfn.rvs(size=2000, *arg)
supp = np.unique(rvs)
m, v = distfn.stats(*arg)
# yield npt.assert_almost_equal(rvs.mean(), m, decimal=4,err_msg='mean')
# yield npt.assert_almost_equal, rvs.mean(), m, 2, 'mean' # does not work
yield check_sample_meanvar, rvs.mean(), m, distname + ' sample mean test'
yield check_sample_meanvar, rvs.var(), v, distname + ' sample var test'
yield check_cdf_ppf, distfn, arg, distname + ' cdf_ppf'
yield check_cdf_ppf2, distfn, arg, supp, distname + ' cdf_ppf'
yield check_pmf_cdf, distfn, arg, distname + ' pmf_cdf'

yield check_oth, distfn, arg, distname + ' oth'
skurt = stats.kurtosis(rvs)
sskew = stats.skew(rvs)
yield check_sample_skew_kurt, distfn, arg, skurt, sskew, \
distname + ' skew_kurt'
yield check_cdf_ppf, distfn, arg, supp, distname + ' cdf_ppf'

yield check_pmf_cdf, distfn, arg, distname
yield check_oth, distfn, arg, supp, distname + ' oth'
yield check_edge_support, distfn, arg

alpha = 0.01
yield check_discrete_chisquare, distfn, arg, rvs, alpha, \
distname + ' chisquare'

yield check_edge_support, distfn, arg

seen = set()
for distname, arg in distdiscrete:
if distname in seen:
Expand All @@ -81,7 +70,6 @@ def test_discrete_basic():


def test_moments():
knf = npt.dec.knownfailureif
for distname, arg in distdiscrete:
distfn = getattr(stats,distname)
m, v, s, k = distfn.stats(*arg, moments='mvsk')
Expand All @@ -97,108 +85,52 @@ def test_moments():
msg = distname + ' fails kurtosis'
yield knf(cond, msg)(check_kurt_expect), distfn, arg, m, v, k, distname

# frozen distr moments
yield check_moment_frozen, distfn, arg, m, 1
yield check_moment_frozen, distfn, arg, v+m*m, 2

@npt.dec.skipif(True)
def test_discrete_private():
# testing private methods mostly for debugging
# some tests might fail by design,
# e.g. incorrect definition of distfn.a and distfn.b
for distname, arg in distdiscrete:
distfn = getattr(stats,distname)
rvs = distfn.rvs(size=10000,*arg)
m,v = distfn.stats(*arg)

yield check_ppf_ppf, distfn, arg
yield check_cdf_ppf_private, distfn, arg, distname
yield check_generic_moment, distfn, arg, m, 1, 3 # last is decimal
yield check_generic_moment, distfn, arg, v+m*m, 2, 3 # last is decimal
yield check_moment_frozen, distfn, arg, m, 1, 3 # last is decimal
yield check_moment_frozen, distfn, arg, v+m*m, 2, 3 # last is decimal


def check_sample_meanvar(sm,m,msg):
if not np.isinf(m):
npt.assert_almost_equal(sm, m, decimal=DECIMAL_meanvar, err_msg=msg +
' - finite moment')
else:
npt.assert_(sm > 10000, msg='infinite moment, sm = ' + str(sm))


def check_sample_var(sm,m,msg):
npt.assert_almost_equal(sm, m, decimal=DECIMAL_meanvar, err_msg=msg + 'var')


def check_cdf_ppf(distfn,arg,msg):
ppf05 = distfn.ppf(0.5,*arg)
cdf05 = distfn.cdf(ppf05,*arg)
npt.assert_almost_equal(distfn.ppf(cdf05-1e-6,*arg),ppf05,
err_msg=msg + 'ppf-cdf-median')
npt.assert_((distfn.ppf(cdf05+1e-4,*arg) > ppf05), msg + 'ppf-cdf-next')


def check_cdf_ppf2(distfn,arg,supp,msg):
npt.assert_array_equal(distfn.ppf(distfn.cdf(supp,*arg),*arg),
def check_cdf_ppf(distfn, arg, supp, msg):
# cdf is a step function, and ppf(q) = min{k : cdf(k) >= q, k integer}
npt.assert_array_equal(distfn.ppf(distfn.cdf(supp, *arg), *arg),
supp, msg + '-roundtrip')
npt.assert_array_equal(distfn.ppf(distfn.cdf(supp,*arg)-1e-8,*arg),
npt.assert_array_equal(distfn.ppf(distfn.cdf(supp, *arg) - 1e-8, *arg),
supp, msg + '-roundtrip')
supp1 = supp[supp < distfn.b]
npt.assert_array_equal(distfn.ppf(distfn.cdf(supp1, *arg) + 1e-8, *arg),
supp1 + distfn.inc, msg + 'ppf-cdf-next')
# -1e-8 could cause an error if pmf < 1e-8


def check_cdf_ppf_private(distfn,arg,msg):
ppf05 = distfn._ppf(0.5,*arg)
cdf05 = distfn.cdf(ppf05,*arg)
npt.assert_almost_equal(distfn._ppf(cdf05-1e-6,*arg),ppf05,
err_msg=msg + '_ppf-cdf-median ')
npt.assert_((distfn._ppf(cdf05+1e-4,*arg) > ppf05), msg + '_ppf-cdf-next')


def check_ppf_ppf(distfn, arg):
npt.assert_(distfn.ppf(0.5,*arg) < np.inf)
ppfs = distfn.ppf([0.5,0.9],*arg)
ppf_s = [distfn._ppf(0.5,*arg), distfn._ppf(0.9,*arg)]
npt.assert_(np.all(ppfs < np.inf))
npt.assert_(ppf_s[0] == distfn.ppf(0.5,*arg))
npt.assert_(ppf_s[1] == distfn.ppf(0.9,*arg))
npt.assert_(ppf_s[0] == ppfs[0])
npt.assert_(ppf_s[1] == ppfs[1])

def check_pmf_cdf(distfn, arg, distname):
startind = np.int(distfn.ppf(0.01, *arg) - 1)
index = list(range(startind, startind + 10))
cdfs, pmfs_cum = distfn.cdf(index,*arg), distfn.pmf(index, *arg).cumsum()

def check_pmf_cdf(distfn, arg, msg):
startind = np.int(distfn._ppf(0.01,*arg)-1)
index = list(range(startind,startind+10))
cdfs = distfn.cdf(index,*arg)
npt.assert_almost_equal(cdfs, distfn.pmf(index, *arg).cumsum() +
cdfs[0] - distfn.pmf(index[0],*arg),
decimal=4, err_msg=msg + 'pmf-cdf')
atol, rtol = 1e-10, 1e-10
if distname == 'skellam': # ncx2 accuracy
atol, rtol = 1e-5, 1e-5
npt.assert_allclose(cdfs - cdfs[0], pmfs_cum - pmfs_cum[0],
atol=atol, rtol=rtol)


def check_generic_moment(distfn, arg, m, k, decim):
npt.assert_almost_equal(distfn.generic_moment(k,*arg), m, decimal=decim,
err_msg=str(distfn) + ' generic moment test')
def check_moment_frozen(distfn, arg, m, k):
npt.assert_allclose(distfn(*arg).moment(k), m,
atol=1e-10, rtol=1e-10)


def check_moment_frozen(distfn, arg, m, k, decim):
npt.assert_almost_equal(distfn(*arg).moment(k), m, decimal=decim,
err_msg=str(distfn) + ' frozen moment test')
def check_oth(distfn, arg, supp, msg):
# checking other methods of distfn
npt.assert_allclose(distfn.sf(supp, *arg), 1. - distfn.cdf(supp, *arg),
atol=1e-10, rtol=1e-10)

q = np.linspace(0.01, 0.99, 20)
npt.assert_allclose(distfn.isf(q, *arg), distfn.ppf(1. - q, *arg),
atol=1e-10, rtol=1e-10)

def check_oth(distfn, arg, msg):
# checking other methods of distfn
meanint = round(float(distfn.stats(*arg)[0])) # closest integer to mean
npt.assert_almost_equal(distfn.sf(meanint, *arg), 1 -
distfn.cdf(meanint, *arg), decimal=8)
median_sf = distfn.isf(0.5, *arg)

npt.assert_(distfn.sf(median_sf - 1, *arg) > 0.5)
npt.assert_(distfn.cdf(median_sf + 1, *arg) > 0.5)
npt.assert_equal(distfn.isf(0.5, *arg), distfn.ppf(0.5, *arg))


def check_sample_skew_kurt(distfn, arg, sk, ss, msg):
k,s = distfn.stats(moments='ks', *arg)
check_sample_meanvar, sk, k, msg + 'sample skew test'
check_sample_meanvar, ss, s, msg + 'sample kurtosis test'



def check_discrete_chisquare(distfn, arg, rvs, alpha, msg):
Expand Down