From 1411b7157ead7965a923f6aede921b77cad0004e Mon Sep 17 00:00:00 2001 From: Wenlei Xie Date: Tue, 21 Jun 2022 21:20:20 -0700 Subject: [PATCH] Lazy update for dtype in tracing mode (#383) Summary: Pull Request resolved: https://github.com/pytorch/torcharrow/pull/383 X-link: https://github.com/facebookresearch/torcharrow/pull/383 Reviewed By: dracifer Differential Revision: D37192215 fbshipit-source-id: f223920b5e50c1bebdbdc151e723ecc5b17c9eb4 --- torcharrow/idataframe.py | 35 ++++------------------------ torcharrow/test/test_trace.py | 16 ++++++------- torcharrow/velox_rt/dataframe_cpu.py | 25 ++++++++++++++++++++ 3 files changed, 38 insertions(+), 38 deletions(-) diff --git a/torcharrow/idataframe.py b/torcharrow/idataframe.py index 2da9ec335..eb417497a 100644 --- a/torcharrow/idataframe.py +++ b/torcharrow/idataframe.py @@ -171,14 +171,6 @@ def columns(self): """The column labels of the DataFrame.""" return [f.name for f in self.dtype.fields] - @abc.abstractmethod - def _set_field_data(self, name: str, col: Column, empty_df: bool): - """ - PRIVATE _set field data, append if field doesn't exist - self._dtype is already updated upon invocation - """ - raise self._not_supported("_set_field_data") - def __contains__(self, key: str) -> bool: for f in self.dtype.fields: if key == f.name: @@ -186,29 +178,9 @@ def __contains__(self, key: str) -> bool: return False @trace + @abc.abstractmethod def __setitem__(self, name: str, value: Any) -> None: - if isinstance(value, Column): - assert self.device == value.device - col = value - else: - col = ta.column(value) - - empty_df = len(self.dtype.fields) == 0 - - # Update dtype - # pyre-fixme[16]: `DType` has no attribute `get_index`. - idx = self.dtype.get_index(name) - if idx is None: - # append column - new_fields = self.dtype.fields + [dt.Field(name, col.dtype)] - else: - # override column - new_fields = list(self.dtype.fields) - new_fields[idx] = dt.Field(name, col.dtype) - self._dtype = dt.Struct(fields=new_fields) - - # Update field data - self._set_field_data(name, col, empty_df) + raise NotImplementedError @trace def copy(self): @@ -659,6 +631,9 @@ class DataFrameVar(Var, DataFrame): def __init__(self, name: str, qualname: str = ""): super().__init__(name, qualname) + def __setitem__(self, name: str, value: Any) -> None: + return self._not_supported("__setitem__") + def _append_null(self): return self._not_supported("_append_null") diff --git a/torcharrow/test/test_trace.py b/torcharrow/test/test_trace.py index 87091b496..d4319730f 100644 --- a/torcharrow/test/test_trace.py +++ b/torcharrow/test/test_trace.py @@ -259,11 +259,11 @@ def test_simple_df_ops_succeed(self): # print("TRACE", cmds(Scope.default.trace.statements())) verdict = [ "c0 = torcharrow.scope.Scope._DataFrame(None, dtype=None, columns=None, device='cpu')", - "_ = torcharrow.idataframe.DataFrame.__setitem__(c0, 'a', [1, 2, 3])", - "_ = torcharrow.idataframe.DataFrame.__setitem__(c0, 'b', [11, 22, 33])", - "_ = torcharrow.idataframe.DataFrame.__setitem__(c0, 'c', [111, 222, 333])", + "_ = torcharrow.velox_rt.dataframe_cpu.DataFrameCpu.__setitem__(c0, 'a', [1, 2, 3])", + "_ = torcharrow.velox_rt.dataframe_cpu.DataFrameCpu.__setitem__(c0, 'b', [11, 22, 33])", + "_ = torcharrow.velox_rt.dataframe_cpu.DataFrameCpu.__setitem__(c0, 'c', [111, 222, 333])", "c4 = torcharrow.icolumn.Column.__getitem__(c0, 'a')", - "_ = torcharrow.idataframe.DataFrame.__setitem__(c0, 'd', c4)", + "_ = torcharrow.velox_rt.dataframe_cpu.DataFrameCpu.__setitem__(c0, 'd', c4)", "c11 = torcharrow.velox_rt.dataframe_cpu.DataFrameCpu.drop(c0, ['a'])", "c16 = torcharrow.icolumn.Column.__getitem__(c0, ['a', 'c'])", "c21 = torcharrow.velox_rt.dataframe_cpu.DataFrameCpu.rename(c16, {'c': 'e'})", @@ -285,7 +285,7 @@ def test_df_trace_equivalence(self): df["a"] = [1, 2, 3] self.assertEqual( Scope.default.trace.statements()[-1], - "_ = torcharrow.idataframe.DataFrame.__setitem__(c0, 'a', [1, 2, 3])", + "_ = torcharrow.velox_rt.dataframe_cpu.DataFrameCpu.__setitem__(c0, 'a', [1, 2, 3])", ) df["b"] = [11, 22, 33] @@ -295,9 +295,9 @@ def test_df_trace_equivalence(self): # print("TRACE", cmds(Scope.default.trace.statements())) verdict = [ "c0 = torcharrow.scope.Scope._DataFrame(None, dtype=None, columns=None, device='cpu')", - "_ = torcharrow.idataframe.DataFrame.__setitem__(c0, 'a', [1, 2, 3])", - "_ = torcharrow.idataframe.DataFrame.__setitem__(c0, 'b', [11, 22, 33])", - "_ = torcharrow.idataframe.DataFrame.__setitem__(c0, 'c', [111, 222, 333])", + "_ = torcharrow.velox_rt.dataframe_cpu.DataFrameCpu.__setitem__(c0, 'a', [1, 2, 3])", + "_ = torcharrow.velox_rt.dataframe_cpu.DataFrameCpu.__setitem__(c0, 'b', [11, 22, 33])", + "_ = torcharrow.velox_rt.dataframe_cpu.DataFrameCpu.__setitem__(c0, 'c', [111, 222, 333])", "c9 = torcharrow.velox_rt.dataframe_cpu.DataFrameCpu.where(c0, torcharrow.idataframe.me.__getitem__('a').__gt__(1))", ] diff --git a/torcharrow/velox_rt/dataframe_cpu.py b/torcharrow/velox_rt/dataframe_cpu.py index bfefe3277..c86e8244b 100644 --- a/torcharrow/velox_rt/dataframe_cpu.py +++ b/torcharrow/velox_rt/dataframe_cpu.py @@ -315,6 +315,31 @@ def _check_columns(self, columns: Iterable[str]): # implementing abstract methods ---------------------------------------------- + @trace + def __setitem__(self, name: str, value: Any) -> None: + if isinstance(value, Column): + assert self.device == value.device + col = value + else: + col = ta.column(value) + + empty_df = len(self.dtype.fields) == 0 + + # Update dtype + # pyre-fixme[16]: `DType` has no attribute `get_index`. + idx = self.dtype.get_index(name) + if idx is None: + # append column + new_fields = self.dtype.fields + [dt.Field(name, col.dtype)] + else: + # override column + new_fields = list(self.dtype.fields) + new_fields[idx] = dt.Field(name, col.dtype) + self._dtype = dt.Struct(fields=new_fields) + + # Update field data + self._set_field_data(name, col, empty_df) + def _set_field_data(self, name: str, col: Column, empty_df: bool): if not empty_df and len(col) != len(self): raise TypeError("all columns/lists must have equal length")