In [2]:
! pip install transformers datasets torch evaluate seqeval scikit-learn accelerate indic-transliteration tqdm sentencepiece


Collecting evaluate
  Downloading evaluate-0.4.6-py3-none-any.whl.metadata (9.5 kB)
Collecting seqeval
  Downloading seqeval-1.2.2.tar.gz (43 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/43.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting indic-transliteration
  Downloading indic_transliteration-2.3.75-py3-none-any.whl.metadata (1.4 kB)
Collecting backports.functools-lru-cache (from indic-transliteration)
  Downloading backports.functools_lru_cache-2.0.0-py2.py3-none-any.whl.metadata (3.5 kB)
Collecting roman (from indic-transliteration)
  Downloading roman-5.2-py3-none-any.whl.metadata (4.3 kB)
Downloading evaluate-0.4.6-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloadin

In [None]:
import json
from pathlib import Path
import numpy as np
import warnings
import os
import sys
import gc
import torch

from datasets import Dataset, DatasetDict
from sklearn.model_selection import train_test_split
from transformers import (
    AutoTokenizer,
    AutoModel,
    AutoModelForTokenClassification,
    TrainingArguments,
    Trainer,
    DataCollatorForTokenClassification
)
from huggingface_hub import login
import evaluate
from tqdm import tqdm

# Suppress potential warnings from the model loading process
warnings.filterwarnings("ignore")

# ----------------------------
# Install transliteration lib if missing (runs silently in Colab/Jupyter)
# ----------------------------
try:
    from indic_transliteration import sanscript
    from indic_transliteration.sanscript import transliterate
except Exception:
    try:
        import subprocess
        subprocess.check_call([sys.executable, "-m", "pip", "install", "indic-transliteration", "-q"])
        from indic_transliteration import sanscript
        from indic_transliteration.sanscript import transliterate
    except Exception as e:
        print("Warning: Could not install 'indic-transliteration'. Transliteration might not work.")
        sanscript = None
        transliterate = None

# ============================================================
# PATH AND DATA CONFIGURATION
# ============================================================
data_dir = Path("cross_lingual_data")  # <-- CHANGE THIS TO YOUR DATA FOLDER
languages = ['as', 'bn', 'gu', 'ml', 'mr', 'ta', 'te']  # languages to load
model_name = "ai4bharat/indic-bert"
output_dir = "./indicbert-devanagari-ner-final"
checkpoint_dir = "./checkpoints"

print("="*70)
print(" CONFIGURATION")
print("="*70)
print("Data directory:", data_dir.absolute())
if not data_dir.exists():
    print("FATAL ERROR: data directory not found. Exiting.")
    sys.exit(1)

# ----------------------------
# SCRIPT MAP + Transliteration Helper
# ----------------------------
SCRIPT_MAP = {}
if sanscript is not None:
    SCRIPT_MAP = {
        'as': sanscript.BENGALI,
        'bn': sanscript.BENGALI,
        'gu': sanscript.GUJARATI,
        'ml': sanscript.MALAYALAM,
        'mr': sanscript.DEVANAGARI,
        'ta': sanscript.TAMIL,
        'te': sanscript.TELUGU,
    }

def transliterate_to_devanagari(text, lang_code):
    """Transliterate text (a token) from source script to Devanagari."""
    if text is None:
        return text
    if lang_code == 'mr':
        return text
    if transliterate is None or lang_code not in SCRIPT_MAP:
        return text
    try:
        # Use ITRANS mode to handle non-Indic foreign words better
        return transliterate(text, SCRIPT_MAP[lang_code], sanscript.DEVANAGARI)
    except Exception:
        return text

# ============================================================
# 1. AUTHENTICATION & MODEL LOADING (Your requested block)
# ============================================================

print("\n---")
print("STEP 1: Hugging Face Authentication")
print("---")
try:
    # Use existing token or prompt for login
    login(new_session=False)
    print("✓ Authentication check passed (token found or successfully logged in).")
except Exception as e:
    print(f"✗ Warning: Authentication failed. Error: {e}")

print("\n---")
print(f"STEP 2: Loading Base Model '{model_name}'")
print("---")

try:
    # Load the Tokenizer first
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
    print("✓ Tokenizer loaded.")

    # Model will be loaded again later as AutoModelForTokenClassification
    # We do a basic load here to verify access
    _ = AutoModel.from_pretrained(model_name, torch_dtype=torch.float16)
    print("✓ Base Model access verified.")
    del _
    gc.collect()

except Exception as e:
    print("\n=============================================================")
    print(f"❌ FATAL ERROR: Could not load model '{model_name}'")
    print("=============================================================")
    print(f"Details: {e}")
    print("\nACTION REQUIRED: Ensure you accepted the license terms on the model page.")
    sys.exit(1)

# ============================================================
# 3. LOAD, CONVERT, AND SPLIT DATA
# ============================================================

def load_and_convert_data(file_path: Path, lang_code: str):
    data = []
    with open(file_path, "r", encoding="utf-8") as f:
        total = sum(1 for line in f if line.strip())
    with open(file_path, "r", encoding="utf-8") as f:
        for line in tqdm(f, total=total, desc=f"Loading {file_path.name}", ncols=80):
            line = line.strip()
            if not line: continue
            try:
                item = json.loads(line)
                words = item.get("words") or item.get("tokens") or []
                ner = item.get("ner") or item.get("ner_tags") or []
                if len(words) != len(ner): continue
                tokens_dev = [transliterate_to_devanagari(w, lang_code) for w in words]
                data.append({"tokens": tokens_dev, "ner_tags": ner})
            except Exception: continue
    return data

print("\n" + "="*70)
print("STEP 3: Loading and Converting Data to Devanagari")
print("="*70)
all_data = []
for lang in languages:
    file_path = data_dir / f"{lang}_data.json"
    if file_path.exists():
        lang_examples = load_and_convert_data(file_path, lang)
        print(f"  Loaded {len(lang_examples):,} examples for {lang}")
        all_data.extend(lang_examples)
        del lang_examples; gc.collect()
    else:
        print("  NOT FOUND:", file_path)

if len(all_data) == 0: sys.exit(1)

# Split data
train_val, test = train_test_split(all_data, test_size=0.10, random_state=42)
train, val = train_test_split(train_val, test_size=0.10, random_state=42)
del all_data; gc.collect()

dataset = DatasetDict({
    "train": Dataset.from_list(train),
    "validation": Dataset.from_list(val),
    "test": Dataset.from_list(test)
})

# Labels
all_labels = set()
for item in train: all_labels.update(item["ner_tags"])
label_list = sorted(list(all_labels))
label2id = {label: i for i, label in enumerate(label_list)}
id2label = {i: label for i, label in enumerate(label_list)}
print(f"\nLabels: {label_list}")
print(f"Total examples: {len(train):,} (Train), {len(val):,} (Val), {len(test):,} (Test)")

# ============================================================
# 4. TOKENIZE AND ALIGN LABELS
# ============================================================
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(
        examples["tokens"], is_split_into_words=True, truncation=True, padding=False, max_length=512
    )
    labels = []
    for i, label_seq in enumerate(examples["ner_tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        label_ids = []
        previous_word_idx = None
        for word_idx in word_ids:
            if word_idx is None:
                label_ids.append(-100)
            elif word_idx != previous_word_idx:
                label_ids.append(label2id[label_seq[word_idx]])
            else:
                label_ids.append(-100) # Subsequent sub-word token gets -100
            previous_word_idx = word_idx
        labels.append(label_ids)
    tokenized_inputs["labels"] = labels
    return tokenized_inputs

print("\n" + "="*70)
print("STEP 4: Tokenizing Dataset")
print("="*70)
tokenized_dataset = dataset.map(
    tokenize_and_align_labels,
    batched=True,
    batch_size=1000,
    remove_columns=dataset["train"].column_names,
    desc="Tokenizing"
)

# Load the specific model head for token classification
model = AutoModelForTokenClassification.from_pretrained(
    model_name,
    num_labels=len(label_list),
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True
)

# ============================================================
# 5. METRICS AND TRAINING
# ============================================================
seqeval = evaluate.load("seqeval")
def compute_metrics(eval_pred):
    preds, labels = eval_pred
    preds = np.argmax(preds, axis=2)
    true_predictions = []
    true_labels = []
    for pred_seq, label_seq in zip(preds, labels):
        pred_labels = []
        true_label_list = []
        for p, l in zip(pred_seq, label_seq):
            if l != -100:
                pred_labels.append(id2label[p])
                true_label_list.append(id2label[l])
        true_predictions.append(pred_labels)
        true_labels.append(true_label_list)
    results = seqeval.compute(predictions=true_predictions, references=true_labels)
    return {"precision": results["overall_precision"], "recall": results["overall_recall"], "f1": results["overall_f1"]}

Path(checkpoint_dir).mkdir(exist_ok=True)
training_args = TrainingArguments(
    output_dir=checkpoint_dir, overwrite_output_dir=True, save_strategy="steps", save_steps=1000,
    save_total_limit=3, eval_strategy="steps", eval_steps=1000, load_best_model_at_end=True,
    metric_for_best_model="f1", greater_is_better=True, learning_rate=2e-5,
    per_device_train_batch_size=16, per_device_eval_batch_size=32, num_train_epochs=5,
    weight_decay=0.01, gradient_accumulation_steps=2, fp16=torch.cuda.is_available(),
    dataloader_num_workers=4, logging_dir="./logs", logging_steps=100, logging_strategy="steps",
    seed=42, push_to_hub=False, report_to="none"
)
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

trainer = Trainer(
    model=model, args=training_args, train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"], tokenizer=tokenizer,
    data_collator=data_collator, compute_metrics=compute_metrics
)

# Training (with auto-resume if checkpoints exist)
print("\n" + "="*70)
print("STEP 5: Training IndicBERT")
print("="*70)
existing_checkpoints = list(Path(checkpoint_dir).glob("checkpoint-*"))
if existing_checkpoints:
    print(f"Found {len(existing_checkpoints)} checkpoint(s). Resuming training.")

try:
    trainer.train(resume_from_checkpoint=True if existing_checkpoints else None)
    print("\nTraining completed.")
except KeyboardInterrupt:
    print("\nTraining interrupted by user. Checkpoints saved.")
except Exception as e:
    print("\nTraining failed with exception:", e)
    raise

# ============================================================
# 6. SAVE FINAL MODEL AND EVALUATE
# ============================================================
print("\n" + "="*70)
print("STEP 6: Saving and Evaluating")
print("="*70)
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)

print(f"\nFinal model saved to: {output_dir}")

# Evaluate on test set
test_results = trainer.evaluate(tokenized_dataset["test"])
print("\nTEST RESULTS")
print(f"Precision: {test_results.get('eval_precision', 0.0):.4f}")
print(f"Recall:    {test_results.get('eval_recall', 0.0):.4f}")
print(f"F1 Score:  {test_results.get('eval_f1', 0.0):.4f}")

print("\nDONE.")