diff --git a/third_party/3/pyspark/sql/dataframe.pyi b/third_party/3/pyspark/sql/dataframe.pyi index 03d96628..9a025c17 100644 --- a/third_party/3/pyspark/sql/dataframe.pyi +++ b/third_party/3/pyspark/sql/dataframe.pyi @@ -10,6 +10,7 @@ from py4j.java_gateway import JavaObject # type: ignore from pyspark.sql._typing import ColumnOrName, Literal, LiteralType from pyspark.sql.types import * from pyspark.sql.context import SQLContext +from pyspark.sql.group import GroupedData from pyspark.sql.readwriter import DataFrameWriter from pyspark.sql.streaming import DataStreamWriter from pyspark.sql.column import Column @@ -91,17 +92,17 @@ class DataFrame: def selectExpr(self, *expr: List[str]) -> 'DataFrame': ... def filter(self, condition: ColumnOrName) -> 'DataFrame': ... @overload - def groupBy(self, *cols: ColumnOrName) -> 'DataFrame': ... + def groupBy(self, *cols: ColumnOrName) -> GroupedData: ... @overload - def groupBy(self, __cols: List[ColumnOrName]) -> 'DataFrame': ... + def groupBy(self, __cols: List[ColumnOrName]) -> GroupedData: ... @overload - def rollup(self, *cols: ColumnOrName) -> 'DataFrame': ... + def rollup(self, *cols: ColumnOrName) -> GroupedData: ... @overload - def rollup(self, __cols: List[ColumnOrName]) -> 'DataFrame': ... + def rollup(self, __cols: List[ColumnOrName]) -> GroupedData: ... @overload - def cube(self, *cols: ColumnOrName) -> 'DataFrame': ... + def cube(self, *cols: ColumnOrName) -> GroupedData: ... @overload - def cube(self, __cols: List[ColumnOrName]) -> 'DataFrame': ... + def cube(self, __cols: List[ColumnOrName]) -> GroupedData: ... def agg(self, *exprs: Union[Column, Dict[str, str]]) -> 'DataFrame': ... def union(self, other: 'DataFrame') -> 'DataFrame': ... def unionAll(self, other: 'DataFrame') -> 'DataFrame': ...