# Named Entity Recognition with `chaine`

In [1]:
import datasets
import pandas as pd
from seqeval.metrics import classification_report

import chaine
from chaine.typing import Dataset, Features, Sentence, Tags

## Data

In [2]:
dataset = datasets.load_dataset("conll2003")

print(f"Number of sentences for training: {len(dataset['train']['tokens'])}")
print(f"Number of sentences for evaluation: {len(dataset['test']['tokens'])}")

Reusing dataset conll2003 (/home/severin/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/63f4ebd1bcb7148b1644497336fd74643d4ce70123334431a3c053b7ee4e96ee)


  0%|          | 0/3 [00:00<?, ?it/s]

Number of sentences for training: 14042
Number of sentences for evaluation: 3454


## Featurization

In [3]:
def featurize_token(token_index: int, sentence: Sentence, pos_tags: Tags) -> Features:
    """Extract features from a token in a sentence.

    Parameters
    ----------
    token_index : int
        Index of the token to featurize in the sentence.
    sentence : Sentence
        Sequence of tokens.
    pos_tags : Tags
        Sequence of part-of-speech tags corresponding to the tokens in the sentence.

    Returns
    -------
    Features
        Features representing the token.
    """
    token = sentence[token_index]
    pos_tag = pos_tags[token_index]
    features = {
        "token.lower()": token.lower(),
        "token[-3:]": token[-3:],
        "token[-2:]": token[-2:],
        "token.isupper()": token.isupper(),
        "token.istitle()": token.istitle(),
        "token.isdigit()": token.isdigit(),
        "pos_tag": pos_tag,
    }
    if token_index > 0:
        previous_token = sentence[token_index - 1]
        previous_pos_tag = pos_tags[token_index - 1]
        features.update(
            {
                "-1:token.lower()": previous_token.lower(),
                "-1:token.istitle()": previous_token.istitle(),
                "-1:token.isupper()": previous_token.isupper(),
                "-1:pos_tag": previous_pos_tag,
            }
        )
    else:
        features["BOS"] = True
    if token_index < len(sentence) - 1:
        next_token = sentence[token_index + 1]
        next_pos_tag = pos_tags[token_index + 1]
        features.update(
            {
                "+1:token.lower()": next_token.lower(),
                "+1:token.istitle()": next_token.istitle(),
                "+1:token.isupper()": next_token.isupper(),
                "+1:pos_tag": next_pos_tag,
            }
        )
    else:
        features["EOS"] = True
    return features


def featurize_sentence(sentence: Sentence, pos_tags: Tags) -> list[Features]:
    """Extract features from tokens in a sentence.

    Parameters
    ----------
    sentence : Sentence
        Sequence of tokens.
    pos_tags : Tags
        Sequence of part-of-speech tags corresponding to the tokens in the sentence.

    Returns
    -------
    list[Features]
        List of features representing tokens of a sentence.
    """
    return [
        featurize_token(token_index, sentence, pos_tags) for token_index in range(len(sentence))
    ]


def featurize_dataset(dataset: Dataset) -> list[list[Features]]:
    """Extract features from sentences in a dataset.

    Parameters
    ----------
    dataset : Dataset
        Dataset to featurize.

    Returns
    -------
    list[list[Features]]
        Featurized dataset.
    """
    return [
        featurize_sentence(sentence, pos_tags)
        for sentence, pos_tags in zip(dataset["tokens"], dataset["pos_tags"])
    ]


def preprocess_labels(dataset: Dataset) -> list[list[str]]:
    """Translate raw labels (i.e. integers) to the respective string labels.

    Parameters
    ----------
    dataset : Dataset
        Dataset to preprocess labels.

    Returns
    -------
    list[list[Features]]
        Preprocessed labels.
    """
    labels = dataset.features["ner_tags"].feature.names
    return [[labels[index] for index in indices] for indices in dataset["ner_tags"]]

In [4]:
train_sentences = featurize_dataset(dataset["train"])
train_labels = preprocess_labels(dataset["train"])

In [5]:
train_sentences[0][0]

{'token.lower()': 'eu',
 'token[-3:]': 'EU',
 'token[-2:]': 'EU',
 'token.isupper()': True,
 'token.istitle()': False,
 'token.isdigit()': False,
 'pos_tag': 22,
 'BOS': True,
 '+1:token.lower()': 'rejects',
 '+1:token.istitle()': False,
 '+1:token.isupper()': False,
 '+1:pos_tag': 42}

In [6]:
train_labels[0][0]

'B-ORG'

## Training

In [7]:
model = chaine.train(train_sentences, train_labels, verbose=0)

## Evaluation

In [8]:
test_sentences = featurize_dataset(dataset["test"])
test_labels = preprocess_labels(dataset["test"])

In [9]:
predictions = model.predict(test_sentences)

print(classification_report(test_labels, predictions))

              precision    recall  f1-score   support

         LOC       0.82      0.72      0.77      1668
        MISC       0.66      0.66      0.66       702
         ORG       0.70      0.59      0.64      1661
         PER       0.82      0.77      0.79      1617

   micro avg       0.76      0.69      0.72      5648
   macro avg       0.75      0.69      0.72      5648
weighted avg       0.76      0.69      0.72      5648



## Optimization

In [10]:
model = chaine.train(train_sentences, train_labels, verbose=0, optimize_hyperparameters=True)

[2022-05-25 09:23:22,872] [INFO] Starting with arow (1/5)
[2022-05-25 09:23:22,873] [INFO] Baseline for arow
[2022-05-25 09:24:20,396] [INFO] Trial 1/10 for arow
[2022-05-25 09:25:29,783] [INFO] Best baseline model: 0.8538541405229366
[2022-05-25 09:25:29,784] [INFO] Best optimized model: 0.8909010732379357
[2022-05-25 09:25:29,784] [INFO] Trial 2/10 for arow
[2022-05-25 09:26:37,112] [INFO] Best baseline model: 0.8538541405229366
[2022-05-25 09:26:37,113] [INFO] Best optimized model: 0.8978907390633878
[2022-05-25 09:26:37,114] [INFO] Trial 3/10 for arow
[2022-05-25 09:27:56,355] [INFO] Best baseline model: 0.8538541405229366
[2022-05-25 09:27:56,355] [INFO] Best optimized model: 0.8978907390633878
[2022-05-25 09:27:56,356] [INFO] Trial 4/10 for arow
[2022-05-25 09:29:04,767] [INFO] Best baseline model: 0.8538541405229366
[2022-05-25 09:29:04,768] [INFO] Best optimized model: 0.8978907390633878
[2022-05-25 09:29:04,769] [INFO] Trial 5/10 for arow
[2022-05-25 09:30:12,147] [INFO] Best 

In [11]:
predictions = model.predict(test_sentences)

print(classification_report(test_labels, predictions))

              precision    recall  f1-score   support

         LOC       0.85      0.79      0.82      1668
        MISC       0.80      0.70      0.75       702
         ORG       0.77      0.64      0.70      1661
         PER       0.80      0.84      0.82      1617

   micro avg       0.81      0.75      0.78      5648
   macro avg       0.80      0.74      0.77      5648
weighted avg       0.81      0.75      0.77      5648



## Inspection

In [12]:
transitions = pd.DataFrame(model.transitions)
transitions.sort_values("weight", ascending=False)[:10]

Unnamed: 0,from,to,weight
6,B-ORG,I-ORG,5.480503
10,O,O,5.142529
60,I-ORG,I-ORG,5.054006
31,B-PER,I-PER,4.621046
25,B-MISC,I-MISC,4.454587
11,O,B-MISC,4.309001
70,I-MISC,I-MISC,4.072832
12,O,B-PER,3.98464
53,B-LOC,I-LOC,3.686581
80,I-LOC,I-LOC,3.629553


In [13]:
states = pd.DataFrame(model.states)
states.sort_values("weight", ascending=False)[:10]

Unnamed: 0,feature,label,weight
405,EOS,O,4.452106
58,BOS,O,3.263974
751,token[-3:]:day,O,2.879297
142519,token[-2:]:5M,O,2.85629
142632,token[-2:]:0M,O,2.812777
39882,-1:token.lower():v,B-ORG,2.672666
4651,-1:token.lower():at,B-LOC,2.664489
31888,token[-2:]:I,O,2.658717
31879,token[-3:]:I,O,2.658717
76909,token.lower():clinton,B-PER,2.546819
