# Sequence Labeling using LSTM+CRF - Part 2

### ---- Tensorflow CRF source code explained

## 1. Overview

The neural architecture is depicted as below:

<img src='images/ner_neural_architecture_2.png' style='height:400px;width:550px'>

The word embedding, encoder and decoder can be implemented by using standard Bi-LSTM model. In this notebook, we will explain how the CRF layer is implemented in Tensorflow.

Following picture depicts the end-to-end procedure of the Bi-LSTM + CRF model:


<img src='images/crf_training_process.png' style='height:500px;width:750px'>

1. Given a training example $(X, y)$, the LSTM+CRF model computes the score $S(X,y)$
2. Softmax $S(X,y)$ to compute the probability $p(y|X)$
3. To train the model, we can either maximize the log-probability $log(p(y|X))$ or minimize the negative log-probability $-log(p(y|X))$, which is actually the `cross-entropy` between the target labels and the logits outputed from the model. We will minimize the negative log-probability $-log(p(y|X))$ with the help of Tensorflow framework.

To minimize the negative log-probability $-\text{log}(p(y|X))$, we need to first calculate $\text{log}(p(y|X))$. Fortunately, Tensorflow has implemented a funcation that computes the $log(p(y|X))$: `tf.contrib.crf.crf_log_likelihood()`.

If you walk through the [source code](https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/contrib/crf/python/ops/crf.py) of `tf.contrib.crf.crf_log_likelihood()`, you would find that it quite complies with the mathematical procedure of calculating the $\text{log}(p(y|X))$. Before we dive into the source code, let us first look at `tf.contrib.crf.crf_log_likelihood()` from a high-level view:

<img src='images/crf_code_procedure.png' style='height:550px;width:750px'>

`crf_log_likelihood()` decomposes the task of calculating $\text{log}(p(y|X))$ into two sub-tasks performed by functions `crf_sequence_score()` and `crf_log_norm` respectively. 

* `crf_sequence_score()` calculates score $S(X,y)$ and it decomposes this task into another two sub-tasks performed by functions `crf_binary_score` and `crf_unary_score`, which calculate:

<img src='images/s_score_decomposition.png' style='height:70px;width:200px'>

* `crf_log_norm` calculates:

<img src='images/crf_log_norm.png' style='height:50px;width:140px'>

Now, we quickly go over the source code. We do not need to understand every bit of the code. For now, we only need to know the calling stack of these functions. 

```python
def crf_log_likelihood(inputs, tag_indices, sequence_lengths, transition_params=None):
  # Get shape information.
  num_tags = inputs.get_shape()[2].value

  # Get the transition matrix if not provided.
  if transition_params is None:
    transition_params = vs.get_variable("transitions", [num_tags, num_tags])

  sequence_scores = crf_sequence_score(inputs, tag_indices, sequence_lengths,
                                       transition_params)
  log_norm = crf_log_norm(inputs, sequence_lengths, transition_params)

  # Normalize the scores to get the log-likelihood per example.
  log_likelihood = sequence_scores - log_norm
  return log_likelihood, transition_params
```



```python
def crf_sequence_score(inputs, tag_indices, sequence_lengths,
                       transition_params):
    # Compute the scores of the given tag sequence.
    unary_scores = crf_unary_score(tag_indices, sequence_lengths, inputs)
    binary_scores = crf_binary_score(tag_indices, sequence_lengths,
                                     transition_params)
    sequence_scores = unary_scores + binary_scores
    return sequence_scores
```

```python
def crf_unary_score(tag_indices, sequence_lengths, inputs):

  batch_size = array_ops.shape(inputs)[0]
  max_seq_len = array_ops.shape(inputs)[1]
  num_tags = array_ops.shape(inputs)[2]

  flattened_inputs = array_ops.reshape(inputs, [-1])

  offsets = array_ops.expand_dims(
      math_ops.range(batch_size) * max_seq_len * num_tags, 1)
  offsets += array_ops.expand_dims(math_ops.range(max_seq_len) * num_tags, 0)
  # Use int32 or int64 based on tag_indices' dtype.
  if tag_indices.dtype == dtypes.int64:
    offsets = math_ops.to_int64(offsets)
  flattened_tag_indices = array_ops.reshape(offsets + tag_indices, [-1])

  unary_scores = array_ops.reshape(
      array_ops.gather(flattened_inputs, flattened_tag_indices),
      [batch_size, max_seq_len])

  masks = array_ops.sequence_mask(sequence_lengths,
                                  maxlen=array_ops.shape(tag_indices)[1],
                                  dtype=dtypes.float32)

  unary_scores = math_ops.reduce_sum(unary_scores * masks, 1)
        
  # unary_scores is a tensor with shape [batch_size]
  return unary_scores
```

```python
def crf_binary_score(tag_indices, sequence_lengths, transition_params):

  # Get shape information.
  num_tags = transition_params.get_shape()[0]
  num_transitions = array_ops.shape(tag_indices)[1] - 1

  # Truncate by one on each side of the sequence to get the start and end
  # indices of each transition.
  start_tag_indices = array_ops.slice(tag_indices, [0, 0],
                                      [-1, num_transitions])
  end_tag_indices = array_ops.slice(tag_indices, [0, 1], [-1, num_transitions])

  # Encode the indices in a flattened representation.
  flattened_transition_indices = start_tag_indices * num_tags + end_tag_indices
  flattened_transition_params = array_ops.reshape(transition_params, [-1])

  # Get the binary scores based on the flattened representation.
  binary_scores = array_ops.gather(flattened_transition_params,
                                   flattened_transition_indices)

  masks = array_ops.sequence_mask(sequence_lengths,
                                  maxlen=array_ops.shape(tag_indices)[1],
                                  dtype=dtypes.float32)
  truncated_masks = array_ops.slice(masks, [0, 1], [-1, -1])
  binary_scores = math_ops.reduce_sum(binary_scores * truncated_masks, 1)
  return binary_scores
```

function `crf_log_likelihood` and function `crf_sequence_score` are quite straightforward. They just break down their tasks into smaller sub-tasks and delegate them to other functions. 

function `crf_unary_score` and function `crf_binary_score` are a little bit involved. Next, we will explain these two functions. 

## 2. `crf_unary_score` and `crf_binary_score`

The arguments that are required by the two functions are:
* `inputs`: A `[batch_size, max_seq_len, num_tags]` tensor of unary potentials to use as input to the CRF layer.
* `tag_indices`: A `[batch_size, max_seq_len]` matrix of tag indices for which we compute the log-likelihood.
* `sequence_lengths`: A `[batch_size]` vector of true sequence lengths.
* `transition_params`: A `[num_tags, num_tags]` transition matrix, if available.

For illustration purpose, we define mathematical notation for arguments `inputs` and `tag_indices`

we define input sequences as follow:
* $x_i$ represents the $i$th input sequence in a batch. $i=1,2,3,...,N$, where $N$ is the batch size.
* $x_{i,l}$ represents the $l$th word in $i$th input sequence. $l=1,2,3,...,L$, where $L$ is the max sequence length.
* $x_{i, l, t}$ represents the $t$th tag score for $l$th word in $i$th sequence. $t=1,2,3,...,T$, where $T$ is the number of labels/tags

we define tag indices as follow:
* $y_i$ represents the $i$th tag sequence in a batch. $i=1,2,3,...,N$, where $N$ is the batch size.
* $y_{i,l}$ represents the $l$th tag index in $i$th tag sequence. $l=1,2,3,...,L$, where $L$ is the max sequence length.

Following is an example of a batch of size 2, in which each sequence has max length of 3 and tag number is 4. 
* Sequences in a batch may have different length. In picture below, $x_0$ has length 3 while $x_1$ has length 2.
* The third word $x_{1,2}$ is the padding with all zeros

<img src='images/batch_example.png' style='height:310px;width:540px'>

###  2.1 `crf_unary_score`

In section, we will explain what task does `crf_unary_score` perform and how it does its task.

### 2.1.1 What

What `crf_unary_score` does is gathering scores of target labels $y_{i, l}$ for its corresponding $x_{i, l}$ and computing sum of those scores for each $x_i$

<img src='images/crf_unary_task.png' style='height:420px;width:770px'>

### 2.1.2 How

In this section, we walk through `crf_unary_score` step by step and explain how it performs its work.

**Step 1**

flatten input sequences in a batch to a 1-D array 

```python
flattened_inputs = array_ops.reshape(inputs, [-1])
```

<img src='images/crf_unary_flatten.png' style='height:220px;width:570px'>

**Step 2**

compute position index in the 1-D array `flattened_inputs` for each $y_{i,l}$ for every tag sequence $y_i$

for each $y_{i,l}$, its index in the `flattened_inputs` is computed by:

$$ i \times L \times T + l \times T + y_{i,l}$$

For example, the index of $y_{1,1}$ is $ 1 \times 3 \times 4 + 1 \times 4 + 1 = 17$, (where L = 3, T = 4 and $y_{1,1}=1$)

Following picture depicts the result:

<img src='images/crf_unary_flatten_example.png' style='height:45px;width:570px'>

The following code computes the indices for all $y_{i,l}$'s (for every tag sequence in a batch) by using matrix operation. Make sure you can understand following code (You can do small experiment to figure out what does expand_dims do).

```python
offsets = array_ops.expand_dims(math_ops.range(batch_size) * max_seq_len * num_tags, 1)
offsets += array_ops.expand_dims(math_ops.range(max_seq_len) * num_tags, 0)
# Use int32 or int64 based on tag_indices' dtype.
if tag_indices.dtype == dtypes.int64:
    offsets = math_ops.to_int64(offsets)
flattened_tag_indices = array_ops.reshape(offsets + tag_indices, [-1])
```

**Step 3**

gather tag score (i.e., unary score) for each $x_{i, l}$ based on index of $y_{i, l}$ and compute sum of tag scores for each sequence $x_i$

```python
unary_scores = array_ops.reshape(array_ops.gather(flattened_inputs, flattened_tag_indices), [batch_size, max_seq_len])
masks = array_ops.sequence_mask(sequence_lengths,vmaxlen=array_ops.shape(tag_indices)[1],vdtype=dtypes.float32)
unary_scores = math_ops.reduce_sum(unary_scores * masks, 1)
```

Note the `masks` is used to exclude padding in the calculation.

###  2.2 `crf_binary_score`

In section, we will explain what task does `crf_binary_score` perform and how it does its task.

As long as you undersand how `crf_unary_score` works, it should be easy to understand `crf_binary_score` since it follows the same logic. 

### 2.2.1 What

`crf_binary_score` calculates following formula for each tag sequence $y_i$:

$$ \sum_{l=0}^{L-1}
 A_{y_{i, l},y_{i, l+1}} $$
 
where $A$ is the `transition_params`

Following picture depicts how the `crf_binary_score` function processes a single tag sequence while `crf_binary_score` actually processes a batch of tag sequences at the same time.

<img src='images/crf_binary_task.png' style='height:420px;width:650px'>

### 2.2.2 How



**Step 1**

flatten transition params matrix to a 1-D array 

```python
flattened_transition_params = array_ops.reshape(transition_params, [-1])
```

**Step 2**

Truncate by one on each side of the sequence to get the start and end indices of each transition.

```python
start_tag_indices = array_ops.slice(tag_indices, [0, 0], [-1, num_transitions])
end_tag_indices = array_ops.slice(tag_indices, [0, 1], [-1, num_transitions])
```

**Step 3**

compute position index in the 1-D array `flattened_transition_params` for each transition $(y_{i, l},y_{i, l+1})$ for every tag sequence $y_i$

for each transition $(y_{i, l},y_{i, l+1})$, its index in the `flattened_transition_params` is computed by:

$$ y_{i,l} \times T + y_{i,l+1}$$

For example, the index of transition $(y_{0,0}, y_{0,1})$ is $  1 \times 4 + 0 = 4$, (where T = 4, $y_{0,0}=1$ and $y_{0,1}=0$)

The following code computes the indices for all transitions $(y_{i, l},y_{i, l+1})$'s (for every tag sequence in a batch) by using matrix operation.
```python
flattened_transition_indices = start_tag_indices * num_tags + end_tag_indices
```

**Step 4**

gather transition score (i.e., binary score) for each $A_{y_{i, l},y_{i, l+1}}$ based on index of $(y_{i,l}, y_{i,l+1})$ and compute sum of transition scores for each sequence $y_i$
    
```python
binary_scores = array_ops.gather(flattened_transition_params, flattened_transition_indices)
masks = array_ops.sequence_mask(sequence_lengths, maxlen=array_ops.shape(tag_indices)[1], dtype=dtypes.float32)
truncated_masks = array_ops.slice(masks, [0, 1], [-1, -1])
binary_scores = math_ops.reduce_sum(binary_scores * truncated_masks, 1)
```

Note the `masks` is used to exclude padding in the calculation.

## 3. Appendix

In [6]:
import tensorflow as tf
import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.layers import utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import variable_scope as vs

  return f(*args, **kwds)


In [4]:
batch_size = 2
max_seq_len = 3
num_tags = 4

rg = math_ops.range(batch_size)
offsets_1_pre = rg * max_seq_len * num_tags
offsets_1 = array_ops.expand_dims(offsets_1_pre, 1)

offsets_2 = array_ops.expand_dims(math_ops.range(max_seq_len) * num_tags, 0)

offsets_3 = offsets_1 + offsets_2
with tf.Session() as sess:
    rg_o = sess.run(rg)
    offset_1_o = sess.run(offsets_1)
    offset_2_o = sess.run(offsets_2)
    offset_3_o = sess.run(offsets_3)
#     print('rg_o shape', rg_o.shape)
#     print('rg_o', rg_o)
    print('offset_1_o shape', offset_1_o.shape)
    print('offset_1_o', offset_1_o)
    print('offset_2_o shape', offset_2_o.shape)
    print('offset_2_o', offset_2_o)
    print('offset_3_o shape', offset_3_o.shape)
    print('offset_3_o', offset_3_o)

offset_1_o shape (2, 1)
offset_1_o [[ 0]
 [12]]
offset_2_o shape (1, 3)
offset_2_o [[0 4 8]]
offset_3_o shape (2, 3)
offset_3_o [[ 0  4  8]
 [12 16 20]]


In [5]:
transition_params = np.asarray([[0.5, 0.7, 0.3],
                     [0.4, 0.1, 0.2],
                     [0.6, 0.9, 0.8]])
tag_indices = [[1,2,2],
               [0,1,0]]

num_tags = transition_params.shape
num_transitions = array_ops.shape(tag_indices)[1] - 1
start_tag_indices = array_ops.slice(tag_indices, [0, 0], [-1, num_transitions])
end_tag_indices = array_ops.slice(tag_indices, [0, 1], [-1, num_transitions])

# Encode the indices in a flattened representation.
flattened_transition_indices = start_tag_indices * num_tags + end_tag_indices
flattened_transition_params = array_ops.reshape(transition_params, [-1])

binary_scores = array_ops.gather(flattened_transition_params, flattened_transition_indices)
with tf.Session() as sess:
    start_tag_indices_o = sess.run(start_tag_indices)
    end_tag_indices_o = sess.run(end_tag_indices)
    flattened_transition_indices_o = sess.run(flattened_transition_indices)
    binary_scores_o = sess.run(binary_scores)
    print('start_tag_indices_o', start_tag_indices_o)
    print('end_tag_indices_o', end_tag_indices_o)
    print('flattened_transition_indices_o', flattened_transition_indices_o)
    print('binary_scores_o', binary_scores_o)

start_tag_indices_o [[1 2]
 [0 1]]
end_tag_indices_o [[2 2]
 [1 0]]
flattened_transition_indices_o [[5 8]
 [1 3]]
binary_scores_o [[0.2 0.8]
 [0.7 0.4]]
