In [1]:
from typing import Sequence

import pytorch_lightning as pl
import torch
import editdistance

단어 간 거리는 형태적 거리와 의미적 거리로 분류됩니다.

- 의미적 거리: Word2Vec(e.g. King == men)
- 형태적 거리: Levenshtein distance(edit distance) (e.g. 점심 먹자 -> 저녁 먹자)

## Levenshtein (Edit) distance

Levenshtein distance 는 한 string s1을 s2로 변환하는 최소 횟수를 두 string 간의 거리로 정의합니다.

횟수(거리)는 **삭제, 삽입, 치환**의 연산으로 이루어 집니다.

Ref: [Levenshtein (edit) distance 를 이용한 한국어 단어의 형태적 유사성](https://lovit.github.io/nlp/2018/08/28/levenshtein_hangle/)

In [17]:
class CharacterErrorRate(pl.metrics.Metric):
    """Character error rate metric, computed using Levenshtein distance."""
    def __init__(self, ignore_tokens: Sequence[int], *args):
        super().__init__(*args)
        self.ignore_tokens = set(ignore_tokens)
        self.add_state("error", default=torch.tensor(0.0), dist_reduce_fx="sum")  # pylint: disable=not-callable
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")  # pylint: disable=not-callable
        self.error: torch.Tensor
        self.total: torch.Tensor

    def update(self, preds: torch.Tensor, targets: torch.Tensor) -> None:
        N = preds.shape[0]
        for ind in range(N):
            pred = [_ for _ in preds[ind].tolist() if _ not in self.ignore_tokens]
            target = [_ for _ in targets[ind].tolist() if _ not in self.ignore_tokens]
        
            distance = editdistance.distance(pred, target)
            print(f'Distance is {distance}')
            error = distance / max(len(pred), len(target))
            print(f'Error is {error}')
            self.error = self.error + error
        self.total = self.total + N
        print(f'Total is {self.total}')
        print('-'* 80)

    def compute(self) -> torch.Tensor:
        return self.error / self.total


In [18]:
metric = CharacterErrorRate([0, 1])

X = torch.tensor(  # pylint: disable=not-callable
   [
      [0, 2, 2, 3, 3, 1],  # error will be 0
      [0, 2, 1, 1, 1, 1],  # error will be .75
      [0, 2, 2, 4, 4, 1],  # error will be .5
   ]
)
Y = torch.tensor(  # pylint: disable=not-callable
   [
      [0, 2, 2, 3, 3, 1],
      [0, 2, 2, 3, 3, 1],
      [0, 2, 2, 3, 3, 1],
   ]
)
metric(X, Y)
print(metric.compute())

Distance is 0
Error is 0.0
Distance is 3
Error is 0.75
Distance is 2
Error is 0.5
Total is 3
--------------------------------------------------------------------------------
Distance is 0
Error is 0.0
Distance is 3
Error is 0.75
Distance is 2
Error is 0.5
Total is 3
--------------------------------------------------------------------------------
tensor(0.4167)


## Greedy Seaerch Decoder

softmax를 통과한 후에, 가장 확률값이 큰 인덱스를 뽑아 해당 time-step의 y_hat으로 사용하는 것.

![img](https://guillaumegenthial.github.io/assets/img2latex/seq2seq_vanilla_decoder.svg)

In [19]:
def greedy_decode(self, logprobs: torch.Tensor, max_length: int) -> torch.Tensor:
   """
   Greedily decode sequences, collapsing repeated tokens, and removing the CTC blank token.
   See the "Inference" sections of https://distill.pub/2017/ctc/

   Using groupby inspired by https://github.com/nanoporetech/fast-ctc-decode/blob/master/tests/benchmark.py#L8

   Parameters
   ----------
   logprobs
      (B, C, S) log probabilities
   max_length
      max length of a sequence

   Returns
   -------
   torch.Tensor
      (B, S) class indices
      
   """
   B = logprobs.shape[0]
   argmax = logprobs.argmax(1)
   decoded = torch.ones((B, max_length)).type_as(logprobs).int() * self.padding_index
   for i in range(B):
      seq = [b for b, _g in itertools.groupby(argmax[i].tolist()) if b != self.blank_index][:max_length]
      for ii, char in enumerate(seq):
         decoded[i, ii] = char
   return decoded