Skip to content

Commit

Permalink
[SPARK-2627] more misc PEP 8 fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
nchammas committed Aug 3, 2014
1 parent fe57ed0 commit 6f4900b
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 12 deletions.
14 changes: 8 additions & 6 deletions python/pyspark/mllib/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def test_classification(self):
self.assertTrue(nb_model.predict(features[2]) <= 0)
self.assertTrue(nb_model.predict(features[3]) > 0)

categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories
categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories
dt_model = \
DecisionTree.trainClassifier(rdd, numClasses=2,
categoricalFeaturesInfo=categoricalFeaturesInfo)
Expand Down Expand Up @@ -176,9 +176,10 @@ def test_regression(self):
self.assertTrue(rr_model.predict(features[2]) <= 0)
self.assertTrue(rr_model.predict(features[3]) > 0)

categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories
categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories
dt_model = \
DecisionTree.trainRegressor(rdd, categoricalFeaturesInfo=categoricalFeaturesInfo)
DecisionTree.trainRegressor(
rdd, categoricalFeaturesInfo=categoricalFeaturesInfo)
self.assertTrue(dt_model.predict(features[0]) <= 0)
self.assertTrue(dt_model.predict(features[1]) > 0)
self.assertTrue(dt_model.predict(features[2]) <= 0)
Expand Down Expand Up @@ -290,7 +291,7 @@ def test_classification(self):
self.assertTrue(nb_model.predict(features[2]) <= 0)
self.assertTrue(nb_model.predict(features[3]) > 0)

categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories
categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories
dt_model = DecisionTree.trainClassifier(rdd, numClasses=2,
categoricalFeaturesInfo=categoricalFeaturesInfo)
self.assertTrue(dt_model.predict(features[0]) <= 0)
Expand Down Expand Up @@ -329,8 +330,9 @@ def test_regression(self):
self.assertTrue(rr_model.predict(features[2]) <= 0)
self.assertTrue(rr_model.predict(features[3]) > 0)

categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories
dt_model = DecisionTree.trainRegressor(rdd, categoricalFeaturesInfo=categoricalFeaturesInfo)
categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories
dt_model = DecisionTree.trainRegressor(
rdd, categoricalFeaturesInfo=categoricalFeaturesInfo)
self.assertTrue(dt_model.predict(features[0]) <= 0)
self.assertTrue(dt_model.predict(features[1]) > 0)
self.assertTrue(dt_model.predict(features[2]) <= 0)
Expand Down
7 changes: 5 additions & 2 deletions python/pyspark/mllib/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
from pyspark.mllib.regression import LabeledPoint
from pyspark.serializers import NoOpSerializer


class DecisionTreeModel(object):

"""
A decision tree model for classification or regression.
Expand Down Expand Up @@ -77,6 +79,7 @@ def __str__(self):


class DecisionTree(object):

"""
Learning algorithm for a decision tree model
for classification or regression.
Expand Down Expand Up @@ -174,7 +177,6 @@ def trainRegressor(data, categoricalFeaturesInfo={},
categoricalFeaturesInfo,
impurity, maxDepth, maxBins)


@staticmethod
def train(data, algo, numClasses, categoricalFeaturesInfo,
impurity, maxDepth, maxBins=100):
Expand Down Expand Up @@ -216,7 +218,8 @@ def _test():
import doctest
globs = globals().copy()
globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
(failure_count, test_count) = doctest.testmod(
globs=globs, optionflags=doctest.ELLIPSIS)
globs['sc'].stop()
if failure_count:
exit(-1)
Expand Down
3 changes: 2 additions & 1 deletion python/pyspark/mllib/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None):
parsed = lines.map(lambda l: MLUtils._parse_libsvm_line(l))
if numFeatures <= 0:
parsed.cache()
numFeatures = parsed.map(lambda x: -1 if x[1].size == 0 else x[1][-1]).reduce(max) + 1
numFeatures = parsed.map(
lambda x: -1 if x[1].size == 0 else x[1][-1]).reduce(max) + 1
return parsed.map(lambda x: LabeledPoint(x[0], Vectors.sparse(numFeatures, x[1], x[2])))

@staticmethod
Expand Down
8 changes: 5 additions & 3 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,9 +981,10 @@ def registerFunction(self, name, f, returnType=StringType()):
env = MapConverter().convert(self._sc.environment,
self._sc._gateway._gateway_client)
includes = ListConverter().convert(self._sc._python_includes,
self._sc._gateway._gateway_client)
self._sc._gateway._gateway_client)
self._ssql_ctx.registerPython(name,
bytearray(CloudPickleSerializer().dumps(command)),
bytearray(
CloudPickleSerializer().dumps(command)),
env,
includes,
self._sc.pythonExec,
Expand Down Expand Up @@ -1525,7 +1526,8 @@ def registerTempTable(self, name):
self._jschema_rdd.registerTempTable(name)

def registerAsTable(self, name):
warnings.warn("Use registerTempTable instead of registerAsTable.", DeprecationWarning)
warnings.warn(
"Use registerTempTable instead of registerAsTable.", DeprecationWarning)
self.registerTempTable(name)

def insertInto(self, tableName, overwrite=False):
Expand Down

0 comments on commit 6f4900b

Please sign in to comment.