# Named Entity Recognition using CRF Model

We use the **CoNLL-2003** dataset to train a Conditional Random Field (CRF) model for Named Entity Recognition.

**9 Distinct NER Tags:**

| Tag | Meaning |
|-----|---------|
| O | Outside any entity |
| B-PER | Beginning of a Person |
| I-PER | Inside a Person |
| B-ORG | Beginning of an Organization |
| I-ORG | Inside an Organization |
| B-LOC | Beginning of a Location |
| I-LOC | Inside a Location |
| B-MISC | Beginning of Miscellaneous |
| I-MISC | Inside Miscellaneous |

#### Importing Libraries

In [15]:
from datasets import load_dataset
from sklearn_crfsuite import CRF
from sklearn_crfsuite.metrics import flat_f1_score, flat_classification_report

#### Loading the CoNLL-2003 Dataset

In [16]:
dataset = load_dataset("conll2003", trust_remote_code=True)
print(dataset)

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'conll2003' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
Using the latest cached version of the dataset since conll2003 couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'conll2003' at /Users/thesakshipandey/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/9a4d16a94f8674ba3466315300359b0acd891b68b6c8743ddf60b9c702adce98 (last modified on Sat Nov  2 01:08:19 2024).


DatasetDict({
    train: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 14041
    })
    validation: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 3250
    })
    test: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 3453
    })
})


In [17]:
# NER tag mapping in CoNLL-2003
tag_names = dataset["train"].features["ner_tags"].feature.names
print("NER Tags:", tag_names)
print("Number of tags:", len(tag_names))

NER Tags: ['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']
Number of tags: 9


In [18]:
# Explore a sample sentence
sample = dataset["train"][0]
print("Tokens:", sample["tokens"])
print("POS tags:", sample["pos_tags"])
print("NER tags:", [tag_names[t] for t in sample["ner_tags"]])

Tokens: ['EU', 'rejects', 'German', 'call', 'to', 'boycott', 'British', 'lamb', '.']
POS tags: [22, 42, 16, 21, 35, 37, 16, 21, 7]
NER tags: ['B-ORG', 'O', 'B-MISC', 'O', 'O', 'O', 'B-MISC', 'O', 'O']


In [19]:
# Dataset statistics
for split in ["train", "validation", "test"]:
    print(f"{split}: {len(dataset[split])} sentences")

train: 14041 sentences
validation: 3250 sentences
test: 3453 sentences


#### Preparing Sentences as (word, POS, NER) Tuples

Convert each dataset example into a list of `(word, pos_tag, ner_tag)` tuples for CRF processing.

In [20]:
pos_tag_names = dataset["train"].features["pos_tags"].feature.names

def convert_to_tuples(example):
    """Convert a dataset example to list of (word, pos, ner_tag) tuples."""
    return list(zip(
        example["tokens"],
        [pos_tag_names[p] for p in example["pos_tags"]],
        [tag_names[t] for t in example["ner_tags"]]
    ))

train_sents = [convert_to_tuples(ex) for ex in dataset["train"]]
test_sents = [convert_to_tuples(ex) for ex in dataset["test"]]

print("Sample sentence:")
print(train_sents[0])

Sample sentence:
[('EU', 'NNP', 'B-ORG'), ('rejects', 'VBZ', 'O'), ('German', 'JJ', 'B-MISC'), ('call', 'NN', 'O'), ('to', 'TO', 'O'), ('boycott', 'VB', 'O'), ('British', 'JJ', 'B-MISC'), ('lamb', 'NN', 'O'), ('.', '.', 'O')]


#### Feature Extraction

For each word, we extract features: lowercase form, suffixes, capitalization, POS tag, and context features from neighboring words.

In [29]:
import re
import string

_PUNCT = set(string.punctuation)

def word_shape(w: str) -> str:
    out = []
    for ch in w:
        if ch.isupper():
            out.append('X')
        elif ch.islower():
            out.append('x')
        elif ch.isdigit():
            out.append('d')
        else:
            out.append(ch)
    shape = ''.join(out)
    shape = re.sub(r'(.)\1+', r'\1', shape)  # compress repeats
    return shape

def has_any_digit(w: str) -> bool:
    return any(ch.isdigit() for ch in w)

def is_punct_token(w: str) -> bool:
    return len(w) > 0 and all(ch in _PUNCT for ch in w)

def add_word_features(feats: dict, w: str, prefix: str):
    wl = w.lower()

    feats[f'{prefix}w.lower'] = wl
    feats[f'{prefix}w.shape'] = word_shape(w)
    feats[f'{prefix}w.len'] = len(w)

    # casing
    feats[f'{prefix}is_upper'] = w.isupper()
    feats[f'{prefix}is_title'] = w.istitle()
    feats[f'{prefix}is_lower'] = w.islower()

    # digits & patterns
    feats[f'{prefix}is_digit'] = w.isdigit()
    feats[f'{prefix}has_digit'] = has_any_digit(w)
    feats[f'{prefix}is_year'] = bool(re.fullmatch(r'(19|20)\d{2}', w))
    feats[f'{prefix}is_decimal'] = bool(re.fullmatch(r'\d+\.\d+', w))
    feats[f'{prefix}is_ordinal'] = bool(re.fullmatch(r'\d+(st|nd|rd|th)', wl))

    # punctuation / special
    feats[f'{prefix}is_punct'] = is_punct_token(w)
    feats[f'{prefix}has_hyphen'] = '-' in w
    feats[f'{prefix}has_apostrophe'] = "'" in w
    feats[f'{prefix}has_dot'] = '.' in w
    feats[f'{prefix}has_slash'] = '/' in w
    feats[f'{prefix}is_initial'] = bool(re.fullmatch(r'[A-Za-z]\.', w))

    # affixes (1-4)
    for k in (1, 2, 3, 4):
        if len(wl) >= k:
            feats[f'{prefix}pref{k}'] = wl[:k]
            feats[f'{prefix}suf{k}'] = wl[-k:]

    # length buckets
    L = len(w)
    feats[f'{prefix}len<=2'] = (L <= 2)
    feats[f'{prefix}len3-5'] = (3 <= L <= 5)
    feats[f'{prefix}len6-8'] = (6 <= L <= 8)
    feats[f'{prefix}len>=9'] = (L >= 9)

def get_token(sent, i):
    """
    Works for:
      - sent[i] = (token, label)
      - sent[i] = (token, pos, label)
      - sent[i] = (token,)  etc.
    """
    return sent[i][0]

def word2features(sent, i):
    w = get_token(sent, i)

    feats = {'bias': 1.0}

    # current token
    add_word_features(feats, w, prefix='0:')

    # previous tokens
    if i > 0:
        w_1 = get_token(sent, i-1)
        add_word_features(feats, w_1, prefix='-1:')
        feats['-1:wl|0:wl'] = w_1.lower() + '|' + w.lower()
    else:
        feats['BOS'] = True

    if i > 1:
        w_2 = get_token(sent, i-2)
        add_word_features(feats, w_2, prefix='-2:')

    # next tokens
    if i < len(sent) - 1:
        w1 = get_token(sent, i+1)
        add_word_features(feats, w1, prefix='+1:')
        feats['0:wl|+1:wl'] = w.lower() + '|' + w1.lower()
    else:
        feats['EOS'] = True

    if i < len(sent) - 2:
        w2 = get_token(sent, i+2)
        add_word_features(feats, w2, prefix='+2:')

    return feats

def sent2features(sent):
    return [word2features(sent, i) for i in range(len(sent))]

def sent2labels(sent):
    # supports (token,label) or (token,pos,label)
    return [tok[-1] for tok in sent]


In [30]:
X_train = [sent2features(s) for s in train_sents]
y_train = [sent2labels(s) for s in train_sents]

X_test = [sent2features(s) for s in test_sents]
y_test = [sent2labels(s) for s in test_sents]

print(f"Training sentences: {len(X_train)}")
print(f"Test sentences: {len(X_test)}")

Training sentences: 14041
Test sentences: 3453


#### Training the CRF Model

In [31]:
crf = CRF(
    algorithm='lbfgs',
    c1=0.1,
    c2=0.1,
    max_iterations=100,
    all_possible_transitions=True
)
crf.fit(X_train, y_train)

0,1,2
,algorithm,'lbfgs'
,min_freq,
,all_possible_states,
,all_possible_transitions,True
,c1,0.1
,c2,0.1
,max_iterations,100
,num_memories,
,epsilon,
,period,


In [32]:
# Predicting on the test set
y_pred = crf.predict(X_test)

#### Evaluating Model Performance

We use precision, recall and F1-score since the dataset has class imbalance (the `O` tag dominates).

In [33]:
# Weighted F1 score
f1 = flat_f1_score(y_test, y_pred, average='weighted')
print(f"Weighted F1 Score: {f1:.4f}")

Weighted F1 Score: 0.9624


In [34]:
# Detailed classification report for all 9 tags
labels = ['B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC', 'O']
report = flat_classification_report(y_test, y_pred, labels=labels, digits=3)
print(report)

              precision    recall  f1-score   support

       B-PER      0.874     0.869     0.871      1617
       I-PER      0.907     0.957     0.931      1156
       B-ORG      0.817     0.768     0.792      1661
       I-ORG      0.713     0.789     0.749       835
       B-LOC      0.879     0.888     0.883      1668
       I-LOC      0.795     0.770     0.783       257
      B-MISC      0.788     0.791     0.789       702
      I-MISC      0.502     0.685     0.579       216
           O      0.991     0.987     0.989     38323

    accuracy                          0.962     46435
   macro avg      0.807     0.834     0.819     46435
weighted avg      0.963     0.962     0.962     46435



#### Top Learned Feature Weights

Inspect the most important features the CRF learned for each entity tag.

In [35]:
def top_bottom(d, k=10):
    items = sorted(d.items(), key=lambda x: x[1])
    return items[-k:], items[:k]

topT, botT = top_bottom(crf.transition_features_, 10)
print("Top transitions:")
for (a,b), w in topT:
    print(a,"->",b, w)

print("\nWorst transitions:")
for (a,b), w in botT:
    print(a,"->",b, w)

Top transitions:
O -> B-PER 1.013043
O -> B-LOC 1.182261
I-ORG -> I-ORG 3.621163
I-LOC -> I-LOC 3.753718
B-LOC -> I-LOC 4.077761
I-PER -> I-PER 4.104527
B-ORG -> I-ORG 4.461146
I-MISC -> I-MISC 4.69559
B-MISC -> I-MISC 5.034972
B-PER -> I-PER 6.598406

Worst transitions:
O -> I-ORG -7.761333
O -> I-MISC -6.851659
O -> I-LOC -6.311759
O -> I-PER -5.870169
B-LOC -> I-ORG -4.845515
B-MISC -> I-ORG -4.502389
B-PER -> B-PER -3.294188
B-MISC -> I-LOC -3.109591
I-LOC -> I-ORG -2.983415
B-LOC -> I-MISC -2.95856


In [36]:
from seqeval.metrics import f1_score as seq_f1_score, classification_report as seq_classification_report
from sklearn.metrics import accuracy_score

WINDOW = 1
USE_BIGRAMS = True
print(f"Using WINDOW={WINDOW}, USE_BIGRAMS={USE_BIGRAMS}")
print(f"Feature data prepared.")
print(f"Example sentence feature length: {len(X_test[0])} tokens")

# Validation set
val_sents = [convert_to_tuples(ex) for ex in dataset["validation"]]
X_val = [sent2features(s) for s in val_sents]
y_val = [sent2labels(s) for s in val_sents]
y_val_pred = crf.predict(X_val)

# Entity-level (seqeval) F1
val_f1 = seq_f1_score(y_val, y_val_pred)
test_f1 = seq_f1_score(y_test, y_pred)
print(f"\nEntity-level (seqeval) F1:")
print(f"  Val F1 : {val_f1:.4f}")
print(f"  Test F1: {test_f1:.2f}")

# Token accuracy (sanity check)
val_acc = accuracy_score([t for s in y_val for t in s], [t for s in y_val_pred for t in s])
test_acc = accuracy_score([t for s in y_test for t in s], [t for s in y_pred for t in s])
print(f"\nToken accuracy (sanity check):")
print(f"  Val token acc : {val_acc:.4f}")
print(f"  Test token acc: {test_acc:.4f}")

# Detailed entity-level report (TEST)
print(f"\nDetailed entity-level report (TEST):")
print(seq_classification_report(y_test, y_pred, digits=2))

Using WINDOW=1, USE_BIGRAMS=True
Feature data prepared.
Example sentence feature length: 12 tokens

Entity-level (seqeval) F1:
  Val F1 : 0.8915
  Test F1: 0.83

Token accuracy (sanity check):
  Val token acc : 0.9793
  Test token acc: 0.9620

Detailed entity-level report (TEST):
              precision    recall  f1-score   support

         LOC       0.87      0.88      0.88      1668
        MISC       0.76      0.76      0.76       702
         ORG       0.79      0.75      0.77      1661
         PER       0.87      0.87      0.87      1617

   micro avg       0.84      0.82      0.83      5648
   macro avg       0.82      0.81      0.82      5648
weighted avg       0.84      0.82      0.83      5648

