In [2]:
import torch
import torch.nn as nn

class CRF(nn.Module):
    def __init__(
        self, 
        num_tags: int,
        batch_first: bool = True
    ):
        """init parameters of CRF
        
        Args:
            num_tags(int): number of tags
            batch_first(bool): 
        """
        super(CRF, self).__init__()
        self.num_tags = num_tags
        self.batch_first = batch_first
        
        self.transitions = nn.Parameter(torch.Tensor(num_tags, num_tags))
        self.start_transitions = nn.Parameter(torch.Tensor(num_tags))
        self.end_transitions = nn.Parameter(torch.Tensor(num_tags))
        
    def forward(
        self,
        emissions: torch.Tensor,
        tags: torch.LongTensor,
        mask: torch.ByteTensor | None = None
    ) -> torch.Tensor:
        """计算给定标签序列的负对数似然

        Args:
            emissions (torch.Tensor): _description_
            tags (torch.LongTensor): _description_
            mask (torch.ByteTensor | None, optional): _description_. Defaults to None.
        
        Returns:
            torch.tensor: 输入tags的负对数似然
        """
        if mask is None:
            mask = torch.ones_like(tags)
            
        if self.batch_fisrt:
            emissions = emissions.permute(1, 0, 2)
            tags = tags.permute(1, 0)
            mask = mask.permute(1, 0)
            
        score = self._compute_score(emissions, tags, mask)
        partition = self._compute_partition(emissions, mask)
        
        return partition - score
    
    def _compute_score(
        self,
        emissions: torch.Tensor,
        tags: torch.LongTensor,
        mask: torch.ByteTensor
    ) -> torch.Tensor:
        """_summary_

        Args:
            emissions (torch.Tensor): _description_
            tags (torch.LongTensor): _description_
            mask (torch.ByteTensor): _description_

        Returns:
            torch.Tensor: _description_
        """
        seq_len, batch_size = tags.shape
        first_tags = tags[0]
        
        score = self.start_transitions[first_tags]
        score += emissions[0, torch.arange(batch_size), first_tags]
        
        mask = mask.type_as(emissions)
        for i in range(1, seq_len):
            score += (
                self.transitions[tags[i], tags[i + 1]] + 
                emissions[i, torch.arange(batch_size), tags[i]]
            ) * mask[i]
            
        last_valid_index = mask.long().sum(dim=0) - 1
        last_tags = tags[last_valid_index, torch.arange(batch_size)]
        score += self.end_transitions[last_tags]
        
        return score
    
    def _compute_partition(
        self,
        emissions: torch.Tensor,
        mask: torch.ByteTensor
    ) -> torch.Tensor:
        """_summary_

        Args:
            emissions (torch.Tensor): _description_
            mask (torch.ByteTensor): _description_

        Returns:
            torch.Tensor: _description_
        """
        
        seq_len = emissions.shape[0]
        score = self.start_transitions.unsqueeze(0) + emissions[0]
        
        for i in range(1, seq_len):
            broadcast_score = score.unsqueeze(2)
            broadcast_emissions = emissions[i].unsqueeze(1)
            current_score = broadcast_score + self.transitions + broadcast_emissions
            current_score = torch.logsumexp(current_score, dim=1) # shape (batch_size, num_tags)
            score = torch.where(mask[i].bool().unsqueeze(1), current_score, score)
            
        score += self.end_transitions
        return torch.logsumexp(score, dim=1)
    
    def decode(
        self,
        emissions: torch.Tensor,
        mask: torch.ByteTensor | None = None
    ):
        """_summary_

        Args:
            emissions (torch.Tensor): _description_
            mask (torch.ByteTensor | None, optional): _description_. Defaults to None.
        """
        if mask is None:
            mask = torch.ones(emissions[:2])
        
        if self.batch_first:
            emissions = emissions.permute(1, 0)
            mask = mask.permute(1, 0)
        
        return self._viterbi_decode(emissions, mask)
    
    def _viterbi_decode(
        self,
        emissions: torch.Tensor,
        mask: torch.ByteTensor
    ):
        """_summary_

        Args:
            emissions (torch.Tensor): (seq_len, batch_size, num_tags)
            mask (torch.ByteTensor): (seq_len, batch_size)
        """
        seq_len, batch_size = mask.shape
        score = self.start_transitions + emissions[0]
        history: list[torch.Tensor] = []
        
        for i in range(1, seq_len):
            broadcast_score = score.unsqueeze(2)
            broadcast_emissions = emissions.unsqueeze(1)
            current_score = broadcast_score + self.transitions + broadcast_emissions
            best_score, indices = torch.max(current_score, dim=1)
            score = torch.where(mask[i].bool().unsqueeze(1), best_score, score)
            history.append(indices)
            
        score += self.end_transitions
        best_score, indices = torch.max(score, dim=1)
        seq_end_tags = mask.long().sum(dim=0) - 1
        best_paths: list[list[int]] = []
        
        for i in range(batch_size):
            best_last_tag = indices[i]
            this_path = [best_last_tag.item()]
            for hist in reversed(history[: seq_end_tags[i]]):
                best_last_tag = hist[i][this_path[-1]]
                this_path.append(best_last_tag.item())
            this_path.reverse()
            best_paths.append(this_path)
        
        return best_paths