# üß† InLegalBERT Fine-Tuning for Legal NER

This notebook fine-tunes InLegalBERT for Named Entity Recognition on Indian legal documents.

**Entity Types:** PETITIONER, RESPONDENT, JUDGE, COURT, STATUTE, PROVISION, PRECEDENT, DATE, etc.

**Model:** [law-ai/InLegalBERT](https://huggingface.co/law-ai/InLegalBERT)

In [None]:
# Setup
import sys
sys.path.insert(0, '..')

import json
import torch
import numpy as np
from pathlib import Path

from transformers import (
    AutoTokenizer,
    AutoModelForTokenClassification,
    TrainingArguments,
    Trainer,
    DataCollatorForTokenClassification
)
from datasets import Dataset
from seqeval.metrics import classification_report, f1_score

from src.models import LegalNERModel
from src.utils import set_seed

# Set seed for reproducibility
set_seed(42)

# Check device
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")

## 1. Define Entity Labels

In [None]:
# Legal entity labels (BIO format)
LABELS = [
    "O",
    "B-PETITIONER", "I-PETITIONER",
    "B-RESPONDENT", "I-RESPONDENT",
    "B-JUDGE", "I-JUDGE",
    "B-LAWYER", "I-LAWYER",
    "B-COURT", "I-COURT",
    "B-STATUTE", "I-STATUTE",
    "B-PROVISION", "I-PROVISION",
    "B-PRECEDENT", "I-PRECEDENT",
    "B-CASE_NUMBER", "I-CASE_NUMBER",
    "B-DATE", "I-DATE",
    "B-GPE", "I-GPE",
    "B-ORG", "I-ORG",
]

label2id = {label: i for i, label in enumerate(LABELS)}
id2label = {i: label for i, label in enumerate(LABELS)}

print(f"Total labels: {len(LABELS)}")
print(f"Entity types: {len([l for l in LABELS if l.startswith('B-')])}")

## 2. Load Model and Tokenizer

In [None]:
MODEL_NAME = "law-ai/InLegalBERT"

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# Load model for token classification
model = AutoModelForTokenClassification.from_pretrained(
    MODEL_NAME,
    num_labels=len(LABELS),
    id2label=id2label,
    label2id=label2id
)

print(f"Model loaded: {MODEL_NAME}")
print(f"Parameters: {model.num_parameters():,}")

## 3. Prepare Training Data

In [None]:
# Sample training data (replace with your annotated data)
# Format: List of {"tokens": [...], "ner_tags": [...]}

SAMPLE_DATA = [
    {
        "tokens": ["In", "Kesavananda", "Bharati", "v.", "State", "of", "Kerala", ",", 
                   "the", "Supreme", "Court", "examined", "Article", "368", "."],
        "ner_tags": [0, 1, 2, 0, 3, 4, 4, 0, 0, 9, 10, 0, 13, 14, 0]
        # O, B-PETITIONER, I-PETITIONER, O, B-RESPONDENT, I-RESPONDENT, I-RESPONDENT, O, O, B-COURT, I-COURT, O, B-PROVISION, I-PROVISION, O
    },
    {
        "tokens": ["Hon'ble", "Justice", "D.Y.", "Chandrachud", "delivered", "the", "judgment", "."],
        "ner_tags": [5, 6, 6, 6, 0, 0, 0, 0]
        # B-JUDGE, I-JUDGE, I-JUDGE, I-JUDGE, O, O, O, O
    },
    {
        "tokens": ["FIR", "No.", "123/2020", "was", "registered", "under", "Section", "302", "of", "the", "IPC", "."],
        "ner_tags": [17, 18, 18, 0, 0, 0, 13, 14, 0, 0, 11, 0]
        # B-CASE_NUMBER, I-CASE_NUMBER, I-CASE_NUMBER, O, O, O, B-PROVISION, I-PROVISION, O, O, B-STATUTE, O
    },
]

# Create dataset
dataset = Dataset.from_list(SAMPLE_DATA)
print(f"Training samples: {len(dataset)}")
print(f"\nFirst sample:")
print(f"  Tokens: {dataset[0]['tokens']}")
print(f"  Tags: {[id2label[t] for t in dataset[0]['ner_tags']]}")

## 4. Tokenization and Alignment

In [None]:
def tokenize_and_align_labels(examples):
    """
    Tokenize and align labels with subword tokens.
    
    Key insight: When a word is split into subwords, only the first
    subword gets the label, others get -100 (ignored in loss).
    """
    tokenized = tokenizer(
        examples["tokens"],
        truncation=True,
        is_split_into_words=True,
        max_length=512,
        padding="max_length"
    )
    
    labels = []
    for i, label in enumerate(examples["ner_tags"]):
        word_ids = tokenized.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []
        
        for word_idx in word_ids:
            if word_idx is None:
                # Special tokens get -100
                label_ids.append(-100)
            elif word_idx != previous_word_idx:
                # First subword of a word
                label_ids.append(label[word_idx])
            else:
                # Subsequent subwords get -100 or I- tag
                label_ids.append(-100)
            
            previous_word_idx = word_idx
        
        labels.append(label_ids)
    
    tokenized["labels"] = labels
    return tokenized

# Tokenize dataset
tokenized_dataset = dataset.map(
    tokenize_and_align_labels,
    batched=True,
    remove_columns=dataset.column_names
)

print("Dataset tokenized!")
print(f"Features: {tokenized_dataset.features}")

## 5. Define Metrics

In [None]:
def compute_metrics(eval_pred):
    """Compute NER metrics using seqeval."""
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=2)
    
    # Remove ignored indices and convert to label strings
    true_labels = []
    true_predictions = []
    
    for pred, label in zip(predictions, labels):
        true_label = []
        true_pred = []
        
        for p, l in zip(pred, label):
            if l != -100:
                true_label.append(id2label[l])
                true_pred.append(id2label[p])
        
        true_labels.append(true_label)
        true_predictions.append(true_pred)
    
    return {
        "f1": f1_score(true_labels, true_predictions),
    }

## 6. Training Configuration

In [None]:
OUTPUT_DIR = "../models/legal_ner"

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=10,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    learning_rate=2e-5,
    weight_decay=0.01,
    warmup_steps=100,
    logging_steps=10,
    save_steps=500,
    save_total_limit=2,
    evaluation_strategy="no",  # Change to "steps" if you have validation data
    fp16=torch.cuda.is_available(),
    report_to="tensorboard",
)

data_collator = DataCollatorForTokenClassification(tokenizer)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

print("Trainer configured!")
print(f"Output directory: {OUTPUT_DIR}")

## 7. Train Model

In [None]:
# ‚ö†Ô∏è Uncomment to train (requires more data for good results)
# trainer.train()

print("Training code ready!")
print("\nTo train with real data:")
print("1. Replace SAMPLE_DATA with annotated NER data")
print("2. Uncomment trainer.train()")
print("3. Run the cell")

## 8. Save Model

In [None]:
# ‚ö†Ô∏è Uncomment after training
# trainer.save_model(f"{OUTPUT_DIR}/final")
# tokenizer.save_pretrained(f"{OUTPUT_DIR}/final")

print("Model save code ready!")

## 9. Inference Example

In [None]:
def predict_entities(text: str, model, tokenizer, id2label):
    """Predict entities in text."""
    # Tokenize
    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        max_length=512,
        return_offsets_mapping=True
    )
    
    offset_mapping = inputs.pop("offset_mapping")[0].tolist()
    
    # Move to device
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    # Predict
    model.eval()
    with torch.no_grad():
        outputs = model(**inputs)
        predictions = torch.argmax(outputs.logits, dim=-1)[0].tolist()
    
    # Extract entities
    entities = []
    current_entity = None
    
    for idx, (pred, (start, end)) in enumerate(zip(predictions, offset_mapping)):
        if start == end:  # Special token
            continue
        
        label = id2label[pred]
        
        if label.startswith("B-"):
            if current_entity:
                entities.append(current_entity)
            current_entity = {
                "text": text[start:end],
                "label": label[2:],
                "start": start,
                "end": end
            }
        elif label.startswith("I-") and current_entity:
            current_entity["text"] = text[current_entity["start"]:end]
            current_entity["end"] = end
        else:
            if current_entity:
                entities.append(current_entity)
                current_entity = None
    
    if current_entity:
        entities.append(current_entity)
    
    return entities

# Test with sample text
test_text = "In Kesavananda Bharati v. State of Kerala, the Supreme Court examined Article 368."
print(f"Test text: {test_text}")
print("\n(Run prediction after training)")

## Next Steps

1. **Get More Data**: Use the Indian Kanoon scraper and annotate with Label Studio
2. **Train**: Run the training cell with proper data
3. **Evaluate**: Add validation set and check F1 scores
4. **Deploy**: Use the model in the NER pipeline