Skip to content

Commit

Permalink
add a basic BinaryClassificationMetrics to PySpark/MLlib
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Mar 3, 2015
1 parent 2db6a85 commit dcddab5
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 0 deletions.
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.Logging
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.evaluation.binary._
import org.apache.spark.rdd.{RDD, UnionRDD}
import org.apache.spark.sql.DataFrame

/**
* :: Experimental ::
Expand Down Expand Up @@ -53,6 +54,13 @@ class BinaryClassificationMetrics(
*/
def this(scoreAndLabels: RDD[(Double, Double)]) = this(scoreAndLabels, 0)

/**
* An auxiliary constructor taking a DataFrame.
* @param scoreAndLabels a DataFrame with two double columns: score and label
*/
private[mllib] def this(scoreAndLabels: DataFrame) =
this(scoreAndLabels.map(r => (r.getDouble(0), r.getDouble(1))))

/** Unpersist intermediate RDDs used in the computation. */
def unpersist() {
cumulativeCounts.unpersist()
Expand Down
7 changes: 7 additions & 0 deletions python/docs/pyspark.mllib.rst
Expand Up @@ -16,6 +16,13 @@ pyspark.mllib.clustering module
:members:
:undoc-members:

pyspark.mllib.evaluation module
-------------------------------

.. automodule:: pyspark.mllib.evaluation
:members:
:undoc-members:

pyspark.mllib.feature module
-------------------------------

Expand Down
79 changes: 79 additions & 0 deletions python/pyspark/mllib/evaluation.py
@@ -0,0 +1,79 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from pyspark.mllib.common import JavaModelWrapper
from pyspark.sql import SQLContext


class BinaryClassificationMetrics(JavaModelWrapper):
"""
Evaluator for binary classification.
>>> scoreAndLabels = sc.parallelize([
... (0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)], 2)
>>> metrics = BinaryClassificationMetrics(scoreAndLabels)
>>> metrics.areaUnderROC()
0.70...
>>> metrics.areaUnderPR()
0.83...
>>> metrics.unpersist()
"""

def __init__(self, scoreAndLabels):
"""
:param scoreAndLabels: an RDD of (score, label) pairs
"""
sc = scoreAndLabels.ctx
SQLContext(sc) # monkey patch RDD.toRDD
java_class = sc._jvm.org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
java_model = java_class(scoreAndLabels.toDF()._jdf)
super(BinaryClassificationMetrics, self).__init__(java_model)

def areaUnderROC(self):
"""
Computes the area under the receiver operating characteristic
(ROC) curve.
"""
return self.call("areaUnderROC")

def areaUnderPR(self):
"""
Computes the area under the precision-recall curve.
"""
return self.call("areaUnderPR")

def unpersist(self):
"""
Unpersists intermediate RDDs used in the computation.
"""
self.call("unpersist")


def _test():
import doctest
from pyspark import SparkContext
import pyspark.mllib.evaluation
globs = pyspark.mllib.evaluation.__dict__.copy()
globs['sc'] = SparkContext('local[4]', 'PythonTest')
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
globs['sc'].stop()
if failure_count:
exit(-1)


if __name__ == "__main__":
_test()
1 change: 1 addition & 0 deletions python/run-tests
Expand Up @@ -75,6 +75,7 @@ function run_mllib_tests() {
echo "Run mllib tests ..."
run_test "pyspark/mllib/classification.py"
run_test "pyspark/mllib/clustering.py"
run_test "pyspark/mllib/evaluation.py"
run_test "pyspark/mllib/feature.py"
run_test "pyspark/mllib/linalg.py"
run_test "pyspark/mllib/rand.py"
Expand Down

0 comments on commit dcddab5

Please sign in to comment.