Skip to content

Commit

Permalink
Small updates towards Python DecisionTree API
Browse files Browse the repository at this point in the history
  • Loading branch information
jkbradley committed Aug 1, 2014
1 parent 188cb0d commit 665ba78
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
12 changes: 6 additions & 6 deletions examples/src/main/python/mllib/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ def swapLabelAndFeature0(labeledPoint):
classificationModel = DecisionTree.trainClassifier(points, numClasses=2)
# Print learned tree and stats.
print "Trained DecisionTree for classification:"
print " Model numNodes: " + classificationModel.numNodes() + "\n"
print " Model depth: " + classificationModel.depth() + "\n"
print " Training accuracy: " + getAccuracy(classificationModel, points) + "\n"
print " Model numNodes: %d\n" % classificationModel.numNodes()
print " Model depth: %d\n" % classificationModel.depth()
print " Training accuracy: %g\n" % getAccuracy(classificationModel, points)
print classificationModel

# Switch labels and first feature to create a regression dataset with categorical features.
Expand All @@ -84,7 +84,7 @@ def swapLabelAndFeature0(labeledPoint):
DecisionTree.trainRegressor(points, categoricalFeaturesInfo=categoricalFeaturesInfo)
# Print learned tree and stats.
print "Trained DecisionTree for regression:"
print " Model numNodes: " + regressionModel.numNodes() + "\n"
print " Model depth: " + regressionModel.depth() + "\n"
print " Training MSE: " + getMSE(regressionModel, points) + "\n"
print " Model numNodes: %d\n" % regressionModel.numNodes()
print " Model depth: %d\n" % regressionModel.depth()
print " Training MSE: %g\n" % getMSE(regressionModel, points)
print regressionModel
7 changes: 5 additions & 2 deletions python/pyspark/mllib/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
from py4j.java_collections import MapConverter

from pyspark import SparkContext, RDD
from pyspark.mllib._common import \
_get_unmangled_double_vector_rdd, _serialize_double_vector, \
_deserialize_labeled_point, _get_unmangled_labeled_point_rdd
from pyspark.mllib.regression import LabeledPoint
from pyspark.serializers import NoOpSerializer


class DecisionTreeModel(object):
"""
A decision tree model for classification or regression.
Expand All @@ -45,7 +47,8 @@ def predict(self, x):
:param x: Either one data point (feature vector), or a dataset (RDD of feature vectors)
"""
pythonAPI = self._sc._jvm.PythonMLLibAPI()
if type(x) == RDD:
print "predict called for type: " + str(type(x))
if isinstance(x, RDD):
# Bulk prediction
dataBytes = _get_unmangled_double_vector_rdd(x)
jSerializedPreds = pythonAPI.predictDecisionTreeModel(self._java_model, dataBytes._jrdd)
Expand Down

0 comments on commit 665ba78

Please sign in to comment.