Skip to content

Commit

Permalink
feat[python]: Preserve subclass for DataFrame.select (#4620)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Aug 30, 2022
1 parent 5a6b32f commit a772412
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions py-polars/polars/internals/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1525,10 +1525,10 @@ def _compare_to_other_df(
return combined.select(expr)

def _compare_to_non_df(
self,
self: DF,
other: Any,
op: ComparisonOperator,
) -> DataFrame:
) -> DF:
"""Compare a DataFrame with a non-DataFrame object."""
if op == "eq":
return self.select(pli.all() == other)
Expand Down Expand Up @@ -2018,7 +2018,7 @@ def to_series(self, index: int = 0) -> pli.Series:
index = len(self.columns) + index
return pli.wrap_s(self._df.select_at_idx(index))

def reverse(self) -> DataFrame:
def reverse(self: DF) -> DF:
"""
Reverse the DataFrame.
Expand Down Expand Up @@ -4311,11 +4311,11 @@ def get_column(self, name: str) -> pli.Series:
return self[name]

def fill_null(
self,
self: DF,
value: Any | None = None,
strategy: FillNullStrategy | None = None,
limit: int | None = None,
) -> DataFrame:
) -> DF:
"""
Fill null values using the specified value or strategy.
Expand Down Expand Up @@ -4922,9 +4922,9 @@ def lazy(self: DF) -> pli.LazyFrame:
return pli.wrap_ldf(self._df.lazy())

def select(
self,
self: DF,
exprs: str | pli.Expr | pli.Series | Sequence[str | pli.Expr | pli.Series],
) -> DataFrame:
) -> DF:
"""
Select columns from this DataFrame.
Expand Down Expand Up @@ -4957,8 +4957,11 @@ def select(
└─────┘
"""
return (
self.lazy().select(exprs).collect(no_optimization=True, string_cache=False)
return self._from_pydf(
self.lazy()
.select(exprs)
.collect(no_optimization=True, string_cache=False)
._df
)

def with_columns(
Expand Down Expand Up @@ -5397,7 +5400,7 @@ def median(self: DF) -> DF:
"""
return self._from_pydf(self._df.median())

def product(self) -> DataFrame:
def product(self: DF) -> DF:
"""
Aggregate the columns of this DataFrame to their product values.
Expand Down Expand Up @@ -5813,7 +5816,7 @@ def shrink_to_fit(self: DF, in_place: bool = False) -> DF | None:
df._df.shrink_to_fit()
return df

def take_every(self, n: int) -> DataFrame:
def take_every(self: DF, n: int) -> DF:
"""
Take every nth row in the DataFrame and return as a new DataFrame.
Expand Down Expand Up @@ -5883,7 +5886,7 @@ def hash_rows(
k3 = seed_3 if seed_3 is not None else seed
return pli.wrap_s(self._df.hash_rows(k0, k1, k2, k3))

def interpolate(self) -> DataFrame:
def interpolate(self: DF) -> DF:
"""
Interpolate intermediate values. The interpolation method is linear.
Expand Down

0 comments on commit a772412

Please sign in to comment.