Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

Already on GitHub? Sign in to your account

ENH: Allow update to use an on keyword. Allow one to many update. #6604

Closed
wants to merge 3 commits into
from
Jump to file or symbol
Failed to load files and symbols.
+89 −25
Split
View
@@ -3003,10 +3003,9 @@ def combiner(x, y, needs_i8_conversion=False):
return self.combine(other, combiner, overwrite=False)
def update(self, other, join='left', overwrite=True, filter_func=None,
- raise_conflict=False):
+ raise_conflict=False, on=None):
"""
- Modify DataFrame in place using non-NA values from passed
- DataFrame. Aligns on indices
+ Modify DataFrame in place using non-NA values from passed DataFrame.
Parameters
----------
@@ -3020,6 +3019,10 @@ def update(self, other, join='left', overwrite=True, filter_func=None,
raise_conflict : boolean
If True, will raise an error if the DataFrame and other both
contain data in the same place.
+ on : label or list, optional
+ Identify the column to should match up observations in other and
+ self. If None, other.reindex_like(self) is called so the index
+ must match to get a meaningful result.
"""
# TODO: Support other joins
if join != 'left': # pragma: no cover
@@ -3028,31 +3031,55 @@ def update(self, other, join='left', overwrite=True, filter_func=None,
if not isinstance(other, DataFrame):
other = DataFrame(other)
- other = other.reindex_like(self)
+ if on is None:
+ other = other.reindex(index=self.index)
+ else:
+ try:
+ old_index = self.index
+ col_order = self.columns
+ self.set_index(on, inplace=True)
+ other.set_index(on, inplace=True)
+ other = other.reindex(index=self.index)
+ except Exception, err:
+ self.reset_index(inplace=True)
+ self.set_index(old_index)
+ raise(err)
- for col in self.columns:
- this = self[col].values
- that = other[col].values
- if filter_func is not None:
- mask = -filter_func(this) | isnull(that)
- else:
- if raise_conflict:
- mask_this = notnull(that)
- mask_that = notnull(this)
- if any(mask_this & mask_that):
- raise ValueError("Data overlaps.")
+ try:
+ for col in other.columns:
+ if col not in self: # don't update what doesn't exist
+ continue
+ this = self[col].values
+ that = other[col].values
+ if filter_func is not None:
+ mask = -filter_func(this) | isnull(that)
+ else:
+ if raise_conflict:
+ mask_this = notnull(that)
+ mask_that = notnull(this)
+ if any(mask_this & mask_that):
+ raise ValueError("Data overlaps.")
+
+ if overwrite:
+ mask = isnull(that)
+
+ # don't overwrite columns unecessarily
+ if mask.all():
+ continue
+ else:
+ mask = notnull(this)
- if overwrite:
- mask = isnull(that)
+ self[col] = expressions.where(
+ mask, this, that, raise_on_error=True)
- # don't overwrite columns unecessarily
- if mask.all():
- continue
- else:
- mask = notnull(this)
+ except Exception, err:
+ raise(err)
- self[col] = expressions.where(
- mask, this, that, raise_on_error=True)
+ finally:
+ if on is not None:
+ self.reset_index(inplace=True)
+ self.set_index(old_index)
+ self = self[col_order]
#----------------------------------------------------------------------
# Misc methods
View
@@ -24,7 +24,7 @@
from numpy.random import randn
import numpy as np
import numpy.ma as ma
-from numpy.testing import assert_array_equal
+from numpy.testing import assert_array_equal, assert_
import numpy.ma.mrecords as mrecords
import pandas.core.nanops as nanops
@@ -9974,6 +9974,43 @@ def test_update(self):
[1.5, nan, 7.]])
assert_frame_equal(df, expected)
+ def test_update_on(self):
+ df = DataFrame([[np.nan, 'A'],
+ [np.nan, 'A'],
+ [np.nan, 'A'],
+ [1.5, 'B'],
+ [2.2, 'C'],
+ [3.1, 'C'],
+ [1.2, 'B']], columns=['number', 'name'])
+
+ df2 = DataFrame([[3.5, 'A']], columns=['number', 'name'])
+
+ expected = DataFrame([[3.5, 'A'],
+ [3.5, 'A'],
+ [3.5, 'A'],
+ [1.5, 'B'],
+ [2.2, 'C'],
+ [3.1, 'C'],
+ [1.2, 'B']], columns=['number', 'name'])
+ df.update(df2, on='name')
+ assert_frame_equal(df, expected)
+
+ df = DataFrame([[np.nan, 'A'],
+ [np.nan, 'A'],
+ [np.nan, 'A'],
+ [1.5, 'B'],
+ [2.2, 'C'],
+ [3.1, 'C'],
+ [1.2, 'B']], columns=['number', 'name'])
+
+ df2 = DataFrame([[3.5, 'A'], [2.5, 'A']],
+ columns=['number', 'name'])
+
+ assertRaises(ValueError, df.update, df2, on='name')
+
+ ## and the index should be reset
+ assert_(df.index.equals(pd.Index(range(7))))
+
def test_update_dtypes(self):
# gh 3016