diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala index 292cbaf639f90..666362ae6739a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala @@ -47,8 +47,9 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { .collectAsMap() private lazy val confusions = predictionAndLabels .map { case (prediction, label) => - ((prediction, label), 1) - }.reduceByKey(_ + _).collectAsMap() + ((label, prediction), 1) + }.reduceByKey(_ + _) + .collectAsMap() /** * Returns confusion matrix: @@ -57,19 +58,18 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { * as in "labels" */ def confusionMatrix: Matrix = { - val transposedFlatMatrix = Array.ofDim[Double](labels.size * labels.size) val n = labels.size - var i, j = 0 - while(i < n){ - j = 0 - while(j < n){ - transposedFlatMatrix(i * labels.size + j) - = confusions.getOrElse((labels(i), labels(j)), 0).toDouble + val values = Array.ofDim[Double](n * n) + var i = 0 + while (i < n) { + var j = 0 + while (j < n) { + values(i + j * n) = confusions.getOrElse((labels(i), labels(j)), 0).toDouble j += 1 } i += 1 } - Matrices.dense(labels.size, labels.size, transposedFlatMatrix) + Matrices.dense(n, n, values) } /** diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala index 555343d7cdb21..1ea503971c864 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala @@ -17,9 +17,10 @@ package org.apache.spark.mllib.evaluation +import org.scalatest.FunSuite + import org.apache.spark.mllib.linalg.Matrices import org.apache.spark.mllib.util.LocalSparkContext -import org.scalatest.FunSuite class MulticlassMetricsSuite extends FunSuite with LocalSparkContext { test("Multiclass evaluation metrics") {