In [None]:
import torch
import torch.nn as nn
from torch import Tensor

class Loss(nn.Module):
    def __init__(self, phi_idx: int, device='cpu') -> None:
        super().__init__()
        self.phi_idx = phi_idx
        self.device = device

    def forward(
            self,
            probs: Tensor,
            target: Tensor,
            target_lengths: Tensor
            ) -> Tensor:
        target_lengths = target_lengths.to(self.device)
        batch_size, max_length, *_ = probs.shape
        n_chars = target_lengths.max().item()
        n_nulls = max_length - n_chars
        
        scores = self.get_score_matrix(batch_size, n_chars, n_nulls)
        scores = scores.to(self.device)

        for c in range(n_chars + 1):
            for p in range(n_nulls + 1):
                if c == 0 and p == 0:
                    continue
                scores = self.update_scores(scores, probs, target, p, c)
        return self.calc_loss(scores, target_lengths)

    def calc_loss(self, scores: Tensor, target_lengths: Tensor) -> Tensor:
        loss = torch.diagonal(torch.index_select(
            scores[:, :, -1], dim=1, index=target_lengths
            ))
        loss = -1 * loss
        return loss.mean()

    def get_score_matrix(
            self, batch_size: int, n_chars: int, n_nulls: int
            ) -> Tensor:
        return torch.zeros(batch_size, n_chars + 1, n_nulls + 1)

    def update_scores(
            self, scores: Tensor, probs: Tensor, target: Tensor, p: int, c: int
            ) -> Tensor:
        if p == 0:
            chars_probs = self.get_chars_probs(probs, target, c, p)
            scores[:, c, p] = chars_probs + scores[:, c - 1, p]
            return scores
        elif c == 0:
            phi_probs = self.get_phi_probs(probs, c, p)
            scores[:, c, p] = phi_probs + scores[:, c, p - 1]
            return scores
        chars_probs = self.get_chars_probs(probs, target, c, p)
        phi_probs = self.get_phi_probs(probs, c, p)
        scores[:, c, p] = torch.logsumexp(
            torch.stack(
                [scores[:, c, p - 1] + self.log(phi_probs),
                scores[:, c - 1, p] + self.log(chars_probs)]
            ), dim=0)
        return scores

    def get_phi_probs(self, probs: Tensor, c: int, p: int) -> Tensor:
        return probs[:, c + p - 1, self.phi_idx]

    def get_chars_probs(
            self, probs: Tensor, target: Tensor, c: int, p: int
            ) -> Tensor:
        all_seqs = probs[:, p + c - 1]
        result = torch.index_select(all_seqs, dim=-1, index=target[:, c - 1])
        return torch.diagonal(result)