Skip to content

Commit

Permalink
fix batch sizes
Browse files Browse the repository at this point in the history
  • Loading branch information
brettkoonce committed Dec 11, 2020
1 parent 1632a0b commit c79ac9b
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions TrainingLoop/Metrics.swift
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,10 @@ public struct AccuracyMeasurerTop5: MetricsMeasurer {
"For accuracy measurements, the model output must be Tensor<Float>, and the labels must be Tensor<Int>."
)
}
let top5 = _Raw.topKV2(predictions, k: Tensor<Int32>(5), sorted:false)
let labelsReshaped = labels.reshaped(to: [1, 32])
let labelsTranspose = labelsReshaped.transposed(permutation: [1,0])
let top5 = _Raw.topKV2(predictions, k: Tensor<Int32>(5), sorted: false)
let batchSize = labels.shape[0]
let labelsReshaped = labels.reshaped(to: [1, batchSize])
let labelsTranspose = labelsReshaped.transposed(permutation: [1, 0])
let ones = Tensor<Int32>([1, 1, 1, 1, 1])
let exandedLabels = labelsTranspose * ones
correctGuessCount += Tensor<Int32>(top5.indices .== exandedLabels).sum().scalarized()
Expand Down

0 comments on commit c79ac9b

Please sign in to comment.