From c007dba44ca3d72753f0621eeac0d36d4c95cb3d Mon Sep 17 00:00:00 2001 From: sinhrks Date: Fri, 8 May 2015 11:50:15 +0900 Subject: [PATCH] ENH: add expand kw to str.extract and str.get_dummies --- pandas/core/strings.py | 123 ++++++++++------------- pandas/tests/test_categorical.py | 1 + pandas/tests/test_strings.py | 167 +++++++++++++++++++++++++++++-- 3 files changed, 210 insertions(+), 81 deletions(-) diff --git a/pandas/core/strings.py b/pandas/core/strings.py index a8907ac192707..08f7d1b6481f4 100644 --- a/pandas/core/strings.py +++ b/pandas/core/strings.py @@ -424,17 +424,21 @@ def str_extract(arr, pat, flags=0): Pattern or regular expression flags : int, default 0 (no flags) re module flags, e.g. re.IGNORECASE + expand : None or bool, default None + * If None, return Series/Index (one group) or DataFrame/MultiIndex (multiple groups) + * If True, return DataFrame/MultiIndex expanding dimensionality. + * If False, return Series/Index. Returns ------- - extracted groups : Series (one group) or DataFrame (multiple groups) + extracted groups : Series/Index or DataFrame/MultiIndex of objects Note that dtype of the result is always object, even when no match is found and the result is a Series or DataFrame containing only NaN values. Examples -------- - A pattern with one group will return a Series. Non-matches will be NaN. + A pattern with one group returns a Series. Non-matches will be NaN. >>> Series(['a1', 'b2', 'c3']).str.extract('[ab](\d)') 0 1 @@ -466,11 +470,14 @@ def str_extract(arr, pat, flags=0): 1 b 2 2 NaN NaN - """ - from pandas.core.series import Series - from pandas.core.frame import DataFrame - from pandas.core.index import Index + Or you can specify ``expand=False`` to return Series. + >>> pd.Series(['a1', 'b2', 'c3']).str.extract('([ab])?(\d)', expand=False) + 0 [a, 1] + 1 [b, 2] + 2 [nan, 3] + Name: [0, 1], dtype: object + """ regex = re.compile(pat, flags=flags) # just to be safe, check this if regex.groups == 0: @@ -490,18 +497,9 @@ def f(x): result = np.array([f(val)[0] for val in arr], dtype=object) name = _get_single_group_name(regex) else: - if isinstance(arr, Index): - raise ValueError("only one regex group is supported with Index") - name = None names = dict(zip(regex.groupindex.values(), regex.groupindex.keys())) - columns = [names.get(1 + i, i) for i in range(regex.groups)] - if arr.empty: - result = DataFrame(columns=columns, dtype=object) - else: - result = DataFrame([f(val) for val in arr], - columns=columns, - index=arr.index, - dtype=object) + name = [names.get(1 + i, i) for i in range(regex.groups)] + result = np.array([f(val) for val in arr], dtype=object) return result, name @@ -514,10 +512,13 @@ def str_get_dummies(arr, sep='|'): ---------- sep : string, default "|" String to split on. + expand : bool, default True + * If True, return DataFrame/MultiIndex expanding dimensionality. + * If False, return Series/Index. Returns ------- - dummies : DataFrame + dummies : Series/Index or DataFrame/MultiIndex of objects Examples -------- @@ -537,14 +538,7 @@ def str_get_dummies(arr, sep='|'): -------- pandas.get_dummies """ - from pandas.core.frame import DataFrame from pandas.core.index import Index - - # GH9980, Index.str does not support get_dummies() as it returns a frame - if isinstance(arr, Index): - raise TypeError("get_dummies is not supported for string methods on Index") - - # TODO remove this hack? arr = arr.fillna('') try: arr = sep + arr + sep @@ -561,7 +555,7 @@ def str_get_dummies(arr, sep='|'): for i, t in enumerate(tags): pat = sep + t + sep dummies[:, i] = lib.map_infer(arr.values, lambda x: pat in x) - return DataFrame(dummies, arr.index, tags) + return dummies, tags def str_join(arr, sep): @@ -1081,7 +1075,10 @@ def __iter__(self): i += 1 g = self.get(i) - def _wrap_result(self, result, use_codes=True, name=None): + def _wrap_result(self, result, use_codes=True, name=None, expand=False): + + if not isinstance(expand, bool): + raise ValueError("expand must be True or False") # for category, we do the stuff on the categories, so blow it up # to the full series again @@ -1095,39 +1092,11 @@ def _wrap_result(self, result, use_codes=True, name=None): # can be merged to _wrap_result_expand in v0.17 from pandas.core.series import Series from pandas.core.frame import DataFrame - from pandas.core.index import Index + from pandas.core.index import Index, MultiIndex - if not hasattr(result, 'ndim'): - return result name = name or getattr(result, 'name', None) or self._orig.name - if result.ndim == 1: - if isinstance(self._orig, Index): - # if result is a boolean np.array, return the np.array - # instead of wrapping it into a boolean Index (GH 8875) - if is_bool_dtype(result): - return result - return Index(result, name=name) - return Series(result, index=self._orig.index, name=name) - else: - assert result.ndim < 3 - return DataFrame(result, index=self._orig.index) - - def _wrap_result_expand(self, result, expand=False): - if not isinstance(expand, bool): - raise ValueError("expand must be True or False") - - # for category, we do the stuff on the categories, so blow it up - # to the full series again - if self._is_categorical: - result = take_1d(result, self._orig.cat.codes) - - from pandas.core.index import Index, MultiIndex - if not hasattr(result, 'ndim'): - return result - if isinstance(self._orig, Index): - name = getattr(result, 'name', None) # if result is a boolean np.array, return the np.array # instead of wrapping it into a boolean Index (GH 8875) if hasattr(result, 'dtype') and is_bool_dtype(result): @@ -1137,7 +1106,7 @@ def _wrap_result_expand(self, result, expand=False): result = list(result) return MultiIndex.from_tuples(result, names=name) else: - return Index(result, name=name) + return Index(result, name=name, tupleize_cols=False) else: index = self._orig.index if expand: @@ -1148,9 +1117,11 @@ def cons_row(x): return [ x ] cons = self._orig._constructor_expanddim data = [cons_row(x) for x in result] - return cons(data, index=index) + return cons(data, index=index, columns=name, + dtype=result.dtype) else: - name = getattr(result, 'name', None) + if result.ndim > 1: + result = list(result) cons = self._orig._constructor return cons(result, name=name, index=index) @@ -1158,20 +1129,22 @@ def cons_row(x): def cat(self, others=None, sep=None, na_rep=None): data = self._orig if self._is_categorical else self._data result = str_cat(data, others=others, sep=sep, na_rep=na_rep) + if not hasattr(result, 'ndim'): + # str_cat may results in np.nan or str + return result return self._wrap_result(result, use_codes=(not self._is_categorical)) - @deprecate_kwarg('return_type', 'expand', mapping={'series': False, 'frame': True}) @copy(str_split) def split(self, pat=None, n=-1, expand=False): result = str_split(self._data, pat, n=n) - return self._wrap_result_expand(result, expand=expand) + return self._wrap_result(result, expand=expand) @copy(str_rsplit) def rsplit(self, pat=None, n=-1, expand=False): result = str_rsplit(self._data, pat, n=n) - return self._wrap_result_expand(result, expand=expand) + return self._wrap_result(result, expand=expand) _shared_docs['str_partition'] = (""" Split the string at the %(side)s occurrence of `sep`, and return 3 elements @@ -1222,7 +1195,7 @@ def rsplit(self, pat=None, n=-1, expand=False): def partition(self, pat=' ', expand=True): f = lambda x: x.partition(pat) result = _na_map(f, self._data) - return self._wrap_result_expand(result, expand=expand) + return self._wrap_result(result, expand=expand) @Appender(_shared_docs['str_partition'] % {'side': 'last', 'return': '3 elements containing two empty strings, followed by the string itself', @@ -1230,7 +1203,7 @@ def partition(self, pat=' ', expand=True): def rpartition(self, pat=' ', expand=True): f = lambda x: x.rpartition(pat) result = _na_map(f, self._data) - return self._wrap_result_expand(result, expand=expand) + return self._wrap_result(result, expand=expand) @copy(str_get) def get(self, i): @@ -1371,12 +1344,13 @@ def wrap(self, width, **kwargs): return self._wrap_result(result) @copy(str_get_dummies) - def get_dummies(self, sep='|'): + def get_dummies(self, sep='|', expand=True): # we need to cast to Series of strings as only that has all # methods available for making the dummies... data = self._orig.astype(str) if self._is_categorical else self._data - result = str_get_dummies(data, sep) - return self._wrap_result(result, use_codes=(not self._is_categorical)) + result, name = str_get_dummies(data, sep) + return self._wrap_result(result, use_codes=(not self._is_categorical), + name=name, expand=expand) @copy(str_translate) def translate(self, table, deletechars=None): @@ -1389,9 +1363,18 @@ def translate(self, table, deletechars=None): findall = _pat_wrapper(str_findall, flags=True) @copy(str_extract) - def extract(self, pat, flags=0): - result, name = str_extract(self._data, pat, flags=flags) - return self._wrap_result(result, name=name) + def extract(self, pat, flags=0, expand=None): + result, name = str_extract(self._orig, pat, flags=flags) + if expand is None and hasattr(result, 'ndim'): + # to be compat with previous behavior + if len(result) == 0: + # for empty input + expand = True if isinstance(name, list) else False + elif result.ndim > 1: + expand = True + else: + expand = False + return self._wrap_result(result, name=name, use_codes=False, expand=expand) _shared_docs['find'] = (""" Return %(side)s indexes in each strings in the Series/Index diff --git a/pandas/tests/test_categorical.py b/pandas/tests/test_categorical.py index e98c98fdec8b3..8d8d1e6d456c0 100755 --- a/pandas/tests/test_categorical.py +++ b/pandas/tests/test_categorical.py @@ -3714,6 +3714,7 @@ def test_str_accessor_api_for_categorical(self): for func, args, kwargs in func_defs: + print(func, args, kwargs, c) res = getattr(c.str, func)(*args, **kwargs) exp = getattr(s.str, func)(*args, **kwargs) diff --git a/pandas/tests/test_strings.py b/pandas/tests/test_strings.py index 0013a6579718a..b43d637c3c4c5 100644 --- a/pandas/tests/test_strings.py +++ b/pandas/tests/test_strings.py @@ -522,15 +522,24 @@ def test_extract(self): exp = DataFrame([['BAD__', 'BAD'], er, er]) tm.assert_frame_equal(result, exp) + result = values.str.extract('.*(BAD[_]+).*(BAD)', expand=False) + exp = Series([['BAD__', 'BAD'], er, er], name=[0, 1]) + tm.assert_series_equal(result, exp) + # mixed mixed = Series(['aBAD_BAD', NA, 'BAD_b_BAD', True, datetime.today(), 'foo', None, 1, 2.]) - rs = Series(mixed).str.extract('.*(BAD[_]+).*(BAD)') + rs = mixed.str.extract('.*(BAD[_]+).*(BAD)') exp = DataFrame([['BAD_', 'BAD'], er, ['BAD_', 'BAD'], er, er, er, er, er, er]) tm.assert_frame_equal(rs, exp) + rs = mixed.str.extract('.*(BAD[_]+).*(BAD)', expand=False) + exp = Series([['BAD_', 'BAD'], er, ['BAD_', 'BAD'], er, er, + er, er, er, er], name=[0, 1]) + tm.assert_series_equal(rs, exp) + # unicode values = Series([u('fooBAD__barBAD'), NA, u('foo')]) @@ -538,12 +547,25 @@ def test_extract(self): exp = DataFrame([[u('BAD__'), u('BAD')], er, er]) tm.assert_frame_equal(result, exp) - # GH9980 - # Index only works with one regex group since - # multi-group would expand to a frame + result = values.str.extract('.*(BAD[_]+).*(BAD)', expand=False) + exp = Series([[u('BAD__'), u('BAD')], er, er], name=[0, 1]) + tm.assert_series_equal(result, exp) + idx = Index(['A1', 'A2', 'A3', 'A4', 'B5']) - with tm.assertRaisesRegexp(ValueError, "supported"): - idx.str.extract('([AB])([123])') + result = idx.str.extract('([AB])([123])') + exp = MultiIndex.from_tuples([('A', '1'), ('A', '2'), ('A', '3'), + (NA, NA), (NA, NA)], names=[0, 1]) + tm.assert_index_equal(result, exp) + + # check warning to return single Series + s = Series(['A1', 'A2']) + result = s.str.extract(r'(A)\d') + tm.assert_series_equal(result, Series(['A', 'A'])) + + # expand=True results in DataFrame + s = Series(['A1', 'A2']) + result = s.str.extract(r'(A)\d', expand=True) + tm.assert_frame_equal(result, DataFrame(['A', 'A'])) # these should work for both Series and Index for klass in [Series, Index]: @@ -635,6 +657,110 @@ def check_index(index): tm.makeDateIndex, tm.makePeriodIndex ]: check_index(index()) + def test_extract_index(self): + values = Index(['fooBAD__barBAD', NA, 'foo']) + er = [NA, NA] # empty row + + result = values.str.extract('.*(BAD[_]+).*(BAD)') + exp = MultiIndex.from_tuples([['BAD__', 'BAD'], er, er], names=[0, 1]) + tm.assert_index_equal(result, exp) + + # mixed + mixed = Index(['aBAD_BAD', NA, 'BAD_b_BAD', True, datetime.today(), + 'foo', None, 1, 2.]) + + rs = mixed.str.extract('.*(BAD[_]+).*(BAD)') + exp = MultiIndex.from_tuples([['BAD_', 'BAD'], er, ['BAD_', 'BAD'], er, er, + er, er, er, er], names=[0, 1]) + tm.assert_index_equal(rs, exp) + + # unicode + values = Index([u('fooBAD__barBAD'), NA, u('foo')]) + + result = values.str.extract('.*(BAD[_]+).*(BAD)') + exp = MultiIndex.from_tuples([[u('BAD__'), u('BAD')], er, er], names=[0, 1]) + tm.assert_index_equal(result, exp) + + s = Index(['A1', 'B2', 'C3']) + # one group, no matches + result = s.str.extract('(_)') + exp = Index([NA, NA, NA], dtype=object) + tm.assert_index_equal(result, exp) + + # two groups, no matches + result = s.str.extract('(_)(_)') + exp = MultiIndex.from_tuples([[NA, NA], [NA, NA], [NA, NA]], + names=[0, 1]) + tm.assert_index_equal(result, exp) + + # one group, some matches + result = s.str.extract('([AB])[123]') + exp = Index(['A', 'B', NA]) + tm.assert_index_equal(result, exp) + self.assertTrue(result.nlevels, 1) + + # two groups, some matches + result = s.str.extract('([AB])([123])') + exp = MultiIndex.from_tuples([['A', '1'], ['B', '2'], [NA, NA]], names=[0, 1]) + tm.assert_index_equal(result, exp) + self.assertTrue(result.nlevels, 2) + + result = s.str.extract('([ABC])([123])', expand=False) + exp = Index(np.array([['A', '1'], ['B', '2'], ['C', '3']]), + dtype=object, name=[0, 1]) + tm.assert_index_equal(result, exp) + self.assertTrue(result.nlevels, 1) + + # one named group + result = s.str.extract('(?P[AB])') + exp = Index(['A', 'B', NA], name='letter') + tm.assert_index_equal(result, exp) + + # two named groups + result = s.str.extract('(?P[AB])(?P[123])') + exp = MultiIndex.from_tuples([['A', '1'], ['B', '2'], [NA, NA]], + names=['letter', 'number']) + tm.assert_index_equal(result, exp) + self.assertTrue(result.nlevels, 2) + + result = s.str.extract('(?P[ABC])(?P[123])', expand=False) + exp = Index(np.array([['A', '1'], ['B', '2'], ['C', '3']]), + dtype=object, name=['letter', 'number']) + tm.assert_index_equal(result, exp) + self.assertTrue(result.nlevels, 1) + + # mix named and unnamed groups + result = s.str.extract('([AB])(?P[123])') + exp = MultiIndex.from_tuples([['A', '1'], ['B', '2'], [NA, NA]], + names=[0, 'number']) + tm.assert_index_equal(result, exp) + self.assertTrue(result.nlevels, 2) + + # one normal group, one non-capturing group + result = s.str.extract('([AB])(?:[123])') + exp = Index(['A', 'B', NA]) + tm.assert_index_equal(result, exp) + self.assertTrue(result.nlevels, 1) + + # two normal groups, one non-capturing group + result = Index(['A11', 'B22', 'C33']).str.extract('([AB])([123])(?:[123])') + exp = MultiIndex.from_tuples([['A', '1'], ['B', '2'], [NA, NA]], names=[0, 1]) + tm.assert_index_equal(result, exp) + + # one optional group followed by one normal group + result = Index(['A1', 'B2', '3']).str.extract('(?P[AB])?(?P[123])') + exp = MultiIndex.from_tuples([['A', '1'], ['B', '2'], [NA, '3']], + names=['letter', 'number']) + tm.assert_index_equal(result, exp) + self.assertTrue(result.nlevels, 2) + + # one normal group followed by one optional group + result = Index(['A1', 'B2', 'C']).str.extract('(?P[ABC])(?P[123])?') + exp = MultiIndex.from_tuples([['A', '1'], ['B', '2'], ['C', NA]], + names=['letter', 'number']) + tm.assert_index_equal(result, exp) + self.assertTrue(result.nlevels, 2) + def test_extract_single_series_name_is_preserved(self): s = Series(['a3', 'b3', 'c2'], name='bob') r = s.str.extract(r'(?P[a-z])') @@ -777,11 +903,30 @@ def test_get_dummies(self): columns=list('7ab')) tm.assert_frame_equal(result, expected) - # GH9980 - # Index.str does not support get_dummies() as it returns a frame - with tm.assertRaisesRegexp(TypeError, "not supported"): - idx = Index(['a|b', 'a|c', 'b|c']) - idx.str.get_dummies('|') + idx = Index(['a|b', 'a|c', 'b|c']) + result = idx.str.get_dummies('|') + expected = MultiIndex.from_tuples([(1, 1, 0), (1, 0, 1), (0, 1, 1)], + names=['a', 'b', 'c']) + tm.assert_index_equal(result, expected) + + def test_get_dummies_without_expand(self): + s = Series(['a|b', 'a|c', np.nan]) + result = s.str.get_dummies('|', expand=False) + expected = Series([[1, 1, 0], [1, 0, 1], [0, 0, 0]], name=['a', 'b', 'c']) + tm.assert_series_equal(result, expected) + + s = Series(['a;b', 'a', 7]) + result = s.str.get_dummies(';', expand=False) + expected = Series([[0, 1, 1], [0, 1, 0], [1, 0, 0]], + name=list('7ab')) + tm.assert_series_equal(result, expected) + + idx = Index(['a|b', 'a|c', 'b|c']) + result = idx.str.get_dummies('|', expand=False) + expected = Index(np.array([(1, 1, 0), (1, 0, 1), (0, 1, 1)]), + name=['a', 'b', 'c']) + tm.assert_index_equal(result, expected) + self.assertEqual(result.nlevels, 1) def test_join(self): values = Series(['a_b_c', 'c_d_e', np.nan, 'f_g_h'])