Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

Already on GitHub? Sign in to your account

Add support for Series argument in DataFrame.dot, align the arguments if needed #1915

Closed
wants to merge 1 commit into
from
Jump to file or symbol
Failed to load files and symbols.
+26 −8
Split
View
@@ -710,20 +710,29 @@ def __neg__(self):
def dot(self, other):
"""
- Matrix multiplication with DataFrame objects. Does no data alignment
+ Matrix multiplication with DataFrame or Series objects
Parameters
----------
- other : DataFrame
+ other : DataFrame or Series
Returns
-------
- dot_product : DataFrame
- """
- lvals = self.values
- rvals = other.values
- result = np.dot(lvals, rvals)
- return DataFrame(result, index=self.index, columns=other.columns)
+ dot_product : DataFrame or Series
+ """
+ common = self.columns.union(other.index)
+ if len(common) > len(self.columns) or len(common) > len(other.index):
+ raise ValueError('matrices are not aligned')
+ left = self.reindex(columns=common, copy=False)
+ right = other.reindex(index=common, copy=False)
+ lvals = left.values
+ rvals = right.values
+ if isinstance(other, DataFrame):
+ return DataFrame(np.dot(lvals, rvals), index=self.index, columns=other.columns)
+ elif isinstance(other, Series):
+ return Series(np.dot(lvals, rvals), index=left.index)
+ else:
+ raise TypeError('unsupported type: %s' % type(other))
#----------------------------------------------------------------------
# IO methods (to / from other formats)
@@ -6629,8 +6629,17 @@ def test_dot(self):
expected = DataFrame(np.dot(a.values, b.values),
index=['a', 'b', 'c'],
columns=['one', 'two'])
+ #Check alignment
+ b1 = b.reindex(index=reversed(b.index))
+ result = a.dot(b)
assert_frame_equal(result, expected)
+ #Check series argument
+ result = a.dot(b['one'])
+ assert_series_equal(result, expected['one'])
+ result = a.dot(b1['one'])
+ assert_series_equal(result, expected['one'])
+
def test_idxmin(self):
frame = self.frame
frame.ix[5:10] = np.nan