Skip to content

Commit

Permalink
ENH: MultiComparison check input in group_order, allow subsets
Browse files Browse the repository at this point in the history
  • Loading branch information
josef-pkt committed Sep 22, 2014
1 parent 69d1c6d commit 4f5b4a8
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 6 deletions.
28 changes: 23 additions & 5 deletions statsmodels/sandbox/stats/multicomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,6 @@ def plot_simultaneous(self, comparison_name=None, ax=None, figsize=(10,6),
return fig



class MultiComparison(object):
'''Tests for multiple comparisons
Expand All @@ -778,13 +777,14 @@ class MultiComparison(object):
group labels corresponding to each data point
group_order : list of strings, optional
the desired order for the group mean results to be reported in. If
not specified, results are reported in increasing order
not specified, results are reported in increasing order.
If group_order does not contain all labels that are in groups, then
only those observations are kept that have a label in group_order.
'''

def __init__(self, data, groups, group_order=None):
if len(set(groups)) < 2:
raise ValueError('2 or more groups required for multiple comparisons')

if len(data) != len(groups):
raise ValueError('data has %d elements and groups has %d' % (len(data), len(groups)))
self.data = np.asarray(data)
Expand All @@ -801,16 +801,34 @@ def __init__(self, data, groups, group_order=None):
raise ValueError(
"group_order value '%s' not found in groups"%grp)
self.groupsunique = np.array(group_order)
self.groupintlab = np.zeros(len(data))
self.groupintlab = np.empty(len(data), int)
self.groupintlab.fill(-999) # instead of a nan
count = 0
for name in self.groupsunique:
idx = np.where(self.groups == name)[0]
count += len(idx)
self.groupintlab[idx] = np.where(self.groupsunique == name)[0]
if count != data.shape[0]:
#raise ValueError('group_order does not contain all groups')
# warn and keep only observations with label in group_order
import warnings
warnings.warn('group_order does not contain all groups:' +
' dropping observations')

mask_keep = self.groupintlab != -999
self.groupintlab = self.groupintlab[mask_keep]
self.data = self.data[mask_keep]
self.groups = self.groups[mask_keep]

if len(self.groupsunique) < 2:
raise ValueError('2 or more groups required for multiple comparisons')

self.datali = [data[self.groups == k] for k in self.groupsunique]
self.pairindices = np.triu_indices(len(self.groupsunique), 1) #tuple
self.nobs = self.data.shape[0]
self.ngroups = len(self.groupsunique)


def getranks(self):
'''convert data to rankdata and attach
Expand Down
32 changes: 31 additions & 1 deletion statsmodels/stats/tests/test_pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

from statsmodels.compat.python import BytesIO, asbytes, range
import numpy as np
from numpy.testing import assert_almost_equal, assert_equal, assert_, assert_raises
from numpy.testing import (assert_almost_equal, assert_equal, assert_,
assert_raises, assert_allclose)

from statsmodels.stats.libqsturng import qsturng

Expand Down Expand Up @@ -246,6 +247,35 @@ def test_incorrect_output(self):
# just one group
assert_raises(ValueError, MultiComparison, np.array([1] * 10), [1] * 10)

# group_order doesn't select all observations, only one group left
assert_raises(ValueError, MultiComparison, np.array([1] * 10),
[1, 2] * 5, group_order=[1])

# group_order doesn't select all observations,
# we do tukey_hsd with reduced set of observations
data = np.arange(15)
groups = np.repeat([1, 2, 3], 5)
mod1 = MultiComparison(np.array(data), groups, group_order=[1, 2])
res1 = mod1.tukeyhsd(alpha=0.01)
mod2 = MultiComparison(np.array(data[:10]), groups[:10])
res2 = mod2.tukeyhsd(alpha=0.01)

attributes = ['confint', 'data', 'df_total', 'groups', 'groupsunique',
'meandiffs', 'q_crit', 'reject', 'reject2', 'std_pairs',
'variance']
for att in attributes:
err_msg = att + 'failed'
assert_allclose(getattr(res1, att), getattr(res2, att), rtol=1e-14,
err_msg=err_msg)

attributes = ['data', 'datali', 'groupintlab', 'groups', 'groupsunique',
'ngroups', 'nobs', 'pairindices']
for att in attributes:
err_msg = att + 'failed'
assert_allclose(getattr(mod1, att), getattr(mod2, att), rtol=1e-14,
err_msg=err_msg)


class TestTuckeyHSD2s(CheckTuckeyHSDMixin):
@classmethod
def setup_class(self):
Expand Down

0 comments on commit 4f5b4a8

Please sign in to comment.