Skip to content

Commit

Permalink
add top 5 accuracy metric
Browse files Browse the repository at this point in the history
  • Loading branch information
brettkoonce committed Dec 11, 2020
1 parent 495d657 commit 1632a0b
Showing 1 changed file with 55 additions and 1 deletion.
56 changes: 55 additions & 1 deletion TrainingLoop/Metrics.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import TensorFlow
public enum TrainingMetrics {
case loss
case accuracy
case top5
case matthewsCorrelationCoefficient
case perplexity

Expand All @@ -13,6 +14,8 @@ public enum TrainingMetrics {
return "loss"
case .accuracy:
return "accuracy"
case .top5:
return "top5"
case .matthewsCorrelationCoefficient:
return "mcc"
case .perplexity:
Expand All @@ -26,6 +29,8 @@ public enum TrainingMetrics {
return LossMeasurer(self.name)
case .accuracy:
return AccuracyMeasurer(self.name)
case .top5:
return AccuracyMeasurerTop5(self.name)
case .matthewsCorrelationCoefficient:
return MCCMeasurer(self.name)
case .perplexity:
Expand Down Expand Up @@ -89,7 +94,7 @@ public struct LossMeasurer: MetricsMeasurer {
}
}

/// A measurer for measuring accuracy
/// A measurer for measuring accuracy (top 1)
public struct AccuracyMeasurer: MetricsMeasurer {
/// Name of the AccuracyMeasurer.
public var name: String
Expand Down Expand Up @@ -134,6 +139,55 @@ public struct AccuracyMeasurer: MetricsMeasurer {
}
}

/// A measurer for measuring accuracy (top 5)
public struct AccuracyMeasurerTop5: MetricsMeasurer {
/// Name of the AccuracyMeasurer.
public var name: String

/// Count of correct guesses.
private var correctGuessCount: Int32 = 0

/// Count of total guesses.
private var totalGuessCount: Int32 = 0

/// Creates an instance with the AccuracyMeasurerTop5 named `name`.
public init(_ name: String = "top5") {
self.name = name
}

/// Resets correctGuessCount and totalGuessCount to zero.
public mutating func reset() {
correctGuessCount = 0
totalGuessCount = 0
}

/// Computes correct guess count from `loss`, `predictions` and `labels`
/// and adds it to correctGuessCount; Computes total guess count from
/// `labels` shape and adds it to totalGuessCount.
public mutating func accumulate<Output, Target>(
loss: Tensor<Float>?, predictions: Output?, labels: Target?
) {
guard let predictions = predictions as? Tensor<Float>, let labels = labels as? Tensor<Int32>
else {
fatalError(
"For accuracy measurements, the model output must be Tensor<Float>, and the labels must be Tensor<Int>."
)
}
let top5 = _Raw.topKV2(predictions, k: Tensor<Int32>(5), sorted:false)
let labelsReshaped = labels.reshaped(to: [1, 32])
let labelsTranspose = labelsReshaped.transposed(permutation: [1,0])
let ones = Tensor<Int32>([1, 1, 1, 1, 1])
let exandedLabels = labelsTranspose * ones
correctGuessCount += Tensor<Int32>(top5.indices .== exandedLabels).sum().scalarized()
totalGuessCount += Int32(labels.shape.reduce(1, *))
}

/// Computes accuracy as percentage of correct guesses.
public func measure() -> Float {
return Float(correctGuessCount) / Float(totalGuessCount)
}
}

/// A measurer for measuring matthewsCorrelationCoefficient.
public struct MCCMeasurer: MetricsMeasurer {
/// Name of the MCCMeasurer.
Expand Down

0 comments on commit 1632a0b

Please sign in to comment.