# Named Entity Recognition on CoNLL-2003 with Transformer-based Models

In [None]:
# !pip install evaluate
# !pip install seqeval

In [2]:
import torch
from datasets import load_dataset
from evaluate import load
import numpy as np
from transformers import (
    AutoTokenizer, 
    T5ForConditionalGeneration,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    EarlyStoppingCallback,
    DataCollatorForSeq2Seq
)
import os
import re
from tqdm import tqdm
import warnings
import logging

os.environ["WANDB_DISABLED"] = "true"
# warnings.filterwarnings("ignore", category=FutureWarning)
# logging.getLogger("transformers").setLevel(logging.ERROR)





# Set random seed
SEED = 42
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

# Model configuration
MODEL_CONFIG = {
    'name': 't5-base',
    'batch_size': 8,
    'learning_rate': 5e-5,
    'num_epochs': 5,
    'weight_decay': 0.01,
    'max_length': 128  
}

def load_and_preprocess_data():
    """Load CoNLL-2003 dataset and prepare label list"""
    datasets = load_dataset("conll2003", trust_remote_code=True)
    
    # # For quick test - use a smaller subset
    # for split in ["train", "validation", "test"]:
    #     datasets[split] = datasets[split].select(range(len(datasets[split]) // 50))
    
    # Get label list from dataset features
    label_list = datasets["train"].features["ner_tags"].feature.names
    print(f"NER Label list: {label_list}")
    print(f"Number of labels: {len(label_list)}")
    return datasets, label_list


def extract_entities_from_tokens_and_tags(tokens, tags, label_list):
    """Extract entity descriptions from tokens and tags"""
    entities = []
    current_entity = None
    current_type = None
    
    for token, tag_idx in zip(tokens, tags):
        tag = label_list[tag_idx]
        
        # If the tag starts with B-, it indicates the beginning of a new entity
        if tag.startswith("B-"):
            # If there is an entity being processed, save it first
            if current_entity is not None:
                entities.append(f"{current_type}: {' '.join(current_entity)}")
            
            # Start a new entity
            current_type = tag[2:]  # Remove the "B-" prefix
            current_entity = [token]
        
        # If the tag starts with I-, it indicates the continuation of an entity
        elif tag.startswith("I-") and current_entity is not None:
            # Ensure the I- tag type matches the current entity type
            if tag[2:] == current_type:
                current_entity.append(token)
        
        # If it's an O tag or other case, it indicates the end of an entity
        elif current_entity is not None:
            entities.append(f"{current_type}: {' '.join(current_entity)}")
            current_entity = None
            current_type = None
    
    # Process the possible last entity
    if current_entity is not None:
        entities.append(f"{current_type}: {' '.join(current_entity)}")
    
    # If there are no entities, return "none"
    if not entities:
        return "none"
        
    return "; ".join(entities)

def tokenize_for_t5(examples, tokenizer, label_list):
    """Convert NER task to a text-to-text task for generating entity descriptions"""
    input_texts = []
    target_texts = []
    
    for tokens, tags in zip(examples["tokens"], examples["ner_tags"]):
        # Improvement 1: Add task prefix to help the model understand the task
        input_text = "extract entities: " + " ".join(tokens)
        
        # Output text is the entity description
        target_text = extract_entities_from_tokens_and_tags(tokens, tags, label_list)
        
        input_texts.append(input_text)
        target_texts.append(target_text)
    
    # Tokenize inputs
    model_inputs = tokenizer(
        input_texts,
        padding="max_length",
        max_length=MODEL_CONFIG['max_length'],
        truncation=True,
        return_tensors="np"
    )
    
    labels = tokenizer(
        text_target=target_texts,
        padding="max_length",
        max_length=MODEL_CONFIG['max_length'],
        truncation=True,
        return_tensors="np"
    ).input_ids
    
    # Replace padding token id with -100
    model_inputs["labels"] = np.where(
        labels == tokenizer.pad_token_id,
        -100,
        labels
    )
    
    return model_inputs

def find_sub_list(sl, l):
    """Find the position of a sublist in the main list"""
    results = []
    sll = len(sl)
    for ind in (i for i, e in enumerate(l) if e == sl[0]):
        if l[ind:ind+sll] == sl:
            results.append((ind, ind+sll-1))
    return results

def generate_labels_from_entities(tokens, entity_text, label_list):
    """Generate IOB format label sequence from entity description text"""
    # Initialize all labels as "O"
    labels = ["O"] * len(tokens)
    
    # If there are no entities or the input is "none", directly return all O labels
    if not entity_text or entity_text.lower() == "none":
        return labels
    
    # Parse entity descriptions
    entities = entity_text.split("; ")
    for entity in entities:
        # Skip improperly formatted entities
        if ": " not in entity:
            continue
        
        entity_type, entity_text = entity.split(": ", 1)
        entity_tokens = entity_text.split()
        
        # Find entities in the original text
        positions = find_sub_list(entity_tokens, tokens)
        
        # If the entity is found, mark the labels
        if positions:
            for start, end in positions:
                # Mark B-prefix (entity start)
                labels[start] = f"B-{entity_type.upper()}"
                # Mark I-prefix (entity interior)
                for i in range(start+1, end+1):
                    labels[i] = f"I-{entity_type.upper()}"
    
    # Ensure all labels are in the label_list
    valid_labels = set(label_list)
    for i, label in enumerate(labels):
        if label not in valid_labels:
            labels[i] = "O"
    
    return labels

def compute_metrics(eval_pred, tokenizer, label_list, metric):
    """Calculate NER metrics using generated entity descriptions"""
    predictions, labels = eval_pred

    # Ensure prediction values are within valid tokenizer range
    vocab_size = tokenizer.vocab_size
    print(f"Tokenizer vocab size: {vocab_size}")
    print(f"Prediction min: {np.min(predictions)}, max: {np.max(predictions)}")
    
    # Clip out-of-range values to valid range
    predictions = np.clip(predictions, 0, vocab_size - 1).astype(np.int32)
    
    # Decode predictions and labels
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    
    # Handle -100 in labels
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    

    # Use tokens from the actual evaluation dataset for more accurate evaluation
    # Since we cannot directly access the original tokens during evaluation, we use a reasonable approximation
    # Here we use dummy tokens with a length of 30 (typically the length of a sentence)
    dummy_tokens = ["token"] * 30  
    
    all_pred_entities = []
    all_true_entities = []
    
    for pred, label in zip(decoded_preds, decoded_labels):
        # Extract entities from prediction and label text
        pred_entities = pred.strip()
        true_entities = label.strip()
        
        # Improve entity extraction
        # Standardize the format of predicted entities
        if pred_entities and not ": " in pred_entities and not ";" in pred_entities:
            # Try to automatically fix non-standard output format
            # For example "AL-AIN, United Arab Emirates" -> "LOC: AL-AIN; LOC: United Arab Emirates"
            for entity_type in ["PER", "ORG", "LOC", "MISC"]:
                if entity_type.lower() in pred_entities.lower():
                    pred_entities = f"{entity_type}: {pred_entities}"
                    break
        
        # Convert entity descriptions to IOB label sequences
        pred_labels = generate_labels_from_entities(dummy_tokens, pred_entities, label_list)
        true_labels = generate_labels_from_entities(dummy_tokens, true_entities, label_list)
        
        # Ensure label sequence lengths match
        min_len = min(len(pred_labels), len(true_labels))
        all_pred_entities.append(pred_labels[:min_len])
        all_true_entities.append(true_labels[:min_len])
    
    # Calculate metrics
    results = metric.compute(
        predictions=all_pred_entities,
        references=all_true_entities,
        zero_division=0
    )
    
    return {
        "precision": float(results["overall_precision"]),
        "recall": float(results["overall_recall"]),
        "f1": float(results["overall_f1"])
    }


        
def evaluate_model_with_real_data(model, tokenizer, test_dataset, label_list, metric):
    """Evaluate the model using real data"""
    # Set the model to evaluation mode
    model.eval()
    
    all_pred_entities = []
    all_true_entities = []
    
    # Create a small batch for evaluation
    batch_size = 16
    for i in tqdm(range(0, len(test_dataset), batch_size)):
        batch = test_dataset[i:i+batch_size]
        
        # Prepare inputs - add task prefix consistent with training
        inputs = tokenizer(
            ["extract entities: " + " ".join(tokens) for tokens in batch["tokens"]],
            padding=True,
            truncation=True,
            max_length=MODEL_CONFIG['max_length'],
            return_tensors="pt"
        ).to(model.device)
        
        # Generate predictions - increase num_beams to improve generation quality
        with torch.no_grad():
            outputs = model.generate(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                max_length=MODEL_CONFIG['max_length'],
                num_beams=4,  # Use beam search to improve generation quality
                early_stopping=True
            )
        
        # Decode predictions
        decoded_preds = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        
        # Process each sample
        for j, (tokens, tags, pred) in enumerate(zip(batch["tokens"], batch["ner_tags"], decoded_preds)):
            # Get the true entity description
            true_entities_text = extract_entities_from_tokens_and_tags(tokens, tags, label_list)
            
            # Fix prediction format
            normalized_pred = pred
            
            # If the prediction doesn't contain formatted entities (no ":" or ";"), try to fix
            if pred and not ": " in pred and not ";" in pred:
                entity_matches = []
                
                # Try to match entities in the text
                for entity_type in ["PER", "ORG", "LOC", "MISC"]:
                    # Look for tokens in the original text that appear in the predicted text
                    text = " ".join(tokens)
                    words = pred.split()
                    
                    # Try to find continuous word groups in the prediction
                    for i in range(len(words)):
                        for j in range(i+1, min(i+6, len(words)+1)):  # Look at a window of up to 5 words
                            phrase = " ".join(words[i:j])
                            if phrase in text:
                                entity_matches.append((entity_type, phrase))
                
                if entity_matches:
                    normalized_entities = [f"{t}: {e}" for t, e in entity_matches]
                    normalized_pred = "; ".join(normalized_entities)
            
            # Convert predicted and true entity descriptions to IOB label sequences
            pred_labels = generate_labels_from_entities(tokens, normalized_pred, label_list)
            true_labels = generate_labels_from_entities(tokens, true_entities_text, label_list)
            
            # Ensure label sequence lengths match
            min_len = min(len(pred_labels), len(true_labels), len(tokens))
            all_pred_entities.append(pred_labels[:min_len])
            all_true_entities.append(true_labels[:min_len])
    
    # Calculate metrics
    results = metric.compute(
        predictions=all_pred_entities,
        references=all_true_entities,
        zero_division=0
    )
    
  
    # Output detailed entity type metrics
    if "PER" in results:
        print("\n=== Entity Type Metrics ===")
        for entity_type in ["PER", "ORG", "LOC", "MISC"]:
            if entity_type in results:
                print(f"{entity_type}:")
                print(f"  Precision: {results[entity_type]['precision']:.4f}")
                print(f"  Recall: {results[entity_type]['recall']:.4f}")
                print(f"  F1: {results[entity_type]['f1']:.4f}")
    
    return results



    

#Train and evaluate the model

# Load dataset and prepare labels
datasets, label_list = load_and_preprocess_data()
metric = load("seqeval")

# Initialize tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(MODEL_CONFIG['name'])
model = T5ForConditionalGeneration.from_pretrained(MODEL_CONFIG['name'])

# Print model info
print("\n=== Model Info ===")
print(f"Vocab size: {tokenizer.vocab_size}")
print(f"Model parameters: {model.num_parameters()}")

# Ensure pad token exists
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    print(f"Set pad_token to eos_token: {tokenizer.pad_token}")

# Tokenize datasets
tokenized_datasets = datasets.map(
    lambda x: tokenize_for_t5(x, tokenizer, label_list),
    batched=True,
    remove_columns=datasets["train"].column_names,
    load_from_cache_file=False  # Disable caching for debugging
)

# Convert datasets to PyTorch format
for split in tokenized_datasets.keys():
    tokenized_datasets[split].set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

# Print dataset sizes
print("\n=== Dataset Sizes ===")
print(f"Train dataset size: {len(tokenized_datasets['train'])}")
print(f"Validation dataset size: {len(tokenized_datasets['validation'])}")
print(f"Test dataset size: {len(tokenized_datasets['test'])}")

# Training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir=f"{MODEL_CONFIG['name']}-finetuned-ner",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=MODEL_CONFIG['learning_rate'],
    per_device_train_batch_size=MODEL_CONFIG['batch_size'],
    per_device_eval_batch_size=MODEL_CONFIG['batch_size'],
    num_train_epochs=MODEL_CONFIG['num_epochs'],
    weight_decay=MODEL_CONFIG['weight_decay'],
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    greater_is_better=True,
    save_total_limit=1,
    predict_with_generate=True,
    generation_max_length=MODEL_CONFIG['max_length'],
    generation_num_beams=4,  
    fp16=torch.cuda.is_available(),
    report_to="none"
)


# Initialize data collator
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

# Add EarlyStoppingCallback
early_stopping_callback = EarlyStoppingCallback(early_stopping_patience=3)

# Initialize trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    compute_metrics=lambda x: compute_metrics(x, tokenizer, label_list, metric),
    data_collator=data_collator,
    callbacks=[early_stopping_callback]
)

# Train model
print("\n=== Starting Training ===")
trainer.train()

# Evaluate on test set using our custom evaluation function
print("\n=== Evaluating on Test Set ===")
test_results = evaluate_model_with_real_data(
    model, 
    tokenizer, 
    datasets["test"], 
    label_list, 
    metric
)

print("\n=== Test Set Metrics ===")
print(f"Precision: {test_results['overall_precision']:.4f}")
print(f"Recall: {test_results['overall_recall']:.4f}")
print(f"F1 Score: {test_results['overall_f1']:.4f}")



README.md:   0%|          | 0.00/12.3k [00:00<?, ?B/s]

conll2003.py:   0%|          | 0.00/9.57k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/983k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/14041 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3250 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/3453 [00:00<?, ? examples/s]

NER Label list: ['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']
Number of labels: 9


Downloading builder script:   0%|          | 0.00/6.34k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/892M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]


=== Model Info ===
Vocab size: 32100
Model parameters: 222903552


Map:   0%|          | 0/14041 [00:00<?, ? examples/s]

Map:   0%|          | 0/3250 [00:00<?, ? examples/s]

Map:   0%|          | 0/3453 [00:00<?, ? examples/s]


=== Dataset Sizes ===
Train dataset size: 14041
Validation dataset size: 3250
Test dataset size: 3453





=== Starting Training ===


  batch["labels"] = torch.tensor(batch["labels"], dtype=torch.int64)
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Epoch,Training Loss,Validation Loss,Precision,Recall,F1
1,0.0848,0.059281,0.0,0.0,0.0
2,0.0434,0.054156,0.0,0.0,0.0
3,0.0297,0.052319,0.0,0.0,0.0
4,0.0195,0.055446,0.0,0.0,0.0


Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Tr

Tokenizer vocab size: 32100
Prediction min: -100, max: 31978


  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.process

Tokenizer vocab size: 32100
Prediction min: -100, max: 31978


  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.process

Tokenizer vocab size: 32100
Prediction min: -100, max: 31978


  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.process

Tokenizer vocab size: 32100
Prediction min: -100, max: 31978


  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight', 'lm_head.weight'].



=== Evaluating on Test Set ===


100%|██████████| 216/216 [03:34<00:00,  1.01it/s]



=== Entity Type Metrics ===
PER:
  Precision: 0.9749
  Recall: 0.9123
  F1: 0.9426
ORG:
  Precision: 0.8429
  Recall: 0.8906
  F1: 0.8661
LOC:
  Precision: 0.9052
  Recall: 0.9138
  F1: 0.9095
MISC:
  Precision: 0.7865
  Recall: 0.7955
  F1: 0.7910

=== Test Set Metrics ===
Precision: 0.8896
Recall: 0.8918
F1 Score: 0.8907
