Skip to content
This repository was archived by the owner on Nov 1, 2024. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 5 additions & 30 deletions torcharrow/idataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,44 +171,16 @@ def columns(self):
"""The column labels of the DataFrame."""
return [f.name for f in self.dtype.fields]

@abc.abstractmethod
def _set_field_data(self, name: str, col: Column, empty_df: bool):
"""
PRIVATE _set field data, append if field doesn't exist
self._dtype is already updated upon invocation
"""
raise self._not_supported("_set_field_data")

def __contains__(self, key: str) -> bool:
for f in self.dtype.fields:
if key == f.name:
return True
return False

@trace
@abc.abstractmethod
def __setitem__(self, name: str, value: Any) -> None:
if isinstance(value, Column):
assert self.device == value.device
col = value
else:
col = ta.column(value)

empty_df = len(self.dtype.fields) == 0

# Update dtype
# pyre-fixme[16]: `DType` has no attribute `get_index`.
idx = self.dtype.get_index(name)
if idx is None:
# append column
new_fields = self.dtype.fields + [dt.Field(name, col.dtype)]
else:
# override column
new_fields = list(self.dtype.fields)
new_fields[idx] = dt.Field(name, col.dtype)
self._dtype = dt.Struct(fields=new_fields)

# Update field data
self._set_field_data(name, col, empty_df)
raise NotImplementedError

@trace
def copy(self):
Expand Down Expand Up @@ -659,6 +631,9 @@ class DataFrameVar(Var, DataFrame):
def __init__(self, name: str, qualname: str = ""):
super().__init__(name, qualname)

def __setitem__(self, name: str, value: Any) -> None:
return self._not_supported("__setitem__")

def _append_null(self):
return self._not_supported("_append_null")

Expand Down
16 changes: 8 additions & 8 deletions torcharrow/test/test_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,11 +259,11 @@ def test_simple_df_ops_succeed(self):
# print("TRACE", cmds(Scope.default.trace.statements()))
verdict = [
"c0 = torcharrow.scope.Scope._DataFrame(None, dtype=None, columns=None, device='cpu')",
"_ = torcharrow.idataframe.DataFrame.__setitem__(c0, 'a', [1, 2, 3])",
"_ = torcharrow.idataframe.DataFrame.__setitem__(c0, 'b', [11, 22, 33])",
"_ = torcharrow.idataframe.DataFrame.__setitem__(c0, 'c', [111, 222, 333])",
"_ = torcharrow.velox_rt.dataframe_cpu.DataFrameCpu.__setitem__(c0, 'a', [1, 2, 3])",
"_ = torcharrow.velox_rt.dataframe_cpu.DataFrameCpu.__setitem__(c0, 'b', [11, 22, 33])",
"_ = torcharrow.velox_rt.dataframe_cpu.DataFrameCpu.__setitem__(c0, 'c', [111, 222, 333])",
"c4 = torcharrow.icolumn.Column.__getitem__(c0, 'a')",
"_ = torcharrow.idataframe.DataFrame.__setitem__(c0, 'd', c4)",
"_ = torcharrow.velox_rt.dataframe_cpu.DataFrameCpu.__setitem__(c0, 'd', c4)",
"c11 = torcharrow.velox_rt.dataframe_cpu.DataFrameCpu.drop(c0, ['a'])",
"c16 = torcharrow.icolumn.Column.__getitem__(c0, ['a', 'c'])",
"c21 = torcharrow.velox_rt.dataframe_cpu.DataFrameCpu.rename(c16, {'c': 'e'})",
Expand All @@ -285,7 +285,7 @@ def test_df_trace_equivalence(self):
df["a"] = [1, 2, 3]
self.assertEqual(
Scope.default.trace.statements()[-1],
"_ = torcharrow.idataframe.DataFrame.__setitem__(c0, 'a', [1, 2, 3])",
"_ = torcharrow.velox_rt.dataframe_cpu.DataFrameCpu.__setitem__(c0, 'a', [1, 2, 3])",
)

df["b"] = [11, 22, 33]
Expand All @@ -295,9 +295,9 @@ def test_df_trace_equivalence(self):
# print("TRACE", cmds(Scope.default.trace.statements()))
verdict = [
"c0 = torcharrow.scope.Scope._DataFrame(None, dtype=None, columns=None, device='cpu')",
"_ = torcharrow.idataframe.DataFrame.__setitem__(c0, 'a', [1, 2, 3])",
"_ = torcharrow.idataframe.DataFrame.__setitem__(c0, 'b', [11, 22, 33])",
"_ = torcharrow.idataframe.DataFrame.__setitem__(c0, 'c', [111, 222, 333])",
"_ = torcharrow.velox_rt.dataframe_cpu.DataFrameCpu.__setitem__(c0, 'a', [1, 2, 3])",
"_ = torcharrow.velox_rt.dataframe_cpu.DataFrameCpu.__setitem__(c0, 'b', [11, 22, 33])",
"_ = torcharrow.velox_rt.dataframe_cpu.DataFrameCpu.__setitem__(c0, 'c', [111, 222, 333])",
"c9 = torcharrow.velox_rt.dataframe_cpu.DataFrameCpu.where(c0, torcharrow.idataframe.me.__getitem__('a').__gt__(1))",
]

Expand Down
25 changes: 25 additions & 0 deletions torcharrow/velox_rt/dataframe_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,31 @@ def _check_columns(self, columns: Iterable[str]):

# implementing abstract methods ----------------------------------------------

@trace
def __setitem__(self, name: str, value: Any) -> None:
if isinstance(value, Column):
assert self.device == value.device
col = value
else:
col = ta.column(value)

empty_df = len(self.dtype.fields) == 0

# Update dtype
# pyre-fixme[16]: `DType` has no attribute `get_index`.
idx = self.dtype.get_index(name)
if idx is None:
# append column
new_fields = self.dtype.fields + [dt.Field(name, col.dtype)]
else:
# override column
new_fields = list(self.dtype.fields)
new_fields[idx] = dt.Field(name, col.dtype)
self._dtype = dt.Struct(fields=new_fields)

# Update field data
self._set_field_data(name, col, empty_df)

def _set_field_data(self, name: str, col: Column, empty_df: bool):
if not empty_df and len(col) != len(self):
raise TypeError("all columns/lists must have equal length")
Expand Down