# Conditional Random Fields
The conditional random fields layer follows the algorithm described in [_Natural language processing (almost) from scratch_ by Collobert et al. (2011)](https://arxiv.org/abs/1103.0398), Sect. 3.3.2., notably for the computation of $\delta$.

Programs from the book: [_Python for Natural Language Processing_](https://link.springer.com/book/9783031575488)

__Author__: Pierre Nugues

In [1]:
import torch
import torch.nn as nn

In [2]:
class CRF(nn.Module):
    def __init__(self, tagset_size):
        super(CRF, self).__init__()
        self.tagset_size = tagset_size
        self.start_transitions = nn.Parameter(torch.empty(tagset_size))
        self.end_transitions = nn.Parameter(torch.empty(tagset_size))
        self.transitions = nn.Parameter(torch.empty(tagset_size, tagset_size))

        nn.init.uniform_(self.start_transitions, -0.1, 0.1)
        nn.init.uniform_(self.end_transitions, -0.1, 0.1)
        nn.init.uniform_(self.transitions, -0.1, 0.1)

    def _compute_deltas(self, logits):
        delta = (self.start_transitions + logits[0]).unsqueeze(dim=0)
        for logit in logits[1:]:
            logadd = torch.logsumexp(
                delta.T + self.transitions,
                dim=0).unsqueeze(dim=0)
            delta = logadd + logit
        delta += self.end_transitions
        return torch.logsumexp(delta, dim=1).squeeze()

    def _compute_sentence_score(self, logits, tags):
        correct_path_score = self.start_transitions[tags[0]
                                                    ] + logits[0][tags[0]]
        for i, logit in enumerate(logits[1:]):
            correct_path_score += self.transitions[tags[i],
                                                   tags[i + 1]] + logit[tags[i + 1]]
        correct_path_score += self.end_transitions[tags[-1]]
        return correct_path_score

    def _viterbi_decode(self, logits):
        backpointers = []
        max_llhoods = self.start_transitions + logits[0]

        for logit in logits[1:]:
            backpointers_t = []
            max_llhoods_t = []
            for next_tag in range(self.tagset_size):
                llhoods = max_llhoods + self.transitions[:, next_tag]
                best_tag_id = torch.argmax(llhoods).item()
                backpointers_t += [best_tag_id]
                max_llhoods_t += [llhoods[best_tag_id]]
            max_llhoods = (torch.tensor(max_llhoods_t) + logit)
            backpointers += [backpointers_t]

        max_llhoods += self.end_transitions
        best_tag_id = torch.argmax(max_llhoods).item()
        path_score = max_llhoods[best_tag_id]

        best_path = [best_tag_id]
        for backpointers_t in backpointers[::-1]:
            best_tag_id = backpointers_t[best_tag_id]
            best_path += [best_tag_id]
        return path_score, best_path[::-1]

    def forward(self, logits, targets=None):
        if targets is not None:
            sent_score = self._compute_sentence_score(logits, targets)
            normalizing_score = self._compute_deltas(logits)
            return sent_score - normalizing_score
        else:
            return self._viterbi_decode(logits)

We set the tagset size and the length of the sequence

In [3]:
tagset_size = 10
seq_length = 15

We create a CRF object with the tagset size

In [4]:
crf = CRF(tagset_size)

The input corresponds typically to logits

In [5]:
inputs = torch.rand((seq_length, tagset_size))
inputs[:3]

tensor([[0.1525, 0.1313, 0.5306, 0.4148, 0.7014, 0.7169, 0.1856, 0.7635, 0.5954,
         0.7533],
        [0.4602, 0.7714, 0.4050, 0.6981, 0.5786, 0.5871, 0.8359, 0.5807, 0.6846,
         0.9374],
        [0.6913, 0.4968, 0.2297, 0.0448, 0.4424, 0.8568, 0.1832, 0.6979, 0.9804,
         0.6524]])

In [6]:
true_tags = torch.randint(0, tagset_size, (seq_length,))
true_tags

tensor([3, 7, 0, 8, 4, 0, 1, 5, 8, 4, 8, 3, 6, 9, 4])

We compute the log-likelihood with respect to a sequence of tags

In [7]:
crf(inputs, true_tags)

tensor(-35.5360, grad_fn=<SubBackward0>)

We compute the optimal log-likelihood and sequence given the logits

In [8]:
best_loglikelihood, best_tags = crf(inputs)
best_loglikelihood, best_tags

(tensor(13.8228, grad_fn=<SelectBackward0>),
 [5, 9, 8, 6, 2, 9, 2, 8, 5, 0, 7, 2, 7, 9, 1])