diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index 34f6310e9e8b1..16cb49cc0cfff 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -17,6 +17,7 @@ from pyspark.mllib.common import JavaModelWrapper from pyspark.sql import SQLContext +from pyspark.sql.types import StructField, StructType, DoubleType class BinaryClassificationMetrics(JavaModelWrapper): @@ -38,9 +39,12 @@ def __init__(self, scoreAndLabels): :param scoreAndLabels: an RDD of (score, label) pairs """ sc = scoreAndLabels.ctx - SQLContext(sc) # monkey patch RDD.toRDD + sql_ctx = SQLContext(sc) + df = sql_ctx.createDataFrame(scoreAndLabels, schema=StructType([ + StructField("score", DoubleType(), nullable=False), + StructField("label", DoubleType(), nullable=False)])) java_class = sc._jvm.org.apache.spark.mllib.evaluation.BinaryClassificationMetrics - java_model = java_class(scoreAndLabels.toDF()._jdf) + java_model = java_class(df._jdf) super(BinaryClassificationMetrics, self).__init__(java_model) def areaUnderROC(self):