Skip to content

Commit

Permalink
BUG: Make distance metrics return tensors, fix #700 #701
Browse files Browse the repository at this point in the history
Makes functions in `vak.transforms.distance.functional`
return tensors so we don't cause errors when lightning
tries to convert from numpy to tensors to log.

Letting lightning do the conversion kind of works,
but it can cause a fatal error
for someone using an Apple M1 with 'mps' as the accelerator,
see https://forum.vocalpy.org/t/vak-tweetynet-with-an-apple-m1-max/78/4?u=nicholdav

I don't find any explicit statement in either the Lightning
or Torchmetrics docs that metrics should always be tensors,
and that this guarantees there won't be weird issues
(right now we get a warning on start-up that all logged scalars
should be float32, but I would expect one should be able to log
integers too?).
But from various issues I read, it seems like that should be the case,
Lightning-AI/pytorch-lightning#2143
and I notice that torchmetrics classes tend to do things like
convert to a float tensor
  • Loading branch information
NickleDave committed Sep 23, 2023
1 parent e005f07 commit d5f0564
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/vak/metrics/distance/functional.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import torch


def levenshtein(source, target):
Expand Down Expand Up @@ -65,7 +66,7 @@ def levenshtein(source, target):

d0, d1 = d1, d0

return d0[-1]
return torch.tensor(d0[-1], dtype=torch.int32)


def segment_error_rate(y_pred, y_true):
Expand Down Expand Up @@ -95,4 +96,5 @@ def segment_error_rate(y_pred, y_true):
"segment error rate is undefined when length of y_true is zero"
)

return levenshtein(y_pred, y_true) / len(y_true)
rate = levenshtein(y_pred, y_true) / len(y_true)
return torch.tensor(rate, dtype=torch.float32)

0 comments on commit d5f0564

Please sign in to comment.