Skip to content

Commit

Permalink
ENH: Add skipna parameter to infer_dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Jul 24, 2017
1 parent c55dbf0 commit c069ab1
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 40 deletions.
2 changes: 2 additions & 0 deletions doc/source/whatsnew/v0.21.0.txt
Expand Up @@ -24,6 +24,8 @@ New features
<https://www.python.org/dev/peps/pep-0519/>`_ on most readers and writers (:issue:`13823`)
- Added ``__fspath__`` method to :class:`~pandas.HDFStore`, :class:`~pandas.ExcelFile`,
and :class:`~pandas.ExcelWriter` to work properly with the file system path protocol (:issue:`13823`)
- Added ``skipna`` parameter :func:`~pandas.api.types.infer_dtype` to support
type inference in the presence of missing values (:issue:`17059`).


.. _whatsnew_0210.enhancements.infer_objects:
Expand Down
134 changes: 96 additions & 38 deletions pandas/_libs/src/inference.pyx
Expand Up @@ -222,14 +222,16 @@ cdef _try_infer_map(v):
return None


def infer_dtype(object value):
def infer_dtype(object value, bint skipna=False):
"""
Effeciently infer the type of a passed val, or list-like
array of values. Return a string describing the type.
Parameters
----------
value : scalar, list, ndarray, or pandas type
skipna : bool, default False
Ignore NaN values when inferring the type.
Returns
-------
Expand Down Expand Up @@ -272,6 +274,9 @@ def infer_dtype(object value):
>>> infer_dtype(['foo', 'bar'])
'string'
>>> infer_dtype(['a', np.nan, 'b'], skipna=True)
'string'
>>> infer_dtype([b'foo', b'bar'])
'bytes'
Expand Down Expand Up @@ -310,7 +315,6 @@ def infer_dtype(object value):
>>> infer_dtype(pd.Series(list('aabc')).astype('category'))
'categorical'
"""
cdef:
Py_ssize_t i, n
Expand Down Expand Up @@ -356,7 +360,7 @@ def infer_dtype(object value):
values = values.ravel()

# try to use a valid value
for i from 0 <= i < n:
for i in range(n):
val = util.get_value_1d(values, i)

# do not use is_nul_datetimelike to keep
Expand Down Expand Up @@ -403,11 +407,11 @@ def infer_dtype(object value):
return 'datetime'

elif is_date(val):
if is_date_array(values):
if is_date_array(values, skipna=skipna):
return 'date'

elif is_time(val):
if is_time_array(values):
if is_time_array(values, skipna=skipna):
return 'time'

elif is_decimal(val):
Expand All @@ -420,19 +424,19 @@ def infer_dtype(object value):
return 'mixed-integer-float'

elif util.is_bool_object(val):
if is_bool_array(values):
if is_bool_array(values, skipna=skipna):
return 'boolean'

elif PyString_Check(val):
if is_string_array(values):
if is_string_array(values, skipna=skipna):
return 'string'

elif PyUnicode_Check(val):
if is_unicode_array(values):
if is_unicode_array(values, skipna=skipna):
return 'unicode'

elif PyBytes_Check(val):
if is_bytes_array(values):
if is_bytes_array(values, skipna=skipna):
return 'bytes'

elif is_period(val):
Expand Down Expand Up @@ -593,10 +597,11 @@ cdef inline bint is_timedelta(object o):
return PyDelta_Check(o) or util.is_timedelta64_object(o)


cpdef bint is_bool_array(ndarray values):
cpdef bint is_bool_array(ndarray values, bint skipna=False):
cdef:
Py_ssize_t i, n = len(values)
ndarray[object] objbuf
object val

if issubclass(values.dtype.type, np.bool_):
return True
Expand All @@ -606,9 +611,16 @@ cpdef bint is_bool_array(ndarray values):
if n == 0:
return False

for i in range(n):
if not util.is_bool_object(objbuf[i]):
return False
if skipna:
for i in range(n):
val = objbuf[i]
if not util._checknull(val) and not util.is_bool_object(val):
return False
else:
for i in range(n):
val = objbuf[i]
if not util.is_bool_object(val):
return False
return True
else:
return False
Expand Down Expand Up @@ -639,6 +651,7 @@ cpdef bint is_integer_float_array(ndarray values):
cdef:
Py_ssize_t i, n = len(values)
ndarray[object] objbuf
object value

if issubclass(values.dtype.type, np.integer):
return True
Expand All @@ -649,9 +662,8 @@ cpdef bint is_integer_float_array(ndarray values):
return False

for i in range(n):
if not (util.is_integer_object(objbuf[i]) or
util.is_float_object(objbuf[i])):

val = objbuf[i]
if not (util.is_integer_object(val) or util.is_float_object(val)):
return False
return True
else:
Expand Down Expand Up @@ -679,10 +691,11 @@ cpdef bint is_float_array(ndarray values):
return False


cpdef bint is_string_array(ndarray values):
cpdef bint is_string_array(ndarray values, bint skipna=False):
cdef:
Py_ssize_t i, n = len(values)
ndarray[object] objbuf
object val

if ((PY2 and issubclass(values.dtype.type, np.string_)) or
not PY2 and issubclass(values.dtype.type, np.unicode_)):
Expand All @@ -693,18 +706,26 @@ cpdef bint is_string_array(ndarray values):
if n == 0:
return False

for i in range(n):
if not PyString_Check(objbuf[i]):
return False
if skipna:
for i in range(n):
val = objbuf[i]
if not util._checknull(val) and not PyString_Check(val):
return False
else:
for i in range(n):
val = objbuf[i]
if not PyString_Check(val):
return False
return True
else:
return False


cpdef bint is_unicode_array(ndarray values):
cpdef bint is_unicode_array(ndarray values, bint skipna=False):
cdef:
Py_ssize_t i, n = len(values)
ndarray[object] objbuf
object val

if issubclass(values.dtype.type, np.unicode_):
return True
Expand All @@ -714,18 +735,26 @@ cpdef bint is_unicode_array(ndarray values):
if n == 0:
return False

for i in range(n):
if not PyUnicode_Check(objbuf[i]):
return False
if skipna:
for i in range(n):
val = objbuf[i]
if not util._checknull(val) and not PyUnicode_Check(val):
return False
else:
for i in range(n):
val = objbuf[i]
if not PyUnicode_Check(val):
return False
return True
else:
return False


cpdef bint is_bytes_array(ndarray values):
cpdef bint is_bytes_array(ndarray values, bint skipna=False):
cdef:
Py_ssize_t i, n = len(values)
ndarray[object] objbuf
object val

if issubclass(values.dtype.type, np.bytes_):
return True
Expand All @@ -735,9 +764,16 @@ cpdef bint is_bytes_array(ndarray values):
if n == 0:
return False

for i in range(n):
if not PyBytes_Check(objbuf[i]):
return False
if skipna:
for i in range(n):
val = objbuf[i]
if not util._checknull(val) and not PyBytes_Check(val):
return False
else:
for i in range(n):
val = objbuf[i]
if not PyBytes_Check(val):
return False
return True
else:
return False
Expand Down Expand Up @@ -856,23 +892,45 @@ cpdef bint is_timedelta_or_timedelta64_array(ndarray values):
return null_count != n


cpdef bint is_date_array(ndarray[object] values):
cdef Py_ssize_t i, n = len(values)
cpdef bint is_date_array(ndarray[object] values, bint skipna=False):
cdef:
Py_ssize_t i, n = len(values)
object val

if n == 0:
return False
for i in range(n):
if not is_date(values[i]):
return False

if skipna:
for i in range(n):
val = values[i]
if not util._checknull(val) and not is_date(val):
return False
else:
for i in range(n):
val = values[i]
if not is_date(val):
return False
return True


cpdef bint is_time_array(ndarray[object] values):
cdef Py_ssize_t i, n = len(values)
cpdef bint is_time_array(ndarray[object] values, bint skipna=False):
cdef:
Py_ssize_t i, n = len(values)
object val

if n == 0:
return False
for i in range(n):
if not is_time(values[i]):
return False

if skipna:
for i in range(n):
val = values[i]
if not util._checknull(val) and not is_time(val):
return False
else:
for i in range(n):
val = values[i]
if not is_time(val):
return False
return True


Expand Down
33 changes: 31 additions & 2 deletions pandas/tests/dtypes/test_inference.py
Expand Up @@ -240,6 +240,9 @@ def test_infer_dtype_bytes(self):
arr = arr.astype(object)
assert lib.infer_dtype(arr) == compare

# object array of bytes with missing values
assert lib.infer_dtype([b'a', np.nan, b'c'], skipna=True) == compare

def test_isinf_scalar(self):
# GH 11352
assert lib.isposinf_scalar(float('inf'))
Expand Down Expand Up @@ -445,6 +448,10 @@ def test_bools(self):
result = lib.infer_dtype(arr)
assert result == 'boolean'

arr = np.array([True, np.nan, False], dtype='O')
result = lib.infer_dtype(arr, skipna=True)
assert result == 'boolean'

def test_floats(self):
arr = np.array([1., 2., 3., np.float64(4), np.float32(5)], dtype='O')
result = lib.infer_dtype(arr)
Expand Down Expand Up @@ -473,11 +480,26 @@ def test_decimals(self):
result = lib.infer_dtype(arr)
assert result == 'mixed'

arr = np.array([Decimal(1), Decimal('NaN'), Decimal(3)])
result = lib.infer_dtype(arr)
assert result == 'decimal'

arr = np.array([Decimal(1), np.nan, Decimal(3)], dtype='O')
result = lib.infer_dtype(arr)
assert result == 'decimal'

def test_string(self):
pass

def test_unicode(self):
pass
arr = [u'a', np.nan, u'c']
result = lib.infer_dtype(arr)
assert result == 'mixed'

arr = [u'a', np.nan, u'c']
result = lib.infer_dtype(arr, skipna=True)
expected = 'unicode' if PY2 else 'string'
assert result == expected

def test_datetime(self):

Expand Down Expand Up @@ -715,10 +737,17 @@ def test_is_datetimelike_array_all_nan_nat_like(self):

def test_date(self):

dates = [date(2012, 1, x) for x in range(1, 20)]
dates = [date(2012, 1, day) for day in range(1, 20)]
index = Index(dates)
assert index.inferred_type == 'date'

dates = [date(2012, 1, day) for day in range(1, 20)] + [np.nan]
result = lib.infer_dtype(dates)
assert result == 'mixed'

result = lib.infer_dtype(dates, skipna=True)
assert result == 'date'

def test_to_object_array_tuples(self):
r = (5, 6)
values = [r]
Expand Down

0 comments on commit c069ab1

Please sign in to comment.