Skip to content

Commit

Permalink
ENH: add expand kw to str.extract and str.get_dummies
Browse files Browse the repository at this point in the history
  • Loading branch information
sinhrks committed Nov 24, 2015
1 parent 5bc191a commit c007dba
Show file tree
Hide file tree
Showing 3 changed files with 210 additions and 81 deletions.
123 changes: 53 additions & 70 deletions pandas/core/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,17 +424,21 @@ def str_extract(arr, pat, flags=0):
Pattern or regular expression
flags : int, default 0 (no flags)
re module flags, e.g. re.IGNORECASE
expand : None or bool, default None
* If None, return Series/Index (one group) or DataFrame/MultiIndex (multiple groups)
* If True, return DataFrame/MultiIndex expanding dimensionality.
* If False, return Series/Index.
Returns
-------
extracted groups : Series (one group) or DataFrame (multiple groups)
extracted groups : Series/Index or DataFrame/MultiIndex of objects
Note that dtype of the result is always object, even when no match is
found and the result is a Series or DataFrame containing only NaN
values.
Examples
--------
A pattern with one group will return a Series. Non-matches will be NaN.
A pattern with one group returns a Series. Non-matches will be NaN.
>>> Series(['a1', 'b2', 'c3']).str.extract('[ab](\d)')
0 1
Expand Down Expand Up @@ -466,11 +470,14 @@ def str_extract(arr, pat, flags=0):
1 b 2
2 NaN NaN
"""
from pandas.core.series import Series
from pandas.core.frame import DataFrame
from pandas.core.index import Index
Or you can specify ``expand=False`` to return Series.
>>> pd.Series(['a1', 'b2', 'c3']).str.extract('([ab])?(\d)', expand=False)
0 [a, 1]
1 [b, 2]
2 [nan, 3]
Name: [0, 1], dtype: object
"""
regex = re.compile(pat, flags=flags)
# just to be safe, check this
if regex.groups == 0:
Expand All @@ -490,18 +497,9 @@ def f(x):
result = np.array([f(val)[0] for val in arr], dtype=object)
name = _get_single_group_name(regex)
else:
if isinstance(arr, Index):
raise ValueError("only one regex group is supported with Index")
name = None
names = dict(zip(regex.groupindex.values(), regex.groupindex.keys()))
columns = [names.get(1 + i, i) for i in range(regex.groups)]
if arr.empty:
result = DataFrame(columns=columns, dtype=object)
else:
result = DataFrame([f(val) for val in arr],
columns=columns,
index=arr.index,
dtype=object)
name = [names.get(1 + i, i) for i in range(regex.groups)]
result = np.array([f(val) for val in arr], dtype=object)
return result, name


Expand All @@ -514,10 +512,13 @@ def str_get_dummies(arr, sep='|'):
----------
sep : string, default "|"
String to split on.
expand : bool, default True
* If True, return DataFrame/MultiIndex expanding dimensionality.
* If False, return Series/Index.
Returns
-------
dummies : DataFrame
dummies : Series/Index or DataFrame/MultiIndex of objects
Examples
--------
Expand All @@ -537,14 +538,7 @@ def str_get_dummies(arr, sep='|'):
--------
pandas.get_dummies
"""
from pandas.core.frame import DataFrame
from pandas.core.index import Index

# GH9980, Index.str does not support get_dummies() as it returns a frame
if isinstance(arr, Index):
raise TypeError("get_dummies is not supported for string methods on Index")

# TODO remove this hack?
arr = arr.fillna('')
try:
arr = sep + arr + sep
Expand All @@ -561,7 +555,7 @@ def str_get_dummies(arr, sep='|'):
for i, t in enumerate(tags):
pat = sep + t + sep
dummies[:, i] = lib.map_infer(arr.values, lambda x: pat in x)
return DataFrame(dummies, arr.index, tags)
return dummies, tags


def str_join(arr, sep):
Expand Down Expand Up @@ -1081,7 +1075,10 @@ def __iter__(self):
i += 1
g = self.get(i)

def _wrap_result(self, result, use_codes=True, name=None):
def _wrap_result(self, result, use_codes=True, name=None, expand=False):

if not isinstance(expand, bool):
raise ValueError("expand must be True or False")

# for category, we do the stuff on the categories, so blow it up
# to the full series again
Expand All @@ -1095,39 +1092,11 @@ def _wrap_result(self, result, use_codes=True, name=None):
# can be merged to _wrap_result_expand in v0.17
from pandas.core.series import Series
from pandas.core.frame import DataFrame
from pandas.core.index import Index
from pandas.core.index import Index, MultiIndex

if not hasattr(result, 'ndim'):
return result
name = name or getattr(result, 'name', None) or self._orig.name

if result.ndim == 1:
if isinstance(self._orig, Index):
# if result is a boolean np.array, return the np.array
# instead of wrapping it into a boolean Index (GH 8875)
if is_bool_dtype(result):
return result
return Index(result, name=name)
return Series(result, index=self._orig.index, name=name)
else:
assert result.ndim < 3
return DataFrame(result, index=self._orig.index)

def _wrap_result_expand(self, result, expand=False):
if not isinstance(expand, bool):
raise ValueError("expand must be True or False")

# for category, we do the stuff on the categories, so blow it up
# to the full series again
if self._is_categorical:
result = take_1d(result, self._orig.cat.codes)

from pandas.core.index import Index, MultiIndex
if not hasattr(result, 'ndim'):
return result

if isinstance(self._orig, Index):
name = getattr(result, 'name', None)
# if result is a boolean np.array, return the np.array
# instead of wrapping it into a boolean Index (GH 8875)
if hasattr(result, 'dtype') and is_bool_dtype(result):
Expand All @@ -1137,7 +1106,7 @@ def _wrap_result_expand(self, result, expand=False):
result = list(result)
return MultiIndex.from_tuples(result, names=name)
else:
return Index(result, name=name)
return Index(result, name=name, tupleize_cols=False)
else:
index = self._orig.index
if expand:
Expand All @@ -1148,30 +1117,34 @@ def cons_row(x):
return [ x ]
cons = self._orig._constructor_expanddim
data = [cons_row(x) for x in result]
return cons(data, index=index)
return cons(data, index=index, columns=name,
dtype=result.dtype)
else:
name = getattr(result, 'name', None)
if result.ndim > 1:
result = list(result)
cons = self._orig._constructor
return cons(result, name=name, index=index)

@copy(str_cat)
def cat(self, others=None, sep=None, na_rep=None):
data = self._orig if self._is_categorical else self._data
result = str_cat(data, others=others, sep=sep, na_rep=na_rep)
if not hasattr(result, 'ndim'):
# str_cat may results in np.nan or str
return result
return self._wrap_result(result, use_codes=(not self._is_categorical))


@deprecate_kwarg('return_type', 'expand',
mapping={'series': False, 'frame': True})
@copy(str_split)
def split(self, pat=None, n=-1, expand=False):
result = str_split(self._data, pat, n=n)
return self._wrap_result_expand(result, expand=expand)
return self._wrap_result(result, expand=expand)

@copy(str_rsplit)
def rsplit(self, pat=None, n=-1, expand=False):
result = str_rsplit(self._data, pat, n=n)
return self._wrap_result_expand(result, expand=expand)
return self._wrap_result(result, expand=expand)

_shared_docs['str_partition'] = ("""
Split the string at the %(side)s occurrence of `sep`, and return 3 elements
Expand Down Expand Up @@ -1222,15 +1195,15 @@ def rsplit(self, pat=None, n=-1, expand=False):
def partition(self, pat=' ', expand=True):
f = lambda x: x.partition(pat)
result = _na_map(f, self._data)
return self._wrap_result_expand(result, expand=expand)
return self._wrap_result(result, expand=expand)

@Appender(_shared_docs['str_partition'] % {'side': 'last',
'return': '3 elements containing two empty strings, followed by the string itself',
'also': 'partition : Split the string at the first occurrence of `sep`'})
def rpartition(self, pat=' ', expand=True):
f = lambda x: x.rpartition(pat)
result = _na_map(f, self._data)
return self._wrap_result_expand(result, expand=expand)
return self._wrap_result(result, expand=expand)

@copy(str_get)
def get(self, i):
Expand Down Expand Up @@ -1371,12 +1344,13 @@ def wrap(self, width, **kwargs):
return self._wrap_result(result)

@copy(str_get_dummies)
def get_dummies(self, sep='|'):
def get_dummies(self, sep='|', expand=True):
# we need to cast to Series of strings as only that has all
# methods available for making the dummies...
data = self._orig.astype(str) if self._is_categorical else self._data
result = str_get_dummies(data, sep)
return self._wrap_result(result, use_codes=(not self._is_categorical))
result, name = str_get_dummies(data, sep)
return self._wrap_result(result, use_codes=(not self._is_categorical),
name=name, expand=expand)

@copy(str_translate)
def translate(self, table, deletechars=None):
Expand All @@ -1389,9 +1363,18 @@ def translate(self, table, deletechars=None):
findall = _pat_wrapper(str_findall, flags=True)

@copy(str_extract)
def extract(self, pat, flags=0):
result, name = str_extract(self._data, pat, flags=flags)
return self._wrap_result(result, name=name)
def extract(self, pat, flags=0, expand=None):
result, name = str_extract(self._orig, pat, flags=flags)
if expand is None and hasattr(result, 'ndim'):
# to be compat with previous behavior
if len(result) == 0:
# for empty input
expand = True if isinstance(name, list) else False
elif result.ndim > 1:
expand = True
else:
expand = False
return self._wrap_result(result, name=name, use_codes=False, expand=expand)

_shared_docs['find'] = ("""
Return %(side)s indexes in each strings in the Series/Index
Expand Down
1 change: 1 addition & 0 deletions pandas/tests/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -3714,6 +3714,7 @@ def test_str_accessor_api_for_categorical(self):


for func, args, kwargs in func_defs:
print(func, args, kwargs, c)
res = getattr(c.str, func)(*args, **kwargs)
exp = getattr(s.str, func)(*args, **kwargs)

Expand Down
Loading

0 comments on commit c007dba

Please sign in to comment.