-
Notifications
You must be signed in to change notification settings - Fork 394
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Threshold metrics calculation fix when unseen labels are present #293
Conversation
…grifAI into km/label-trim-metrics
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, minor comment
@@ -188,7 +188,8 @@ private[op] class OpMultiClassificationEvaluator | |||
def computeMetrics(scoresAndLabels: (Array[Double], Double)): MetricsMap = { | |||
val scores: Array[Double] = scoresAndLabels._1 | |||
val label: Label = scoresAndLabels._2.toInt | |||
val trueClassScore: Double = scores(label) | |||
// The label may be unseen during model training, so treat scores for unseen classes as all being zero | |||
val trueClassScore: Double = if (scores.isDefinedAt(label)) scores(label) else 0.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about scores.lift(label).getOrElse(0.0)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahh, I was trying to remember how to do that!
…label-trim-metrics
Related issues
Issue exposed by fixes in #263
Describe the proposed solution
Simple one-line fix to threshold metrics calculation. When the true label score is calculated, it will either come from the list of labels the model was trained on (which may be pruned by DataCutter) or be 0 (eg. if it corresponds to a label the model was not trained on).
Describe alternatives you've considered
n/a
Additional context
Merge #263 first since this PR was branched off of it.
Unit test resides in MultiClassificationModelSelectorTest