diff --git a/torcharrow/_interop.py b/torcharrow/_interop.py index 17a6042f9..81e86f6e6 100644 --- a/torcharrow/_interop.py +++ b/torcharrow/_interop.py @@ -229,7 +229,7 @@ def _dtype_to_arrowtype(t: dt.DType) -> pa.DataType: pa.field( f.name, _dtype_to_arrowtype(f.dtype), nullable=f.dtype.nullable ) - for f in t.fields + for f in cast(dt.Struct, t).fields ] ) raise NotImplementedError(f"Unsupported DType to Arrow type: {str(t)}") diff --git a/torcharrow/dtypes.py b/torcharrow/dtypes.py index 35ad16f9e..b83c1703f 100644 --- a/torcharrow/dtypes.py +++ b/torcharrow/dtypes.py @@ -54,8 +54,6 @@ def __str__(self): @dataclass(frozen=True) # type: ignore class DType(ABC): - fields: ty.ClassVar[ty.Any] = NotImplementedError - typecode: ty.ClassVar[str] = "__TO_BE_DEFINED_IN_SUBCLASS__" arraycode: ty.ClassVar[str] = "__TO_BE_DEFINED_IN_SUBCLASS__" @@ -573,7 +571,7 @@ def contains_tuple(t: DType): # pyre-fixme[16]: `DType` has no attribute `key_dtype`. return contains_tuple(t.key_dtype) or contains_tuple(t.item_dtype) if is_struct(t): - return any(contains_tuple(f.dtype) for f in t.fields) + return any(contains_tuple(f.dtype) for f in ty.cast(Struct, t).fields) return False @@ -708,9 +706,13 @@ def common_dtype(l: DType, r: DType) -> ty.Optional[DType]: return String(l.nullable or r.nullable) if is_boolean_or_numerical(l) and is_boolean_or_numerical(r): return promote(l, r) - if is_tuple(l) and is_tuple(r) and len(l.fields) == len(r.fields): + if ( + is_tuple(l) + and is_tuple(r) + and len(ty.cast(Struct, l).fields) == len(ty.cast(Struct, r).fields) + ): res = [] - for i, j in zip(l.fields, r.fields): + for i, j in zip(ty.cast(Struct, l).fields, ty.cast(Struct, r).fields): m = common_dtype(i, j) if m is None: return None diff --git a/torcharrow/idataframe.py b/torcharrow/idataframe.py index b7a881d10..751e2d520 100644 --- a/torcharrow/idataframe.py +++ b/torcharrow/idataframe.py @@ -11,6 +11,7 @@ from typing import ( Any, Callable, + cast, Dict, get_type_hints, Iterable, @@ -171,6 +172,11 @@ def columns(self): """The column labels of the DataFrame.""" return [f.name for f in self.dtype.fields] + @property + @traceproperty + def dtype(self) -> dt.Struct: + return cast(dt.Struct, self._dtype) + def __contains__(self, key: str) -> bool: for f in self.dtype.fields: if key == f.name: diff --git a/torcharrow/velox_rt/dataframe_cpu.py b/torcharrow/velox_rt/dataframe_cpu.py index 9e655e6c7..2d9fe7088 100644 --- a/torcharrow/velox_rt/dataframe_cpu.py +++ b/torcharrow/velox_rt/dataframe_cpu.py @@ -329,7 +329,6 @@ def __setitem__(self, name: str, value: Any) -> None: empty_df = len(self.dtype.fields) == 0 # Update dtype - # pyre-fixme[16]: `DType` has no attribute `get_index`. idx = self.dtype.get_index(name) if idx is None: # append column @@ -354,7 +353,7 @@ def _set_field_data(self, name: str, col: Column, empty_df: bool): new_delegate.set_length(len(col._data)) # Set columns for new_delegate - for idx in range(len(self._dtype.fields)): + for idx in range(len(self.dtype.fields)): if idx != column_idx: new_delegate.set_child(idx, self._data.child_at(idx)) else: @@ -1477,6 +1476,7 @@ def __pos__(self): def log(self) -> DataFrameCpu: return self._fromdata( + # pyre-fixme[6]: Incompatible parameter type [6]: In call `DataFrameCpu._fromdata`, for 1st positional only parameter expected `OrderedDict[str, Column]` but got `Dict[str, typing.Any]` { self.dtype.fields[i] # pyre-fixme[16]: `Column` has no attribute `log`. @@ -1498,6 +1498,7 @@ def log(self) -> DataFrameCpu: def isin(self, values: Union[list, dict, Column]): if isinstance(values, list): return self._fromdata( + # pyre-fixme[6]: Incompatible parameter type [6]: In call `DataFrameCpu._fromdata`, for 1st positional only parameter expected `OrderedDict[str, Column]` but got `Dict[str, typing.Any]` { self.dtype.fields[i] .name: ColumnCpuMixin._from_velox( @@ -1538,6 +1539,7 @@ def fill_null(self, fill_value: Optional[Union[dt.ScalarTypes, Dict]]): return self if isinstance(fill_value, Column._scalar_types): return self._fromdata( + # pyre-fixme[6]: Incompatible parameter type [6]: In call `DataFrameCpu._fromdata`, for 1st positional only parameter expected `OrderedDict[str, Column]` but got `Dict[str, typing.Any]` { self.dtype.fields[i] .name: ColumnCpuMixin._from_velox( @@ -1844,6 +1846,7 @@ def drop(self, columns: Union[str, List[str]]): columns = [columns] self._check_columns(columns) return self._fromdata( + # pyre-fixme[6]: Incompatible parameter type [6]: In call `DataFrameCpu._fromdata`, for 1st positional only parameter expected `OrderedDict[str, Column]` but got `Dict[str, typing.Any]` { self.dtype.fields[i].name: ColumnCpuMixin._from_velox( self.device, @@ -1865,6 +1868,7 @@ def _keep(self, columns: List[str]): """ self._check_columns(columns) return self._fromdata( + # pyre-fixme[6]: Incompatible parameter type [6]: In call `DataFrameCpu._fromdata`, for 1st positional only parameter expected `OrderedDict[str, Column]` but got `Dict[str, typing.Any]` { self.dtype.fields[i].name: ColumnCpuMixin._from_velox( self.device, @@ -2173,7 +2177,6 @@ def groupby( key_fields = [] item_fields = [] for k in key_columns: - # pyre-fixme[16]: `DType` has no attribute `get`. key_fields.append(dt.Field(k, self.dtype.get(k))) for f in self.dtype.fields: if f.name not in key_columns: