# Named Entity Recognition with `chaine`

In [1]:
import datasets
import pandas as pd
from seqeval.metrics import classification_report

import chaine
from chaine.typing import Dataset, Features, Sentence, Tags

## Data

In [2]:
dataset = datasets.load_dataset("conll2003")

print(f"Number of sentences for training: {len(dataset['train']['tokens'])}")
print(f"Number of sentences for evaluation: {len(dataset['test']['tokens'])}")

Reusing dataset conll2003 (/home/severin/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/63f4ebd1bcb7148b1644497336fd74643d4ce70123334431a3c053b7ee4e96ee)


  0%|          | 0/3 [00:00<?, ?it/s]

Number of sentences for training: 14042
Number of sentences for evaluation: 3454


## Featurization

In [3]:
def featurize_token(token_index: int, sentence: Sentence, pos_tags: Tags) -> Features:
    """Extract features from a token in a sentence.

    Parameters
    ----------
    token_index : int
        Index of the token to featurize in the sentence.
    sentence : Sentence
        Sequence of tokens.
    pos_tags : Tags
        Sequence of part-of-speech tags corresponding to the tokens in the sentence.

    Returns
    -------
    Features
        Features representing the token.
    """
    token = sentence[token_index]
    pos_tag = pos_tags[token_index]
    features = {
        "token.lower()": token.lower(),
        "token[-3:]": token[-3:],
        "token[-2:]": token[-2:],
        "token.isupper()": token.isupper(),
        "token.istitle()": token.istitle(),
        "token.isdigit()": token.isdigit(),
        "pos_tag": pos_tag,
    }
    if token_index > 0:
        previous_token = sentence[token_index - 1]
        previous_pos_tag = pos_tags[token_index - 1]
        features.update(
            {
                "-1:token.lower()": previous_token.lower(),
                "-1:token.istitle()": previous_token.istitle(),
                "-1:token.isupper()": previous_token.isupper(),
                "-1:pos_tag": previous_pos_tag,
            }
        )
    else:
        features["BOS"] = True
    if token_index < len(sentence) - 1:
        next_token = sentence[token_index + 1]
        next_pos_tag = pos_tags[token_index + 1]
        features.update(
            {
                "+1:token.lower()": next_token.lower(),
                "+1:token.istitle()": next_token.istitle(),
                "+1:token.isupper()": next_token.isupper(),
                "+1:pos_tag": next_pos_tag,
            }
        )
    else:
        features["EOS"] = True
    return features


def featurize_sentence(sentence: Sentence, pos_tags: Tags) -> list[Features]:
    """Extract features from tokens in a sentence.

    Parameters
    ----------
    sentence : Sentence
        Sequence of tokens.
    pos_tags : Tags
        Sequence of part-of-speech tags corresponding to the tokens in the sentence.

    Returns
    -------
    list[Features]
        List of features representing tokens of a sentence.
    """
    return [
        featurize_token(token_index, sentence, pos_tags) for token_index in range(len(sentence))
    ]


def featurize_dataset(dataset: Dataset) -> list[list[Features]]:
    """Extract features from sentences in a dataset.

    Parameters
    ----------
    dataset : Dataset
        Dataset to featurize.

    Returns
    -------
    list[list[Features]]
        Featurized dataset.
    """
    return [
        featurize_sentence(sentence, pos_tags)
        for sentence, pos_tags in zip(dataset["tokens"], dataset["pos_tags"])
    ]


def preprocess_labels(dataset: Dataset) -> list[list[str]]:
    """Translate raw labels (i.e. integers) to the respective string labels.

    Parameters
    ----------
    dataset : Dataset
        Dataset to preprocess labels.

    Returns
    -------
    list[list[Features]]
        Preprocessed labels.
    """
    labels = dataset.features["ner_tags"].feature.names
    return [[labels[index] for index in indices] for indices in dataset["ner_tags"]]

In [4]:
sentences = featurize_dataset(dataset["train"])
labels = preprocess_labels(dataset["train"])

In [5]:
sentences[0][0]

{'token.lower()': 'eu',
 'token[-3:]': 'EU',
 'token[-2:]': 'EU',
 'token.isupper()': True,
 'token.istitle()': False,
 'token.isdigit()': False,
 'pos_tag': 22,
 'BOS': True,
 '+1:token.lower()': 'rejects',
 '+1:token.istitle()': False,
 '+1:token.isupper()': False,
 '+1:pos_tag': 42}

In [6]:
labels[0][0]

'B-ORG'

## Training

In [7]:
model = chaine.train(sentences, labels, verbose=0)

## Evaluation

In [8]:
sentences = featurize_dataset(dataset["test"])
labels = preprocess_labels(dataset["test"])

In [9]:
predictions = model.predict(sentences)

print(classification_report(labels, predictions))

              precision    recall  f1-score   support

         LOC       0.82      0.72      0.77      1668
        MISC       0.66      0.66      0.66       702
         ORG       0.70      0.59      0.64      1661
         PER       0.82      0.77      0.79      1617

   micro avg       0.76      0.69      0.72      5648
   macro avg       0.75      0.69      0.72      5648
weighted avg       0.76      0.69      0.72      5648



## Optimization

In [10]:
model = chaine.train(sentences, labels, verbose=0, optimize_hyperparameters=True)

[2022-05-24 14:12:57,820] [INFO] Starting with arow (1/5)
[2022-05-24 14:12:57,821] [INFO] Baseline for arow
[2022-05-24 14:13:10,700] [INFO] Trial 1/10 for arow
[2022-05-24 14:13:35,112] [INFO] Best baseline model: 0.8053851088265638
[2022-05-24 14:13:35,113] [INFO] Best optimized model: 0.8053851088265638
[2022-05-24 14:13:35,115] [INFO] Trial 2/10 for arow
[2022-05-24 14:13:57,998] [INFO] Best baseline model: 0.8053851088265638
[2022-05-24 14:13:57,999] [INFO] Best optimized model: 0.8053851088265638
[2022-05-24 14:13:57,999] [INFO] Trial 3/10 for arow
[2022-05-24 14:14:22,100] [INFO] Best baseline model: 0.8053851088265638
[2022-05-24 14:14:22,101] [INFO] Best optimized model: 0.8053851088265638
[2022-05-24 14:14:22,101] [INFO] Trial 4/10 for arow
[2022-05-24 14:14:41,419] [INFO] Best baseline model: 0.8053851088265638
[2022-05-24 14:14:41,420] [INFO] Best optimized model: 0.8053851088265638
[2022-05-24 14:14:41,422] [INFO] Trial 5/10 for arow
[2022-05-24 14:15:02,888] [INFO] Best 

In [11]:
predictions = model.predict(sentences)

print(classification_report(labels, predictions))

              precision    recall  f1-score   support

         LOC       0.97      0.95      0.96      1668
        MISC       0.97      0.87      0.91       702
         ORG       0.94      0.93      0.94      1661
         PER       0.97      0.96      0.96      1617

   micro avg       0.96      0.94      0.95      5648
   macro avg       0.96      0.93      0.94      5648
weighted avg       0.96      0.94      0.95      5648



## Inspection

In [12]:
transitions = pd.DataFrame(model.transitions)
transitions.sort_values("weight", ascending=False)[:10]

Unnamed: 0,from,to,weight
12,B-PER,I-PER,8.318353
30,B-ORG,I-ORG,7.168033
22,B-MISC,I-MISC,6.609091
26,I-MISC,I-MISC,6.177694
34,I-ORG,I-ORG,6.17187
0,O,O,5.502018
8,B-LOC,I-LOC,5.196033
18,I-LOC,I-LOC,4.680722
3,O,B-MISC,4.323212
15,I-PER,I-PER,4.204643


In [13]:
states = pd.DataFrame(model.states)
states.sort_values("weight", ascending=False)[:10]

Unnamed: 0,feature,label,weight
4353,token[-2:]:0M,O,9.405752
4354,token[-2:]:5M,O,8.817209
216,+1:token.lower():1996-12-06,B-LOC,5.713073
3807,token.lower():painewebber,B-ORG,5.619434
217,+1:token.lower():1996-12-06,I-LOC,5.243976
3704,+1:token.lower():exxon,O,5.152197
1013,token.lower():italy,B-LOC,5.052394
168,EOS,O,4.99369
1605,-1:token.lower():b,B-PER,4.895346
3233,token.lower():trans-atlantic,B-MISC,4.789043
