# NER Model Testing on Custom Text

This notebook loads trained BiLSTM models and tests them on custom literary texts.

**Requirements:**
- Trained models in `../results/` directory
- Vocabulary and tag mappings in `../data/` directory

**Features:**
- Load and test multiple saved models
- Predict NER tags on custom text
- Visualize entities with color coding
- Compare predictions from different models side-by-side
- Interactive testing interface
- Export predictions to CSV/JSON

## 1. Imports and Dependencies

In [None]:
# Import PyTorch for model loading and inference
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence

# Import standard libraries
import json
import numpy as np
import pandas as pd
from pathlib import Path
from collections import defaultdict
import warnings
warnings.filterwarnings('ignore')

# Import for text processing
import re

# Import for visualization
from IPython.display import HTML, display

print("All imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 2. Configuration

In [None]:
# ============================================================================
# CONFIGURATION - Paths and Parameters
# ============================================================================

# Directory paths
DATA_DIR = Path('../data')
RESULTS_DIR = Path('../results')
OUTPUT_DIR = Path('../output')

# Create output directory if it doesn't exist
OUTPUT_DIR.mkdir(exist_ok=True)

# Model file paths
MODEL_PATHS = {
    'Random': RESULTS_DIR / 'BiLSTM_Random_best.pt',
    'FastText': RESULTS_DIR / 'BiLSTM_FastText_best.pt'
}

# Data file paths
VOCAB_PATH = DATA_DIR / 'vocabulary.json'
TAG_MAPPINGS_PATH = DATA_DIR / 'tag_mappings.json'

# Model hyperparameters (must match training configuration)
MAX_LEN = 128  # Maximum sequence length
EMBEDDING_DIM = 300
HIDDEN_DIM = 256
NUM_LAYERS = 3
DROPOUT = 0.6
BIDIRECTIONAL = True

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print("Configuration loaded successfully!")
print(f"Device: {device}")
print(f"Max sequence length: {MAX_LEN}")

## 3. Define Model Architecture

Define the same BiLSTM architecture used during training.

In [None]:
class BiLSTMTagger(nn.Module):
    """
    Bidirectional LSTM model for Named Entity Recognition.
    
    Architecture:
    1. Embedding layer: Converts word indices to dense vectors
    2. Bidirectional LSTM: Processes sequences in both directions
    3. Dropout: Regularization to prevent overfitting
    4. Fully connected layer: Maps LSTM outputs to tag scores
    """
    
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_tags, 
                 num_layers=2, dropout=0.5, pretrained_embeddings=None):
        super(BiLSTMTagger, self).__init__()
        
        # Save dimensions
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        # Embedding layer
        if pretrained_embeddings is not None:
            self.embedding = nn.Embedding.from_pretrained(
                pretrained_embeddings, 
                freeze=False,  # Allow fine-tuning
                padding_idx=0  # <PAD> token
            )
        else:
            self.embedding = nn.Embedding(
                vocab_size, 
                embedding_dim, 
                padding_idx=0
            )
        
        # Bidirectional LSTM
        self.lstm = nn.LSTM(
            embedding_dim,
            hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=True
        )
        
        # Dropout layer
        self.dropout = nn.Dropout(dropout)
        
        # Fully connected layer (hidden_dim * 2 because bidirectional)
        self.fc = nn.Linear(hidden_dim * 2, num_tags)
    
    def forward(self, x):
        """
        Forward pass.
        
        Args:
            x: Input tensor of word indices, shape (batch_size, seq_len)
        
        Returns:
            Tag scores, shape (batch_size, seq_len, num_tags)
        """
        # Embed words: (batch_size, seq_len) -> (batch_size, seq_len, embedding_dim)
        embedded = self.embedding(x)
        
        # Apply dropout to embeddings
        embedded = self.dropout(embedded)
        
        # LSTM forward pass: (batch_size, seq_len, embedding_dim) -> (batch_size, seq_len, hidden_dim * 2)
        lstm_out, _ = self.lstm(embedded)
        
        # Apply dropout to LSTM outputs
        lstm_out = self.dropout(lstm_out)
        
        # Fully connected layer: (batch_size, seq_len, hidden_dim * 2) -> (batch_size, seq_len, num_tags)
        tag_scores = self.fc(lstm_out)
        
        return tag_scores

print("BiLSTMTagger model architecture defined successfully!")

## 4. Load Resources

Load vocabulary, tag mappings, and trained models.

In [None]:
# Load vocabulary
print("Loading vocabulary...")
with open(VOCAB_PATH, 'r') as f:
    vocab_data = json.load(f)
    word2idx = vocab_data['word2idx']
    vocab_size = vocab_data['vocab_size']

print(f"  Vocabulary size: {vocab_size:,}")

# Load tag mappings
print("\nLoading tag mappings...")
with open(TAG_MAPPINGS_PATH, 'r') as f:
    tag_data = json.load(f)
    tag2idx = tag_data['tag2idx']
    idx2tag = {int(k): v for k, v in tag_data['idx2tag'].items()}  # Convert keys to int
    PAD_TAG_IDX = tag_data['PAD_TAG_IDX']

num_tags = len(tag2idx)
print(f"  Number of tags: {num_tags}")
print(f"  Tags: {list(tag2idx.keys())}")

# Load models
models = {}
print("\nLoading models...")

for model_name, model_path in MODEL_PATHS.items():
    if model_path.exists():
        # Create model instance
        model = BiLSTMTagger(
            vocab_size=vocab_size,
            embedding_dim=EMBEDDING_DIM,
            hidden_dim=HIDDEN_DIM,
            num_tags=num_tags,
            num_layers=NUM_LAYERS,
            dropout=DROPOUT,
            pretrained_embeddings=None
        )
        
        # Load saved weights
        model.load_state_dict(torch.load(model_path, map_location=device))
        model.to(device)
        model.eval()  # Set to evaluation mode
        
        models[model_name] = model
        print(f"  ✓ Loaded {model_name} model from {model_path.name}")
    else:
        print(f"  ✗ Model file not found: {model_path}")

if not models:
    print("\n⚠ WARNING: No models loaded! Please ensure model files exist in ../results/")
else:
    print(f"\n✓ Successfully loaded {len(models)} model(s)")

print("\n" + "="*80)
print("RESOURCES LOADED SUCCESSFULLY")
print("="*80)

## 5. Helper Functions

Define functions for prediction, visualization, and comparison.

In [None]:
def simple_tokenize(text):
    """
    Simple tokenization that splits on whitespace and punctuation.
    
    Args:
        text: Input text string
    
    Returns:
        List of tokens
    """
    # Split on whitespace and punctuation but keep punctuation
    tokens = re.findall(r"\w+|[^\w\s]", text)
    return tokens


def predict_ner(model, text, word2idx, idx2tag, device, max_len=128):
    """
    Predict NER tags for a given text using a trained model.
    
    Args:
        model: Trained BiLSTM model
        text: Input text string
        word2idx: Dictionary mapping words to indices
        idx2tag: Dictionary mapping tag indices to tag names
        device: torch device (cpu or cuda)
        max_len: Maximum sequence length (default: 128)
    
    Returns:
        List of (token, tag) tuples
    """
    # Tokenize the input text
    tokens = simple_tokenize(text)
    
    if not tokens:
        return []
    
    # Truncate if necessary
    original_length = len(tokens)
    tokens = tokens[:max_len]
    
    # Convert tokens to indices (use <UNK> for unknown words)
    unk_idx = word2idx.get('<UNK>', 1)
    token_indices = [word2idx.get(token.lower(), unk_idx) for token in tokens]
    
    # Pad sequence to max_len
    if len(token_indices) < max_len:
        token_indices += [0] * (max_len - len(token_indices))  # 0 is <PAD>
    
    # Convert to tensor and add batch dimension
    input_tensor = torch.tensor([token_indices], dtype=torch.long).to(device)
    
    # Get predictions
    model.eval()
    with torch.no_grad():
        tag_scores = model(input_tensor)  # (1, max_len, num_tags)
        predictions = torch.argmax(tag_scores, dim=2)  # (1, max_len)
    
    # Convert predictions to tags (only for original tokens, not padding)
    predicted_tags = [idx2tag[idx.item()] for idx in predictions[0][:len(tokens)]]
    
    # Create (token, tag) pairs
    result = list(zip(tokens, predicted_tags))
    
    if original_length > max_len:
        print(f"⚠ Note: Text truncated from {original_length} to {max_len} tokens")
    
    return result


def get_entity_color(tag):
    """
    Get color for entity tag visualization.
    
    Args:
        tag: NER tag (e.g., 'B-PER', 'I-LOC', 'O')
    
    Returns:
        HTML color code
    """
    color_map = {
        'PER': '#FFB6C1',  # Light pink for persons
        'LOC': '#ADD8E6',  # Light blue for locations
        'GPE': '#90EE90',  # Light green for geopolitical entities
        'ORG': '#FFD700',  # Gold for organizations
        'FAC': '#DDA0DD',  # Plum for facilities
        'VEH': '#F0E68C',  # Khaki for vehicles
    }
    
    # Extract entity type from tag (e.g., 'B-PER' -> 'PER')
    if tag == 'O':
        return None
    
    entity_type = tag.split('-')[-1] if '-' in tag else tag
    return color_map.get(entity_type, '#D3D3D3')  # Default gray


def visualize_entities(predictions, show_all=False):
    """
    Visualize NER predictions with color-coded entity types.
    
    Args:
        predictions: List of (token, tag) tuples
        show_all: If True, show all tokens including 'O' tags (default: False)
    
    Returns:
        HTML display of colored entities
    """
    if not predictions:
        print("No predictions to visualize.")
        return
    
    html_parts = []
    
    for token, tag in predictions:
        color = get_entity_color(tag)
        
        if color:
            # Entity token - highlight with color
            html_parts.append(
                f'<span style="background-color: {color}; padding: 2px 4px; '
                f'margin: 2px; border-radius: 3px; font-weight: bold;"'
                f'title="{tag}">{token}</span>'
            )
        elif show_all:
            # Non-entity token - show without highlighting
            html_parts.append(f'<span style="margin: 2px;">{token}</span>')
    
    # Display HTML
    html = '<div style="line-height: 2.5; font-size: 14px;">' + ' '.join(html_parts) + '</div>'
    display(HTML(html))
    
    # Create legend
    legend_html = '<div style="margin-top: 15px; font-size: 12px;"><b>Legend:</b> '
    legend_items = [
        ('PER', 'Person', '#FFB6C1'),
        ('LOC', 'Location', '#ADD8E6'),
        ('GPE', 'Geo-Political Entity', '#90EE90'),
        ('ORG', 'Organization', '#FFD700'),
        ('FAC', 'Facility', '#DDA0DD'),
        ('VEH', 'Vehicle', '#F0E68C')
    ]
    
    for code, name, color in legend_items:
        legend_html += (
            f'<span style="background-color: {color}; padding: 2px 8px; '
            f'margin: 2px; border-radius: 3px; font-weight: bold;">{code}</span> '
            f'<span style="margin-right: 10px;">{name}</span> '
        )
    legend_html += '</div>'
    display(HTML(legend_html))


def compare_models(text, models, word2idx, idx2tag, device, max_len=128):
    """
    Compare predictions from multiple models side-by-side.
    
    Args:
        text: Input text string
        models: Dictionary of model_name -> model
        word2idx: Dictionary mapping words to indices
        idx2tag: Dictionary mapping tag indices to tag names
        device: torch device (cpu or cuda)
        max_len: Maximum sequence length (default: 128)
    """
    if not models:
        print("No models available for comparison.")
        return
    
    print("\n" + "="*80)
    print("MODEL COMPARISON")
    print("="*80)
    print(f"\nInput text: {text}")
    print("\n" + "-"*80)
    
    # Get predictions from all models
    all_predictions = {}
    for model_name, model in models.items():
        predictions = predict_ner(model, text, word2idx, idx2tag, device, max_len)
        all_predictions[model_name] = predictions
    
    # Display each model's predictions
    for model_name, predictions in all_predictions.items():
        print(f"\n{model_name} Model:")
        print("-" * 40)
        
        # Show entities only
        entities = [(token, tag) for token, tag in predictions if tag != 'O']
        
        if entities:
            print(f"Found {len(entities)} entities:")
            for token, tag in entities:
                print(f"  {token:20s} -> {tag}")
        else:
            print("  No entities found.")
        
        # Visualize
        print("\nVisualization:")
        visualize_entities(predictions)
        print()
    
    print("="*80)


def export_predictions(predictions, output_path, format='csv'):
    """
    Export predictions to CSV or JSON file.
    
    Args:
        predictions: List of (token, tag) tuples
        output_path: Path to output file
        format: 'csv' or 'json' (default: 'csv')
    """
    # Filter to entities only
    entities = [(token, tag) for token, tag in predictions if tag != 'O']
    
    if format == 'csv':
        df = pd.DataFrame(entities, columns=['Token', 'Tag'])
        df.to_csv(output_path, index=False)
        print(f"✓ Exported {len(entities)} entities to {output_path}")
    
    elif format == 'json':
        data = [{'token': token, 'tag': tag} for token, tag in entities]
        with open(output_path, 'w') as f:
            json.dump(data, f, indent=2)
        print(f"✓ Exported {len(entities)} entities to {output_path}")
    
    else:
        print(f"✗ Unsupported format: {format}. Use 'csv' or 'json'.")


print("Helper functions defined successfully!")
print("  - simple_tokenize()")
print("  - predict_ner()")
print("  - visualize_entities()")
print("  - compare_models()")
print("  - export_predictions()")

## 6. Interactive Testing Cell

**Modify the text below to test your own literary passages!**

In [None]:
# ============================================================================
# INTERACTIVE TESTING - Modify this text to test your own passages
# ============================================================================

# Enter your custom text here:
custom_text = "Elizabeth Bennet lived in Longbourn with her family. She often visited her friend Charlotte Lucas in the nearby village."

# Test with one model (change 'FastText' to 'Random' to try the other model)
model_name = 'FastText'  # or 'Random'

if model_name in models:
    print(f"\nTesting {model_name} model...")
    print("="*80)
    print(f"Input: {custom_text}")
    print("="*80)
    
    # Get predictions
    predictions = predict_ner(models[model_name], custom_text, word2idx, idx2tag, device, MAX_LEN)
    
    # Display entities only
    entities = [(token, tag) for token, tag in predictions if tag != 'O']
    
    print(f"\nFound {len(entities)} entities:")
    if entities:
        for token, tag in entities:
            print(f"  {token:20s} -> {tag}")
    else:
        print("  No entities found.")
    
    # Visualize
    print("\nVisualization:")
    visualize_entities(predictions)
    
    # Optional: Export to file
    # export_predictions(predictions, OUTPUT_DIR / 'my_predictions.csv', format='csv')
    
else:
    print(f"Model '{model_name}' not found. Available models: {list(models.keys())}")

## 7. Batch Testing on Sample Texts

Test multiple sample literary texts and compare model predictions.

In [None]:
# ============================================================================
# SAMPLE LITERARY TEXTS FOR TESTING
# ============================================================================

sample_texts = [
    "Elizabeth Bennet lived in Longbourn.",
    "Sherlock Holmes resided at 221B Baker Street in London.",
    "Captain Ahab commanded the whaling ship Pequod.",
    "Hester Prynne lived in Boston during the Puritan era.",
    "Jay Gatsby threw lavish parties at his mansion in West Egg, New York.",
    "Atticus Finch practiced law in the town of Maycomb, Alabama.",
    "Holden Caulfield was expelled from Pencey Prep and wandered around New York City.",
    "The March sisters lived in Concord, Massachusetts during the Civil War.",
]

print("\n" + "="*80)
print("BATCH TESTING ON SAMPLE LITERARY TEXTS")
print("="*80)

for i, text in enumerate(sample_texts, 1):
    print(f"\n{'='*80}")
    print(f"Sample {i}/{len(sample_texts)}")
    print(f"{'='*80}")
    compare_models(text, models, word2idx, idx2tag, device, MAX_LEN)

print("\n" + "="*80)
print("BATCH TESTING COMPLETE")
print("="*80)

## 8. Detailed Comparison on Single Text

Choose one text for detailed side-by-side comparison of both models.

In [None]:
# Choose a text for detailed comparison
detailed_text = "In the spring of 1922, Jay Gatsby moved into a mansion in West Egg, Long Island, where he threw elaborate parties. His neighbor Nick Carraway often observed the festivities from his small cottage."

print("\n" + "="*80)
print("DETAILED MODEL COMPARISON")
print("="*80)

compare_models(detailed_text, models, word2idx, idx2tag, device, MAX_LEN)

## 9. Export Predictions (Optional)

Export predictions to CSV or JSON format for further analysis.

In [None]:
# Example: Export predictions from one model to CSV

# Choose text and model
export_text = "Sherlock Holmes and Dr. Watson investigated a case at Buckingham Palace in London."
export_model_name = 'FastText'  # or 'Random'

if export_model_name in models:
    # Get predictions
    predictions = predict_ner(models[export_model_name], export_text, word2idx, idx2tag, device, MAX_LEN)
    
    # Export to CSV
    csv_path = OUTPUT_DIR / f'{export_model_name}_predictions.csv'
    export_predictions(predictions, csv_path, format='csv')
    
    # Export to JSON
    json_path = OUTPUT_DIR / f'{export_model_name}_predictions.json'
    export_predictions(predictions, json_path, format='json')
    
    print(f"\nPredictions exported to:")
    print(f"  - {csv_path}")
    print(f"  - {json_path}")
else:
    print(f"Model '{export_model_name}' not found.")

## 10. Summary Statistics

Display summary statistics about model predictions across all sample texts.

In [None]:
# Analyze predictions across all sample texts
print("\n" + "="*80)
print("SUMMARY STATISTICS")
print("="*80)

stats = defaultdict(lambda: defaultdict(int))

for model_name, model in models.items():
    print(f"\n{model_name} Model:")
    print("-" * 40)
    
    total_entities = 0
    entity_types = defaultdict(int)
    
    for text in sample_texts:
        predictions = predict_ner(model, text, word2idx, idx2tag, device, MAX_LEN)
        
        for token, tag in predictions:
            if tag != 'O':
                total_entities += 1
                # Extract entity type (e.g., 'B-PER' -> 'PER')
                entity_type = tag.split('-')[-1] if '-' in tag else tag
                entity_types[entity_type] += 1
    
    print(f"Total entities found: {total_entities}")
    print(f"\nEntity type distribution:")
    for entity_type, count in sorted(entity_types.items(), key=lambda x: x[1], reverse=True):
        percentage = (count / total_entities * 100) if total_entities > 0 else 0
        print(f"  {entity_type:10s}: {count:3d} ({percentage:5.1f}%)")

print("\n" + "="*80)
print("ANALYSIS COMPLETE")
print("="*80)