In [22]:
import numpy as np
import pandas as pd
import nltk
from sklearn_crfsuite import CRF
from sklearn_crfsuite.metrics import flat_classification_report

In [23]:
def load_conll_data(file_path):
    sentences = []
    sentence = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if not line:  # New sentence
                if sentence:
                    sentences.append(sentence)
                    sentence = []
            else:
                parts = line.split()
                if len(parts) == 4 and parts[0] != "-DOCSTART-":  # Skip DOCSTART
                    word, pos, chunk, ner = parts
                    sentence.append((word, pos, ner))
    return sentences


In [24]:
data_path = "dataset_NER/train.txt"  # Update with your file path
sentences = load_conll_data(data_path)
print(f"Loaded {len(sentences)} sentences.")
print(sentences[:2])  # Display first two sentences

Loaded 14041 sentences.
[[('EU', 'NNP', 'B-ORG'), ('rejects', 'VBZ', 'O'), ('German', 'JJ', 'B-MISC'), ('call', 'NN', 'O'), ('to', 'TO', 'O'), ('boycott', 'VB', 'O'), ('British', 'JJ', 'B-MISC'), ('lamb', 'NN', 'O'), ('.', '.', 'O')], [('Peter', 'NNP', 'B-PER'), ('Blackburn', 'NNP', 'I-PER')]]


In [25]:
def word_features(sent, i):
    word = sent[i][0]
    postag = sent[i][1]
    
    features = {
        'word': word.lower(),
        'postag': postag,
        'is_upper': word.isupper(),
        'is_title': word.istitle(),
        'is_digit': word.isdigit(),
    }
    if i > 0:
        prev_word = sent[i-1][0]
        prev_postag = sent[i-1][1]
        features.update({
            '-1:word': prev_word.lower(),
            '-1:postag': prev_postag,
        })
    else:
        features['BOS'] = True  # Beginning of Sentence
    
    if i < len(sent)-1:
        next_word = sent[i+1][0]
        next_postag = sent[i+1][1]
        features.update({
            '+1:word': next_word.lower(),
            '+1:postag': next_postag,
        })
    else:
        features['EOS'] = True  # End of Sentence
    
    return features

# Apply feature extraction to sentences
X = [[word_features(sent, i) for i in range(len(sent))] for sent in sentences]
y = [[token[2] for token in sent] for sent in sentences]  # NER labels


In [26]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print(f"Training sentences: {len(X_train)}, Testing sentences: {len(X_test)}")

Training sentences: 11232, Testing sentences: 2809


In [27]:
from sklearn_crfsuite import CRF

# Initialize and train the CRF model
crf = CRF(algorithm='lbfgs', max_iterations=100)
crf.fit(X_train, y_train)

In [28]:
from seqeval.metrics import classification_report

# Make predictions
y_pred = crf.predict(X_test)

# Generate a classification report
print(classification_report(y_test, y_pred))

              precision    recall  f1-score   support

         LOC       0.88      0.89      0.88      1423
        MISC       0.89      0.75      0.81       675
         ORG       0.82      0.76      0.79      1212
         PER       0.86      0.87      0.87      1254

   micro avg       0.86      0.83      0.84      4564
   macro avg       0.86      0.82      0.84      4564
weighted avg       0.86      0.83      0.84      4564



In [29]:
from collections import Counter

def evaluate_dataset(sentences, y_true, y_pred):
    results = {"PER": Counter(), "LOC": Counter(), "ORG": Counter()}
    for sentence, true_labels, pred_labels in zip(sentences, y_true, y_pred):
        words = [token[0] for token in sentence]
        true_entities = extract_entities(words, true_labels)
        pred_entities = extract_entities(words, pred_labels)

        for entity_type, entity in true_entities:
            if entity_type in results:
                if entity in [e[1] for e in pred_entities if e[0] == entity_type]:
                    results[entity_type]["correct"] += 1
                else:
                    results[entity_type]["missed"] += 1

        for entity_type, entity in pred_entities:
            if entity_type in results:
                if entity not in [e[1] for e in true_entities if e[0] == entity_type]:
                    results[entity_type]["false_positive"] += 1

    return results

# Compute evaluation metrics
results = evaluate_dataset(sentences, y_test, y_pred)
print("Evaluation Results:")
for entity, counts in results.items():
    print(f"{entity}: {counts}")

Evaluation Results:
PER: Counter({'correct': 790, 'false_positive': 110, 'missed': 108})
LOC: Counter({'correct': 868, 'missed': 111, 'false_positive': 104})
ORG: Counter({'correct': 734, 'missed': 184, 'false_positive': 145})


## Extract Specific Entity Predictions

In [30]:
# Function to extract entities from a sentence
def extract_entities(words, labels):
    entities = []
    entity = []
    current_label = None

    for word, label in zip(words, labels):
        if label.startswith("B-"):  # Beginning of a new entity
            if entity:
                entities.append((current_label, " ".join(entity)))
            entity = [word]
            current_label = label[2:]
        elif label.startswith("I-") and label[2:] == current_label:  # Continuation of an entity
            entity.append(word)
        else:
            if entity:
                entities.append((current_label, " ".join(entity)))
            entity = []
            current_label = None

    if entity:
        entities.append((current_label, " ".join(entity)))

    return entities

# Analyze predictions for a sample
sample_index = 3  # Change this to check different samples
sample_words = [token[0] for token in sentences[sample_index]]  # Original words
sample_true = y_test[sample_index]  # True labels
sample_pred = y_pred[sample_index]  # Predicted labels

print("Words:", sample_words)
print("True Entities:", extract_entities(sample_words, sample_true))
print("Predicted Entities:", extract_entities(sample_words, sample_pred))

Words: ['The', 'European', 'Commission', 'said', 'on', 'Thursday', 'it', 'disagreed', 'with', 'German', 'advice', 'to', 'consumers', 'to', 'shun', 'British', 'lamb', 'until', 'scientists', 'determine', 'whether', 'mad', 'cow', 'disease', 'can', 'be', 'transmitted', 'to', 'sheep', '.']
True Entities: [('LOC', 'The'), ('LOC', 'Commission said')]
Predicted Entities: [('LOC', 'The'), ('LOC', 'Commission said')]
