Skip to content

Commit

Permalink
BUG: preserve join keys dtype
Browse files Browse the repository at this point in the history
- closes #8596, preserve join keys dtype  - adds ``Index.where``
method for all Index types (like ``np.where/Series.where``), but
preserves dtypes

Author: Jeff Reback <jeff@reback.net>
Author: Mike Kelly <mtk@numeric.com>

Closes #13170 from jreback/merge2 and squashes the following commits:

0a267cf [Jeff Reback] BUG: preserve merge keys dtypes when possible
4173dbf [Mike Kelly] Preserve dtype in merge keys when possible
  • Loading branch information
jreback committed May 27, 2016
1 parent 0f1666d commit e8d9e79
Show file tree
Hide file tree
Showing 14 changed files with 464 additions and 75 deletions.
1 change: 1 addition & 0 deletions doc/source/api.rst
Expand Up @@ -1333,6 +1333,7 @@ Modifying and Computations
Index.max
Index.reindex
Index.repeat
Index.where
Index.take
Index.putmask
Index.set_names
Expand Down
58 changes: 57 additions & 1 deletion doc/source/whatsnew/v0.18.2.txt
Expand Up @@ -77,11 +77,20 @@ Other enhancements
- The ``pd.read_csv()`` with ``engine='python'`` has gained support for the ``decimal`` option (:issue:`12933`)

- ``Index.astype()`` now accepts an optional boolean argument ``copy``, which allows optional copying if the requirements on dtype are satisfied (:issue:`13209`)
- ``Index`` now supports the ``.where()`` function for same shape indexing (:issue:`13170`)

.. ipython:: python

idx = pd.Index(['a', 'b', 'c'])
idx.where([True, False, True])

- ``Categorical.astype()`` now accepts an optional boolean argument ``copy``, effective when dtype is categorical (:issue:`13209`)
- Consistent with the Python API, ``pd.read_csv()`` will now interpret ``+inf`` as positive infinity (:issue:`13274`)

- ``pd.read_html()`` has gained support for the ``decimal`` option (:issue:`12907`)



.. _whatsnew_0182.api:

API changes
Expand Down Expand Up @@ -119,7 +128,6 @@ New Behavior:

type(s.tolist()[0])


.. _whatsnew_0182.api.promote:

``Series`` type promotion on assignment
Expand Down Expand Up @@ -171,6 +179,54 @@ This will now convert integers/floats with the default unit of ``ns``.

pd.to_datetime([1, 'foo'], errors='coerce')

.. _whatsnew_0182.api.merging:

Merging changes
^^^^^^^^^^^^^^^

Merging will now preserve the dtype of the join keys (:issue:`8596`)

.. ipython:: python

df1 = pd.DataFrame({'key': [1], 'v1': [10]})
df1
df2 = pd.DataFrame({'key': [1, 2], 'v1': [20, 30]})
df2

Previous Behavior:

.. code-block:: ipython

In [5]: pd.merge(df1, df2, how='outer')
Out[5]:
key v1
0 1.0 10.0
1 1.0 20.0
2 2.0 30.0

In [6]: pd.merge(df1, df2, how='outer').dtypes
Out[6]:
key float64
v1 float64
dtype: object

New Behavior:

We are able to preserve the join keys

.. ipython:: python

pd.merge(df1, df2, how='outer')
pd.merge(df1, df2, how='outer').dtypes

Of course if you have missing values that are introduced, then the
resulting dtype will be upcast (unchanged from previous).

.. ipython:: python

pd.merge(df1, df2, how='outer', on='key')
pd.merge(df1, df2, how='outer', on='key').dtypes

.. _whatsnew_0182.api.other:

Other API changes
Expand Down
18 changes: 18 additions & 0 deletions pandas/indexes/base.py
Expand Up @@ -465,6 +465,24 @@ def repeat(self, n, *args, **kwargs):
nv.validate_repeat(args, kwargs)
return self._shallow_copy(self._values.repeat(n))

def where(self, cond, other=None):
"""
.. versionadded:: 0.18.2
Return an Index of same shape as self and whose corresponding
entries are from self where cond is True and otherwise are from
other.
Parameters
----------
cond : boolean same length as self
other : scalar, or array-like
"""
if other is None:
other = self._na_value
values = np.where(cond, self.values, other)
return self._shallow_copy_with_infer(values, dtype=self.dtype)

def ravel(self, order='C'):
"""
return an ndarray of the flattened values of the underlying data
Expand Down
23 changes: 23 additions & 0 deletions pandas/indexes/category.py
Expand Up @@ -307,6 +307,29 @@ def _can_reindex(self, indexer):
""" always allow reindexing """
pass

def where(self, cond, other=None):
"""
.. versionadded:: 0.18.2
Return an Index of same shape as self and whose corresponding
entries are from self where cond is True and otherwise are from
other.
Parameters
----------
cond : boolean same length as self
other : scalar, or array-like
"""
if other is None:
other = self._na_value
values = np.where(cond, self.values, other)

from pandas.core.categorical import Categorical
cat = Categorical(values,
categories=self.categories,
ordered=self.ordered)
return self._shallow_copy(cat, **self._get_attributes_dict())

def reindex(self, target, method=None, level=None, limit=None,
tolerance=None):
"""
Expand Down
4 changes: 4 additions & 0 deletions pandas/indexes/multi.py
Expand Up @@ -1084,6 +1084,10 @@ def repeat(self, n, *args, **kwargs):
for label in self.labels], names=self.names,
sortorder=self.sortorder, verify_integrity=False)

def where(self, cond, other=None):
raise NotImplementedError(".where is not supported for "
"MultiIndex operations")

def drop(self, labels, level=None, errors='raise'):
"""
Make new MultiIndex with passed list of labels deleted
Expand Down
14 changes: 13 additions & 1 deletion pandas/tests/indexes/common.py
Expand Up @@ -7,7 +7,7 @@

from pandas import (Series, Index, Float64Index, Int64Index, RangeIndex,
MultiIndex, CategoricalIndex, DatetimeIndex,
TimedeltaIndex, PeriodIndex)
TimedeltaIndex, PeriodIndex, notnull)
from pandas.util.testing import assertRaisesRegexp

import pandas.util.testing as tm
Expand Down Expand Up @@ -363,6 +363,18 @@ def test_numpy_repeat(self):
tm.assertRaisesRegexp(ValueError, msg, np.repeat,
i, rep, axis=0)

def test_where(self):
i = self.create_index()
result = i.where(notnull(i))
expected = i
tm.assert_index_equal(result, expected)

i2 = i.copy()
i2 = pd.Index([np.nan, np.nan] + i[2:].tolist())
result = i.where(notnull(i2))
expected = i2
tm.assert_index_equal(result, expected)

def test_setops_errorcases(self):
for name, idx in compat.iteritems(self.indices):
# # non-iterable input
Expand Down
15 changes: 14 additions & 1 deletion pandas/tests/indexes/test_category.py
Expand Up @@ -11,7 +11,7 @@

import numpy as np

from pandas import Categorical, compat
from pandas import Categorical, compat, notnull
from pandas.util.testing import assert_almost_equal
import pandas.core.config as cf
import pandas as pd
Expand Down Expand Up @@ -230,6 +230,19 @@ def f(x):
ordered=False)
tm.assert_categorical_equal(result, exp)

def test_where(self):
i = self.create_index()
result = i.where(notnull(i))
expected = i
tm.assert_index_equal(result, expected)

i2 = i.copy()
i2 = pd.CategoricalIndex([np.nan, np.nan] + i[2:].tolist(),
categories=i.categories)
result = i.where(notnull(i2))
expected = i2
tm.assert_index_equal(result, expected)

def test_append(self):

ci = self.create_index()
Expand Down
67 changes: 66 additions & 1 deletion pandas/tests/indexes/test_datetimelike.py
Expand Up @@ -7,7 +7,7 @@
from pandas import (DatetimeIndex, Float64Index, Index, Int64Index,
NaT, Period, PeriodIndex, Series, Timedelta,
TimedeltaIndex, date_range, period_range,
timedelta_range)
timedelta_range, notnull)

import pandas.util.testing as tm

Expand Down Expand Up @@ -449,6 +449,38 @@ def test_astype_raises(self):
self.assertRaises(ValueError, idx.astype, 'datetime64')
self.assertRaises(ValueError, idx.astype, 'datetime64[D]')

def test_where_other(self):

# other is ndarray or Index
i = pd.date_range('20130101', periods=3, tz='US/Eastern')

for arr in [np.nan, pd.NaT]:
result = i.where(notnull(i), other=np.nan)
expected = i
tm.assert_index_equal(result, expected)

i2 = i.copy()
i2 = Index([pd.NaT, pd.NaT] + i[2:].tolist())
result = i.where(notnull(i2), i2)
tm.assert_index_equal(result, i2)

i2 = i.copy()
i2 = Index([pd.NaT, pd.NaT] + i[2:].tolist())
result = i.where(notnull(i2), i2.values)
tm.assert_index_equal(result, i2)

def test_where_tz(self):
i = pd.date_range('20130101', periods=3, tz='US/Eastern')
result = i.where(notnull(i))
expected = i
tm.assert_index_equal(result, expected)

i2 = i.copy()
i2 = Index([pd.NaT, pd.NaT] + i[2:].tolist())
result = i.where(notnull(i2))
expected = i2
tm.assert_index_equal(result, expected)

def test_get_loc(self):
idx = pd.date_range('2000-01-01', periods=3)

Expand Down Expand Up @@ -776,6 +808,39 @@ def test_get_loc(self):
with tm.assertRaises(KeyError):
idx.get_loc('2000-01-10', method='nearest', tolerance='1 day')

def test_where(self):
i = self.create_index()
result = i.where(notnull(i))
expected = i
tm.assert_index_equal(result, expected)

i2 = i.copy()
i2 = pd.PeriodIndex([pd.NaT, pd.NaT] + i[2:].tolist(),
freq='D')
result = i.where(notnull(i2))
expected = i2
tm.assert_index_equal(result, expected)

def test_where_other(self):

i = self.create_index()
for arr in [np.nan, pd.NaT]:
result = i.where(notnull(i), other=np.nan)
expected = i
tm.assert_index_equal(result, expected)

i2 = i.copy()
i2 = pd.PeriodIndex([pd.NaT, pd.NaT] + i[2:].tolist(),
freq='D')
result = i.where(notnull(i2), i2)
tm.assert_index_equal(result, i2)

i2 = i.copy()
i2 = pd.PeriodIndex([pd.NaT, pd.NaT] + i[2:].tolist(),
freq='D')
result = i.where(notnull(i2), i2.values)
tm.assert_index_equal(result, i2)

def test_get_indexer(self):
idx = pd.period_range('2000-01-01', periods=3).asfreq('H', how='start')
tm.assert_numpy_array_equal(idx.get_indexer(idx), [0, 1, 2])
Expand Down
8 changes: 8 additions & 0 deletions pandas/tests/indexes/test_multi.py
Expand Up @@ -78,6 +78,14 @@ def test_labels_dtypes(self):
self.assertTrue((i.labels[0] >= 0).all())
self.assertTrue((i.labels[1] >= 0).all())

def test_where(self):
i = MultiIndex.from_tuples([('A', 1), ('A', 2)])

def f():
i.where(True)

self.assertRaises(NotImplementedError, f)

def test_repeat(self):
reps = 2
numbers = [1, 2, 3]
Expand Down
40 changes: 40 additions & 0 deletions pandas/tests/types/test_types.py
@@ -0,0 +1,40 @@
# -*- coding: utf-8 -*-
import nose
import numpy as np

from pandas import NaT
from pandas.types.api import (DatetimeTZDtype, CategoricalDtype,
na_value_for_dtype, pandas_dtype)


def test_pandas_dtype():

assert pandas_dtype('datetime64[ns, US/Eastern]') == DatetimeTZDtype(
'datetime64[ns, US/Eastern]')
assert pandas_dtype('category') == CategoricalDtype()
for dtype in ['M8[ns]', 'm8[ns]', 'object', 'float64', 'int64']:
assert pandas_dtype(dtype) == np.dtype(dtype)


def test_na_value_for_dtype():
for dtype in [np.dtype('M8[ns]'), np.dtype('m8[ns]'),
DatetimeTZDtype('datetime64[ns, US/Eastern]')]:
assert na_value_for_dtype(dtype) is NaT

for dtype in ['u1', 'u2', 'u4', 'u8',
'i1', 'i2', 'i4', 'i8']:
assert na_value_for_dtype(np.dtype(dtype)) == 0

for dtype in ['bool']:
assert na_value_for_dtype(np.dtype(dtype)) is False

for dtype in ['f2', 'f4', 'f8']:
assert np.isnan(na_value_for_dtype(np.dtype(dtype)))

for dtype in ['O']:
assert np.isnan(na_value_for_dtype(np.dtype(dtype)))


if __name__ == '__main__':
nose.runmodule(argv=[__file__, '-vvs', '-x', '--pdb', '--pdb-failure'],
exit=False)

0 comments on commit e8d9e79

Please sign in to comment.