Skip to content

Commit

Permalink
ENH: add expand kw to str.extract and str.get_dummies
Browse files Browse the repository at this point in the history
  • Loading branch information
sinhrks committed May 23, 2015
1 parent 9b04bd0 commit f30f63c
Show file tree
Hide file tree
Showing 2 changed files with 212 additions and 75 deletions.
117 changes: 54 additions & 63 deletions pandas/core/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,17 +421,21 @@ def str_extract(arr, pat, flags=0):
Pattern or regular expression
flags : int, default 0 (no flags)
re module flags, e.g. re.IGNORECASE
expand : None or bool, default None
* If None, return Series/Index (one group) or DataFrame/MultiIndex (multiple groups)
* If True, return DataFrame/MultiIndex expanding dimensionality.
* If False, return Series/Index.
Returns
-------
extracted groups : Series (one group) or DataFrame (multiple groups)
extracted groups : Series/Index or DataFrame/MultiIndex of objects
Note that dtype of the result is always object, even when no match is
found and the result is a Series or DataFrame containing only NaN
values.
Examples
--------
A pattern with one group will return a Series. Non-matches will be NaN.
A pattern with one group returns a Series. Non-matches will be NaN.
>>> Series(['a1', 'b2', 'c3']).str.extract('[ab](\d)')
0 1
Expand Down Expand Up @@ -463,11 +467,14 @@ def str_extract(arr, pat, flags=0):
1 b 2
2 NaN NaN
"""
from pandas.core.series import Series
from pandas.core.frame import DataFrame
from pandas.core.index import Index
Or you can specify ``expand=False`` to return Series.
>>> pd.Series(['a1', 'b2', 'c3']).str.extract('([ab])?(\d)', expand=False)
0 [a, 1]
1 [b, 2]
2 [nan, 3]
Name: [0, 1], dtype: object
"""
regex = re.compile(pat, flags=flags)
# just to be safe, check this
if regex.groups == 0:
Expand All @@ -487,18 +494,9 @@ def f(x):
result = np.array([f(val)[0] for val in arr], dtype=object)
name = _get_single_group_name(regex)
else:
if isinstance(arr, Index):
raise ValueError("only one regex group is supported with Index")
name = None
names = dict(zip(regex.groupindex.values(), regex.groupindex.keys()))
columns = [names.get(1 + i, i) for i in range(regex.groups)]
if arr.empty:
result = DataFrame(columns=columns, dtype=object)
else:
result = DataFrame([f(val) for val in arr],
columns=columns,
index=arr.index,
dtype=object)
name = [names.get(1 + i, i) for i in range(regex.groups)]
result = np.array([f(val) for val in arr], dtype=object)
return result, name


Expand All @@ -511,10 +509,13 @@ def str_get_dummies(arr, sep='|'):
----------
sep : string, default "|"
String to split on.
expand : bool, default True
* If True, return DataFrame/MultiIndex expanding dimensionality.
* If False, return Series/Index.
Returns
-------
dummies : DataFrame
dummies : Series/Index or DataFrame/MultiIndex of objects
Examples
--------
Expand All @@ -534,15 +535,15 @@ def str_get_dummies(arr, sep='|'):
--------
pandas.get_dummies
"""
from pandas.core.frame import DataFrame
from pandas.core.index import Index

# GH9980, Index.str does not support get_dummies() as it returns a frame
# TODO: Add fillna GH 10089
if isinstance(arr, Index):
raise TypeError("get_dummies is not supported for string methods on Index")

# TODO remove this hack?
arr = arr.fillna('')
# temp hack
values = arr.values
values[isnull(values)] = ''
arr = Index(values)
else:
arr = arr.fillna('')
try:
arr = sep + arr + sep
except TypeError:
Expand All @@ -558,7 +559,7 @@ def str_get_dummies(arr, sep='|'):
for i, t in enumerate(tags):
pat = sep + t + sep
dummies[:, i] = lib.map_infer(arr.values, lambda x: pat in x)
return DataFrame(dummies, arr.index, tags)
return dummies, tags


def str_join(arr, sep):
Expand Down Expand Up @@ -1043,40 +1044,19 @@ def __iter__(self):
i += 1
g = self.get(i)

def _wrap_result(self, result, **kwargs):

# leave as it is to keep extract and get_dummies results
# can be merged to _wrap_result_expand in v0.17
from pandas.core.series import Series
from pandas.core.frame import DataFrame
from pandas.core.index import Index

if not hasattr(result, 'ndim'):
return result
name = kwargs.get('name') or getattr(result, 'name', None) or self.series.name

if result.ndim == 1:
if isinstance(self.series, Index):
# if result is a boolean np.array, return the np.array
# instead of wrapping it into a boolean Index (GH 8875)
if is_bool_dtype(result):
return result
return Index(result, name=name)
return Series(result, index=self.series.index, name=name)
else:
assert result.ndim < 3
return DataFrame(result, index=self.series.index)
def _wrap_result(self, result, expand=False, name=None):
from pandas.core.index import Index, MultiIndex

def _wrap_result_expand(self, result, expand=False):
if not isinstance(expand, bool):
raise ValueError("expand must be True or False")

from pandas.core.index import Index, MultiIndex
if name is None:
name = getattr(result, 'name', None) or self.series.name

if not hasattr(result, 'ndim'):
return result

if isinstance(self.series, Index):
name = getattr(result, 'name', None)
# if result is a boolean np.array, return the np.array
# instead of wrapping it into a boolean Index (GH 8875)
if hasattr(result, 'dtype') and is_bool_dtype(result):
Expand All @@ -1092,10 +1072,12 @@ def _wrap_result_expand(self, result, expand=False):
if expand:
cons_row = self.series._constructor
cons = self.series._constructor_expanddim
data = [cons_row(x) for x in result]
return cons(data, index=index)
data = [cons_row(x, index=name) for x in result]
return cons(data, index=index, columns=name,
dtype=result.dtype)
else:
name = getattr(result, 'name', None)
if result.ndim > 1:
result = list(result)
cons = self.series._constructor
return cons(result, name=name, index=index)

Expand All @@ -1109,7 +1091,7 @@ def cat(self, others=None, sep=None, na_rep=None):
@copy(str_split)
def split(self, pat=None, n=-1, expand=False):
result = str_split(self.series, pat, n=n)
return self._wrap_result_expand(result, expand=expand)
return self._wrap_result(result, expand=expand)

_shared_docs['str_partition'] = ("""
Split the string at the %(side)s occurrence of `sep`, and return 3 elements
Expand Down Expand Up @@ -1160,15 +1142,15 @@ def split(self, pat=None, n=-1, expand=False):
def partition(self, pat=' ', expand=True):
f = lambda x: x.partition(pat)
result = _na_map(f, self.series)
return self._wrap_result_expand(result, expand=expand)
return self._wrap_result(result, expand=expand)

@Appender(_shared_docs['str_partition'] % {'side': 'last',
'return': '3 elements containing two empty strings, followed by the string itself',
'also': 'partition : Split the string at the first occurrence of `sep`'})
def rpartition(self, pat=' ', expand=True):
f = lambda x: x.rpartition(pat)
result = _na_map(f, self.series)
return self._wrap_result_expand(result, expand=expand)
return self._wrap_result(result, expand=expand)

@copy(str_get)
def get(self, i):
Expand Down Expand Up @@ -1309,9 +1291,9 @@ def wrap(self, width, **kwargs):
return self._wrap_result(result)

@copy(str_get_dummies)
def get_dummies(self, sep='|'):
result = str_get_dummies(self.series, sep)
return self._wrap_result(result)
def get_dummies(self, sep='|', expand=True):
result, name = str_get_dummies(self.series, sep)
return self._wrap_result(result, name=name, expand=expand)

@copy(str_translate)
def translate(self, table, deletechars=None):
Expand All @@ -1324,9 +1306,18 @@ def translate(self, table, deletechars=None):
findall = _pat_wrapper(str_findall, flags=True)

@copy(str_extract)
def extract(self, pat, flags=0):
def extract(self, pat, flags=0, expand=None):
result, name = str_extract(self.series, pat, flags=flags)
return self._wrap_result(result, name=name)
if expand is None and hasattr(result, 'ndim'):
# to be compat with previous behavior
if len(result) == 0:
# for empty input
expand = True if isinstance(name, list) else False
elif result.ndim > 1:
expand = True
else:
expand = False
return self._wrap_result(result, name=name, expand=expand)

_shared_docs['find'] = ("""
Return %(side)s indexes in each strings in the Series/Index
Expand Down
Loading

0 comments on commit f30f63c

Please sign in to comment.