From c79ac9ba2576cc408b4402e9e9d17b6593ea758e Mon Sep 17 00:00:00 2001 From: brett koonce Date: Fri, 11 Dec 2020 02:43:36 +0000 Subject: [PATCH] fix batch sizes --- TrainingLoop/Metrics.swift | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/TrainingLoop/Metrics.swift b/TrainingLoop/Metrics.swift index 9ef7595d429..a6800a74094 100644 --- a/TrainingLoop/Metrics.swift +++ b/TrainingLoop/Metrics.swift @@ -173,9 +173,10 @@ public struct AccuracyMeasurerTop5: MetricsMeasurer { "For accuracy measurements, the model output must be Tensor, and the labels must be Tensor." ) } - let top5 = _Raw.topKV2(predictions, k: Tensor(5), sorted:false) - let labelsReshaped = labels.reshaped(to: [1, 32]) - let labelsTranspose = labelsReshaped.transposed(permutation: [1,0]) + let top5 = _Raw.topKV2(predictions, k: Tensor(5), sorted: false) + let batchSize = labels.shape[0] + let labelsReshaped = labels.reshaped(to: [1, batchSize]) + let labelsTranspose = labelsReshaped.transposed(permutation: [1, 0]) let ones = Tensor([1, 1, 1, 1, 1]) let exandedLabels = labelsTranspose * ones correctGuessCount += Tensor(top5.indices .== exandedLabels).sum().scalarized()