# Named Entity Recognition with `chaine`

In [1]:
import datasets
import pandas as pd
from seqeval.metrics import classification_report

import chaine
from chaine.typing import Dataset, Features, Sentence, Tags

## Data

In [2]:
dataset = datasets.load_dataset("conll2003")

print(f"Number of sentences for training: {len(dataset['train']['tokens'])}")
print(f"Number of sentences for evaluation: {len(dataset['test']['tokens'])}")

Downloading builder script:   0%|          | 0.00/2.58k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/1.62k [00:00<?, ?B/s]

Reusing dataset conll2003 (/home/severin/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/63f4ebd1bcb7148b1644497336fd74643d4ce70123334431a3c053b7ee4e96ee)


  0%|          | 0/3 [00:00<?, ?it/s]

Number of sentences for training: 14042
Number of sentences for evaluation: 3454


## Featurization

In [3]:
def featurize_token(token_index: int, sentence: Sentence, pos_tags: Tags) -> Features:
    """Extract features from a token in a sentence.

    Parameters
    ----------
    token_index : int
        Index of the token to featurize in the sentence.
    sentence : Sentence
        Sequence of tokens.
    pos_tags : Tags
        Sequence of part-of-speech tags corresponding to the tokens in the sentence.

    Returns
    -------
    Features
        Features representing the token.
    """
    token = sentence[token_index]
    pos_tag = pos_tags[token_index]
    features = {
        "token.lower()": token.lower(),
        "token[-3:]": token[-3:],
        "token[-2:]": token[-2:],
        "token.isupper()": token.isupper(),
        "token.istitle()": token.istitle(),
        "token.isdigit()": token.isdigit(),
        "pos_tag": pos_tag,
    }
    if token_index > 0:
        previous_token = sentence[token_index - 1]
        previous_pos_tag = pos_tags[token_index - 1]
        features.update(
            {
                "-1:token.lower()": previous_token.lower(),
                "-1:token.istitle()": previous_token.istitle(),
                "-1:token.isupper()": previous_token.isupper(),
                "-1:pos_tag": previous_pos_tag,
            }
        )
    else:
        features["BOS"] = True
    if token_index < len(sentence) - 1:
        next_token = sentence[token_index + 1]
        next_pos_tag = pos_tags[token_index + 1]
        features.update(
            {
                "+1:token.lower()": next_token.lower(),
                "+1:token.istitle()": next_token.istitle(),
                "+1:token.isupper()": next_token.isupper(),
                "+1:pos_tag": next_pos_tag,
            }
        )
    else:
        features["EOS"] = True
    return features


def featurize_sentence(sentence: Sentence, pos_tags: Tags) -> list[Features]:
    """Extract features from tokens in a sentence.

    Parameters
    ----------
    sentence : Sentence
        Sequence of tokens.
    pos_tags : Tags
        Sequence of part-of-speech tags corresponding to the tokens in the sentence.

    Returns
    -------
    list[Features]
        List of features representing tokens of a sentence.
    """
    return [
        featurize_token(token_index, sentence, pos_tags) for token_index in range(len(sentence))
    ]


def featurize_dataset(dataset: Dataset) -> list[list[Features]]:
    """Extract features from sentences in a dataset.

    Parameters
    ----------
    dataset : Dataset
        Dataset to featurize.

    Returns
    -------
    list[list[Features]]
        Featurized dataset.
    """
    return [
        featurize_sentence(sentence, pos_tags)
        for sentence, pos_tags in zip(dataset["tokens"], dataset["pos_tags"])
    ]


def preprocess_labels(dataset: Dataset) -> list[list[str]]:
    """Translate raw labels (i.e. integers) to the respective string labels.

    Parameters
    ----------
    dataset : Dataset
        Dataset to preprocess labels.

    Returns
    -------
    list[list[Features]]
        Preprocessed labels.
    """
    labels = dataset.features["ner_tags"].feature.names
    return [[labels[index] for index in indices] for indices in dataset["ner_tags"]]

In [None]:
sentences = featurize_dataset(dataset["train"])
labels = preprocess_labels(dataset["train"])

## Training

In [7]:
model = chaine.train(sentences, labels, verbose=0)

## Evaluation

In [8]:
sentences = featurize_dataset(dataset["test"])
labels = preprocess_labels(dataset["test"])

In [9]:
predictions = model.predict(sentences)

print(classification_report(labels, predictions))

              precision    recall  f1-score   support

         LOC       0.67      0.65      0.66      1668
        MISC       0.72      0.50      0.59       702
         ORG       0.67      0.34      0.45      1661
         PER       0.71      0.66      0.69      1617

   micro avg       0.69      0.54      0.61      5648
   macro avg       0.70      0.54      0.60      5648
weighted avg       0.69      0.54      0.60      5648



## Optimization

In [10]:
model = chaine.train(sentences, labels, verbose=0, optimize_hyperparameters=True)

In [11]:
predictions = model.predict(sentences)

print(classification_report(labels, predictions))

              precision    recall  f1-score   support

         LOC       0.94      0.91      0.92      1668
        MISC       0.91      0.77      0.83       702
         ORG       0.90      0.88      0.89      1661
         PER       0.90      0.91      0.91      1617

   micro avg       0.92      0.88      0.90      5648
   macro avg       0.92      0.87      0.89      5648
weighted avg       0.92      0.88      0.90      5648



## Inspection

In [12]:
transitions = pd.DataFrame(model.transitions)
transitions.sort_values("weight", ascending=False)[:10]

Unnamed: 0,from,to,weight
12,B-PER,I-PER,8.655663
30,B-ORG,I-ORG,7.053397
33,I-ORG,I-ORG,6.440869
0,O,O,5.564141
3,O,B-MISC,5.456652
4,O,B-ORG,5.394471
22,B-MISC,I-MISC,5.257558
8,B-LOC,I-LOC,4.875593
2,O,B-PER,4.716164
1,O,B-LOC,4.539984


In [13]:
states = pd.DataFrame(model.states)
states.sort_values("weight", ascending=False)[:10]

Unnamed: 0,feature,label,weight
4385,token[-2:]:0M,O,5.749702
4386,token[-2:]:5M,O,5.541271
164,EOS,O,5.192606
214,+1:token.lower():1996-12-06,B-LOC,4.757576
215,+1:token.lower():1996-12-06,I-LOC,3.692241
3532,+1:token.lower():1996-12-07,B-LOC,3.504
1658,-1:token.lower():b,B-PER,3.502818
23,token.isdigit(),O,3.362199
36,BOS,O,3.280833
1095,-1:token.lower():at,B-LOC,3.198786
