Skip to content

Commit

Permalink
Align dataframe code wit pylint recommendations
Browse files Browse the repository at this point in the history
  • Loading branch information
tools4origins committed Nov 13, 2019
1 parent 01f271f commit ca1802a
Showing 1 changed file with 107 additions and 84 deletions.
191 changes: 107 additions & 84 deletions pysparkling/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,14 @@ def exceptAll(self, other):
>>> from pysparkling import Context
>>> from pysparkling.sql.session import SparkSession
>>> spark = SparkSession(Context())
>>> df1 = spark.createDataFrame([("a", 1), ("a", 1), ("a", 1), ("a", 2), ("b", 3), ("c", 4)], ["C1", "C2"])
>>> df1 = spark.createDataFrame([
... ("a", 1),
... ("a", 1),
... ("a", 1),
... ("a", 2),
... ("b", 3),
... ("c", 4)
... ], ["C1", "C2"])
>>> df2 = spark.createDataFrame([("a", 1), ("b", 3)], ["C1", "C2"])
>>> df1.exceptAll(df2).show()
+---+---+
Expand All @@ -124,6 +131,7 @@ def isLocal(self):
return True

def isStreaming(self):
# pylint: disable=W0511
# todo: Add support of streaming
return False

Expand Down Expand Up @@ -169,7 +177,9 @@ def show(self, n=20, truncate=True, vertical=False):
| [2 -> 2, 4 -> 4]|
+---------------------------------------------------------+
>>> c = col("id")
>>> spark.range(9, 11).select(c, c*2, c**2).show(vertical=True)# doctest: +NORMALIZE_WHITESPACE
>>> (spark.range(9, 11)
... .select(c, c*2, c**2)
... .show(vertical=True)) # doctest: +NORMALIZE_WHITESPACE
-RECORD 0-------------
id | 9
(id * 2) | 18
Expand Down Expand Up @@ -297,7 +307,8 @@ def foreachPartition(self, f):
>>> from pysparkling import Context
>>> from pysparkling.sql.session import SparkSession
>>> spark = SparkSession(Context())
>>> result = spark.range(4, numPartitions=2).foreachPartition(lambda partition: print(list(partition)))
>>> result = (spark.range(4, numPartitions=2)
... .foreachPartition(lambda partition: print(list(partition))))
[Row(id=0), Row(id=1)]
[Row(id=2), Row(id=3)]
>>> result is None
Expand Down Expand Up @@ -330,7 +341,9 @@ def persist(self, storageLevel=StorageLevel.MEMORY_ONLY):
True
"""
if storageLevel != StorageLevel.MEMORY_ONLY:
raise NotImplementedError("Pysparkling currently only supports memory as the storage level")
raise NotImplementedError(
"Pysparkling currently only supports memory as the storage level"
)
return DataFrame(self._jdf.persist(storageLevel), self.sql_ctx)

@property
Expand All @@ -351,8 +364,7 @@ def storageLevel(self):
"""
if self.is_cached:
return self._jdf.storageLevel
else:
return StorageLevel(False, False, False, False, 1)
return StorageLevel(False, False, False, False, 1)

def unpersist(self, blocking=False):
"""Cache the DataFrame
Expand Down Expand Up @@ -437,16 +449,15 @@ def repartition(self, numPartitions, *cols):
if isinstance(numPartitions, int):
if not cols:
return DataFrame(self._jdf.repartition(numPartitions), self.sql_ctx)
else:
def partitioner(row: Row):
return sum(hash(row[c]) for c in cols)

repartitioned_jdf = self._jdf.partitionValues(numPartitions, partitioner)
return DataFrame(repartitioned_jdf, self.sql_ctx)
elif isinstance(numPartitions, (basestring, Column)):
def partitioner(row: Row):
return sum(hash(row[c]) for c in cols)

repartitioned_jdf = self._jdf.partitionValues(numPartitions, partitioner)
return DataFrame(repartitioned_jdf, self.sql_ctx)
if isinstance(numPartitions, (basestring, Column)):
return self.repartition(200, numPartitions, *cols)
else:
raise TypeError("numPartitions should be an int, str or Column")
raise TypeError("numPartitions should be an int, str or Column")

def repartitionByRange(self, numPartitions, *cols):
"""
Expand Down Expand Up @@ -479,23 +490,22 @@ def repartitionByRange(self, numPartitions, *cols):
[Row(v=4)]
"""
# pylint: disable=W0511
# todo: support sort orders and assume "ascending nulls first" if needed
if isinstance(numPartitions, int):
if not cols:
raise ValueError("At least one partition-by expression must be specified.")
else:
repartitioned_jdf = self._jdf.repartitionByRange(numPartitions, *cols)
return DataFrame(repartitioned_jdf, self.sql_ctx)
elif isinstance(numPartitions, (basestring, Column)):
repartitioned_jdf = self._jdf.repartitionByRange(numPartitions, *cols)
return DataFrame(repartitioned_jdf, self.sql_ctx)
if isinstance(numPartitions, (basestring, Column)):
return self.repartitionByRange(200, numPartitions, *cols)
else:
raise TypeError("numPartitions should be an int, str or Column")
raise TypeError("numPartitions should be an int, str or Column")

def distinct(self):
return DataFrame(self._jdf.distinct(), self.sql_ctx)

def sample(self, withReplacement=None, fraction=None, seed=None):
is_withReplacement_set = type(withReplacement) == bool and isinstance(fraction, float)
is_withReplacement_set = isinstance(withReplacement, bool) and isinstance(fraction, float)
is_withReplacement_omitted_kwargs = withReplacement is None and isinstance(fraction, float)
is_withReplacement_omitted_args = isinstance(withReplacement, float)

Expand Down Expand Up @@ -699,8 +709,16 @@ def join(self, other, on=None, how="inner"):
>>> from pysparkling.sql.session import SparkSession
>>> from pysparkling.sql.functions import length, col, lit
>>> spark = SparkSession(Context())
>>> left_df = spark.range(1, 3).select(lit("test_value"), (col("id")*2).alias("id"), lit("left").alias("side"))
>>> right_df = spark.range(1, 3).select(lit("test_value"), col("id"), lit("right").alias("side"))
>>> left_df = spark.range(1, 3).select(
... lit("test_value"),
... (col("id")*2).alias("id"),
... lit("left").alias("side")
... )
>>> right_df = spark.range(1, 3).select(
... lit("test_value"),
... col("id"),
... lit("right").alias("side")
... )
>>>
>>> left_df.join(right_df, on="id", how="inner").show()
+---+----------+----+----------+-----+
Expand Down Expand Up @@ -786,7 +804,8 @@ def sortWithinPartitions(self, *cols, ascending=True):
>>> from pysparkling.sql.session import SparkSession
>>> spark = SparkSession(Context())
>>> df = spark.range(4, numPartitions=2)
>>> df.sortWithinPartitions("id", ascending=False).foreachPartition(lambda p: print(list(p)))
>>> (df.sortWithinPartitions("id", ascending=False)
... .foreachPartition(lambda p: print(list(p))))
[Row(id=1), Row(id=0)]
[Row(id=3), Row(id=2)]
"""
Expand Down Expand Up @@ -843,6 +862,7 @@ def orderBy(self, *cols, **kwargs):
def _sort_cols(cols, kwargs):
""" Return a list of Columns that describes the sort order
"""
# pylint: disable=W0511
# todo: use this function in sort methods to add support of custom orders
if not cols:
raise ValueError("should sort by at least one column")
Expand Down Expand Up @@ -900,7 +920,7 @@ def describe(self, *cols):
"""
if len(cols) == 1 and isinstance(cols[0], list):
cols = cols[0]
if len(cols) == 0:
if not cols:
cols = ["*"]
return DataFrame(self._jdf.describe(cols), self.sql_ctx)

Expand Down Expand Up @@ -979,14 +999,13 @@ def first(self):
def __getitem__(self, item):
if isinstance(item, basestring):
return getattr(self, item)
elif isinstance(item, Column):
if isinstance(item, Column):
return self.filter(item)
elif isinstance(item, (list, tuple)):
if isinstance(item, (list, tuple)):
return self.select(*item)
elif isinstance(item, int):
if isinstance(item, int):
return Column(FieldAsExpression(self._jdf.bound_schema[item]))
else:
raise TypeError("unexpected item type: %s" % type(item))
raise TypeError("unexpected item type: %s" % type(item))

def __getattr__(self, name):
if name.startswith("_"):
Expand All @@ -1005,7 +1024,8 @@ def select(self, *cols):
>>> from pysparkling import Context
>>> from pysparkling.sql.session import SparkSession
>>> spark = SparkSession(Context())
>>> from pysparkling.sql.functions import explode, split, posexplode, posexplode_outer, col, avg
>>> from pysparkling.sql.functions import (explode, split, posexplode,
... posexplode_outer, col, avg)
>>> df = spark.createDataFrame(
... [Row(age=2, name='Alice'), Row(age=5, name='Bob')]
... )
Expand Down Expand Up @@ -1085,7 +1105,6 @@ def selectExpr(self, *expr):
This is a variant of :func:`select` that accepts SQL expressions.
# todo: handle this:
# >>> df.selectExpr("age * 2", "abs(age)").collect()
# [Row((age * 2)=4, abs(age)=2), Row((age * 2)=10, abs(age)=5)]
Expand All @@ -1105,6 +1124,7 @@ def selectExpr(self, *expr):
"""
if len(expr) == 1 and isinstance(expr[0], list):
expr = expr[0]
# pylint: disable=W0511
# todo: handle expr like abs(age)
# with jdf = self._jdf.selectExpr(expr)
jdf = self._jdf.select(*expr)
Expand Down Expand Up @@ -1319,7 +1339,8 @@ def subtract(self, other):
return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx)

def dropDuplicates(self, subset=None):
"""Return a new DataFrame without any duplicate values between rows or between rows for a subset of fields
"""Return a new DataFrame without any duplicate values
between rows or between rows for a subset of fields
>>> from pysparkling import Context, Row
>>> sc = Context()
Expand Down Expand Up @@ -1373,23 +1394,16 @@ def fillna(self, value, subset=None):

if isinstance(value, dict):
return DataFrame(self._jdf.fillna(value), self.sql_ctx)
elif subset is None:
if subset is None:
return DataFrame(self._jdf.fillna(value), self.sql_ctx)
else:
if isinstance(subset, basestring):
subset = [subset]
elif not isinstance(subset, (list, tuple)):
raise ValueError("subset should be a list or tuple of column names")
if isinstance(subset, basestring):
subset = [subset]
elif not isinstance(subset, (list, tuple)):
raise ValueError("subset should be a list or tuple of column names")

return DataFrame(self._jdf.fillna(value, subset), self.sql_ctx)
return DataFrame(self._jdf.fillna(value, subset), self.sql_ctx)

def replace(self, to_replace, value=_NoValue, subset=None):
if value is _NoValue:
if isinstance(to_replace, dict):
value = None
else:
raise TypeError("value argument is required when to_replace is not a dictionary.")

# Helper functions
def all_of(types):
def all_of_(xs):
Expand All @@ -1401,27 +1415,7 @@ def all_of_(xs):
all_of_str = all_of(basestring)
all_of_numeric = all_of((float, int, long))

# Validate input types
valid_types = (bool, float, int, long, basestring, list, tuple)
if not isinstance(to_replace, valid_types) and not isinstance(to_replace, dict):
raise ValueError(
"to_replace should be a bool, float, int, long, string, list, tuple, or dict. "
"Got {0}".format(type(to_replace)))

if not isinstance(value, valid_types) and value is not None \
and not isinstance(to_replace, dict):
raise ValueError("If to_replace is not a dict, value should be "
"a bool, float, int, long, string, list, tuple or None. "
"Got {0}".format(type(value)))

if isinstance(to_replace, (list, tuple)) and isinstance(value, (list, tuple)):
if len(to_replace) != len(value):
raise ValueError("to_replace and value lists should be of the same length. "
"Got {0} and {1}".format(len(to_replace), len(value)))

if not (subset is None or isinstance(subset, (list, tuple, basestring))):
raise ValueError("subset should be a list or tuple of column names, "
"column name or None. Got {0}".format(type(subset)))
value = self._check_replace_inputs(subset, to_replace, value)

# Reshape input arguments if necessary
if isinstance(to_replace, (float, int, long, basestring)):
Expand All @@ -1447,8 +1441,34 @@ def all_of_(xs):

if subset is None:
return DataFrame(self._jdf.replace('*', rep_dict), self.sql_ctx)
else:
return DataFrame(self._jdf.replace(subset, rep_dict), self.sql_ctx)
return DataFrame(self._jdf.replace(subset, rep_dict), self.sql_ctx)

def _check_replace_inputs(self, subset, to_replace, value):
if value is _NoValue:
if isinstance(to_replace, dict):
value = None
else:
raise TypeError("value argument is required when to_replace is not a dictionary.")

# Validate input types
valid_types = (bool, float, int, long, basestring, list, tuple)
if not isinstance(to_replace, valid_types) and not isinstance(to_replace, dict):
raise ValueError(
"to_replace should be a bool, float, int, long, string, list, tuple, or dict. "
"Got {0}".format(type(to_replace)))
if not isinstance(value, valid_types) and value is not None \
and not isinstance(to_replace, dict):
raise ValueError("If to_replace is not a dict, value should be "
"a bool, float, int, long, string, list, tuple or None. "
"Got {0}".format(type(value)))
if isinstance(to_replace, (list, tuple)) and isinstance(value, (list, tuple)):
if len(to_replace) != len(value):
raise ValueError("to_replace and value lists should be of the same length. "
"Got {0} and {1}".format(len(to_replace), len(value)))
if not (subset is None or isinstance(subset, (list, tuple, basestring))):
raise ValueError("subset should be a list or tuple of column names, "
"column name or None. Got {0}".format(type(subset)))
return value

def approxQuantile(self, col, probabilities, relativeError):
"""
Expand Down Expand Up @@ -1512,7 +1532,7 @@ def corr(self, col1, col2, method=None):
raise ValueError("col2 should be a string.")
if not method:
method = "pearson"
if not method == "pearson":
if method != "pearson":
raise ValueError("Currently only the calculation of the Pearson Correlation " +
"coefficient is supported.")
return self._jdf.corr(col1, col2, method)
Expand All @@ -1532,7 +1552,9 @@ def cov(self, col1, col2):
return self._jdf.cov(col1, col2)

def crosstab(self, col1, col2):
# pylint: disable=W0511
# todo: extra workin here
# pylint: disable=W0511
# todo: tests on schema
if not isinstance(col1, basestring):
raise ValueError("col1 should be a string.")
Expand Down Expand Up @@ -1602,7 +1624,7 @@ def drop(self, *cols):
"""
if len(cols) == 1:
col = cols[0]
if isinstance(col, basestring) or isinstance(col, Column):
if isinstance(col, (basestring, Column)):
jdf = self._jdf.drop([col])
else:
raise TypeError("col should be a string or a Column")
Expand Down Expand Up @@ -1654,6 +1676,7 @@ def toPandas(self):
else:
timezone = None

# pylint: disable=W0511
# todo: Handle sql_ctx_conf.arrowEnabled()
# Below is toPandas without Arrow optimization.
pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns)
Expand All @@ -1675,13 +1698,14 @@ def toPandas(self):

if timezone is None:
return pdf
else:
for field in self.schema:
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
if isinstance(field.dataType, TimestampType):
pdf[field.name] = \
_check_series_convert_timestamps_local_tz(pdf[field.name], timezone)
return pdf

for field in self.schema:
# pylint: disable=W0511
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
if isinstance(field.dataType, TimestampType):
pdf[field.name] = \
_check_series_convert_timestamps_local_tz(pdf[field.name], timezone)
return pdf

def groupby(self, *cols):
return self.groupBy(*cols)
Expand Down Expand Up @@ -1754,13 +1778,12 @@ def _to_corrected_pandas_type(dt):
This method gets the corrected data type for Pandas if that type may be inferred uncorrectly.
"""
import numpy as np
if type(dt) == ByteType:
if isinstance(dt, ByteType):
return np.int8
elif type(dt) == ShortType:
if isinstance(dt, ShortType):
return np.int16
elif type(dt) == IntegerType:
if isinstance(dt, IntegerType):
return np.int32
elif type(dt) == FloatType:
if isinstance(dt, FloatType):
return np.float32
else:
return None
return None

0 comments on commit ca1802a

Please sign in to comment.