In [None]:
!pip install -q transformers evaluate seqeval pandas

import json, os, glob, csv
from pathlib import Path
from typing import List
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
import evaluate
import pandas as pd

# ---------- SETTINGS ----------
base_dir = "."
pattern = os.path.join(base_dir, "projected_hi_*_ner.jsonl")
files = sorted(glob.glob(pattern))
if not files:
    raise SystemExit(f"No files matched {pattern}")

model_name = "Davlan/xlm-roberta-base-wikiann-ner"

# ---------- load model ----------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
model = AutoModelForTokenClassification.from_pretrained(model_name)
model.to(device)
model.eval()
id2label = model.config.id2label

print("Model:", model_name, "Device:", device)
print("Files:", files)

# ---------- helpers ----------
def load_jsonl_examples(path: str, token_field="target_tokens", tag_field="target_tags"):
    exs = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line=line.strip()
            if not line:
                continue
            obj = json.loads(line)
            toks = obj.get(token_field) or obj.get("target_tokens") or obj.get("tokens")
            tags = obj.get(tag_field) or obj.get("target_tags") or obj.get("tags")
            if toks is None or tags is None:
                continue
            if len(toks) != len(tags):
                continue
            exs.append({"tokens": toks, "tags": tags})
    return exs

def flat_to_bio(flat_tags: List[str]) -> List[str]:
    bio=[]
    prev="O"
    for t in flat_tags:
        t = "O" if (t is None) else str(t)
        if t.upper()=="O":
            bio.append("O"); prev="O"
        else:
            if prev==t:
                bio.append("I-"+t)
            else:
                bio.append("B-"+t)
            prev=t
    return bio

def bio_to_flat(bio_tag: str) -> str:
    if bio_tag is None: return "O"
    if bio_tag=="O": return "O"
    if bio_tag.startswith("B-") or bio_tag.startswith("I-"):
        return bio_tag.split("-",1)[1]
    return bio_tag

@torch.no_grad()
def predict_for_tokens(tokens: List[str]) -> List[str]:
    enc = tokenizer(tokens, is_split_into_words=True, return_tensors="pt", truncation=True, padding=True)
    enc = {k:v.to(device) for k,v in enc.items()}
    outputs = model(**enc)
    logits = outputs.logits.detach().cpu().numpy()[0]  # (seq_len, num_labels)
    pred_ids = np.argmax(logits, axis=-1)
    # obtain word_ids from non-tensor encoding
    enc2 = tokenizer(tokens, is_split_into_words=True, truncation=True)
    word_ids = enc2.word_ids()
    preds_per_word=[]
    last_word=None
    for idx, w in enumerate(word_ids):
        if w is None:
            continue
        if w != last_word:
            label_id = int(pred_ids[idx])
            model_label = id2label[label_id]
            preds_per_word.append(bio_to_flat(model_label))
        last_word = w
    # align length
    if len(preds_per_word) != len(tokens):
        if len(preds_per_word) > len(tokens):
            preds_per_word = preds_per_word[:len(tokens)]
        else:
            preds_per_word += ["O"] * (len(tokens) - len(preds_per_word))
    return preds_per_word

# metrics helpers
seqeval = evaluate.load("seqeval")
def compute_token_metrics(gold_flat, pred_flat):
    p_mac, r_mac, f1_mac, _ = precision_recall_fscore_support(gold_flat, pred_flat, average="macro", zero_division=0)
    p_mi, r_mi, f1_mi, _ = precision_recall_fscore_support(gold_flat, pred_flat, average="micro", zero_division=0)
    acc = accuracy_score(gold_flat, pred_flat)
    return {"precision_macro": p_mac, "recall_macro": r_mac, "f1_macro": f1_mac,
            "precision_micro": p_mi, "recall_micro": r_mi, "f1_micro": f1_mi,
            "accuracy": acc}

def compute_entity_metrics(gold_bios, pred_bios):
    res = seqeval.compute(predictions=pred_bios, references=gold_bios)
    overall = {k: res[k] for k in ["overall_precision","overall_recall","overall_f1","overall_accuracy"] if k in res}
    return overall, res

# ---------- per-file evaluation ----------
rows = []
aggregate_gold=[]
aggregate_pred=[]

for fpath in files:
    examples = load_jsonl_examples(fpath)
    if not examples:
        print("Skipping", fpath, "- no examples")
        continue

    gold_flat_all=[]
    pred_flat_all=[]
    gold_bio_seqs=[]
    pred_bio_seqs=[]

    for ex in examples:
        toks = ex["tokens"]
        gold_tags = ex["tags"]
        # normalize gold
        if any(isinstance(t,str) and (t.startswith("B-") or t.startswith("I-")) for t in gold_tags):
            gold_bio = [str(t) for t in gold_tags]
            gold_flat = [bio_to_flat(t) for t in gold_bio]
        else:
            gold_flat = [str(t) for t in gold_tags]
            gold_bio = flat_to_bio(gold_flat)

        preds_flat = predict_for_tokens(toks)
        preds_bio = flat_to_bio(preds_flat)

        gold_flat_all.extend(gold_flat)
        pred_flat_all.extend(preds_flat)
        gold_bio_seqs.append(gold_bio)
        pred_bio_seqs.append(preds_bio)

    token_metrics = compute_token_metrics(gold_flat_all, pred_flat_all)
    entity_overall, entity_full = compute_entity_metrics(gold_bio_seqs, pred_bio_seqs)

    lang_name = os.path.basename(fpath).replace("projected_hi_","").replace("_ner.jsonl","")
    row = {
        "file": os.path.basename(fpath),
        "lang": lang_name,
        "n_examples": len(examples),
        "n_tokens": len(gold_flat_all),
        # token-level
        "token_precision_macro": token_metrics["precision_macro"],
        "token_recall_macro": token_metrics["recall_macro"],
        "token_f1_macro": token_metrics["f1_macro"],
        "token_precision_micro": token_metrics["precision_micro"],
        "token_recall_micro": token_metrics["recall_micro"],
        "token_f1_micro": token_metrics["f1_micro"],
        "token_accuracy": token_metrics["accuracy"],
        # entity-level (seqeval overall)
        "entity_precision": entity_overall.get("overall_precision", None),
        "entity_recall": entity_overall.get("overall_recall", None),
        "entity_f1": entity_overall.get("overall_f1", None),
        "entity_accuracy": entity_overall.get("overall_accuracy", None),
    }
    rows.append(row)

    aggregate_gold.extend(gold_flat_all)
    aggregate_pred.extend(pred_flat_all)

    # quick print per language
    print(f"\n=== {lang_name} === examples: {len(examples)} tokens: {len(gold_flat_all)}")
    print(f"Token F1 (micro): {token_metrics['f1_micro']:.4f}  Precision (micro): {token_metrics['precision_micro']:.4f}  Recall (micro): {token_metrics['recall_micro']:.4f}  Acc: {token_metrics['accuracy']:.4f}")
    print(f"Entity F1: {entity_overall.get('overall_f1', 'N/A')}  P: {entity_overall.get('overall_precision', 'N/A')}  R: {entity_overall.get('overall_recall', 'N/A')}")

# ---------- aggregated metrics ----------
if aggregate_gold:
    agg_token = compute_token_metrics(aggregate_gold, aggregate_pred)
    print("\n=== AGGREGATE ACROSS ALL LANGUAGES ===")
    print(f"Total tokens: {len(aggregate_gold)}")
    print(f"Token F1 (micro): {agg_token['f1_micro']:.4f}  P: {agg_token['precision_micro']:.4f}  R: {agg_token['recall_micro']:.4f}  Acc: {agg_token['accuracy']:.4f}")

# ---------- save results ----------
df = pd.DataFrame(rows)
csv_out = "per_file_metrics.csv"
df.to_csv(csv_out, index=False)
with open("aggregate_metrics.json","w",encoding="utf-8") as fo:
    json.dump({"aggregate_token_metrics": agg_token if aggregate_gold else None, "files": rows}, fo, ensure_ascii=False, indent=2)

print(f"\nPer-file metrics written to {csv_out}")
print("Aggregate saved to aggregate_metrics.json")


TEMPORAL ENTITY

In [None]:
# Recursively read all xmls in HindiTimeBank subfolders, train (TIMEX-only) with oversampling, evaluate.
!pip install -q transformers datasets evaluate seqeval

import os, glob, re, xml.etree.ElementTree as ET, json
from pathlib import Path
from sklearn.model_selection import StratifiedShuffleSplit
from collections import Counter
import numpy as np
import torch
import evaluate
from datasets import Dataset as HFDataset, DatasetDict as HFDatasetDict
from transformers import (
    AutoTokenizer, AutoModelForTokenClassification,
    DataCollatorForTokenClassification, TrainingArguments, Trainer
)
from sklearn.metrics import precision_recall_fscore_support, accuracy_score

# ---------------- CONFIG - change these ----------------
BASE_DIR = "/content/HindiTimeBank"   # your main folder containing subfolders with XMLs
MODEL_NAME = "Davlan/xlm-roberta-base-wikiann-ner"
OUTPUT_DIR = "./timebank_recursive_timex"
ID_START, ID_END = 0, 99999   # we'll auto-collect all files; use range if you want to filter numeric IDs
RANDOM_SEED = 42

TIMEX_ONLY = True                  # collapse everything non-TIMEX -> O
TIMEX_OVERSAMPLE_FACTOR = 3        # duplicate docs with TIMEX in train
DO_TRAIN = True                    # set False to skip training (will still evaluate with base model)
EPOCHS = 6
BATCH_SIZE = 4                     # lower if GPU memory limited
LR = 1e-5
# -----------------------------------------------------

# === find xml files recursively ===
xml_paths = sorted(glob.glob(os.path.join(BASE_DIR, "**", "*.xml"), recursive=True))
if not xml_paths:
    raise SystemExit(f"No XML files found under {BASE_DIR}. Make sure folder is uploaded.")
print("Found XML files (recursively):", len(xml_paths))

def extract_id_from_path(path):
    # return first numeric sequence in filename if you want to filter; else return full stem
    m = re.search(r"(\d+)", Path(path).stem)
    return int(m.group(1)) if m else Path(path).stem

# === parse TimeBank-style XML (token ids -> tokens + BIO tags) ===
def parse_timebank_xml(path):
    tree = ET.parse(path); root = tree.getroot()
    tokens = []; id2idx = {}
    for t in root.findall("token"):
        tid = t.attrib.get("id")
        text = (t.text or "").strip()
        tokens.append(text)
        id2idx[tid] = len(tokens)-1
    labels = ["O"] * len(tokens)
    def assign(tok_ids, prefix):
        if not tok_ids: return
        ids = [s.strip() for s in tok_ids.split(",") if s.strip()]
        positions = [id2idx.get(i) for i in ids if i in id2idx]
        if not positions: return
        positions = sorted(positions)
        labels[positions[0]] = "B-" + prefix
        for p in positions[1:]:
            labels[p] = "I-" + prefix
    for ev in root.findall("EVENT"):
        tok_attr = ev.attrib.get("tokens") or ev.attrib.get("token")
        cls = ev.attrib.get("class"); prefix = "EVENT" if not cls else f"EVENT_{cls}"
        assign(tok_attr, prefix)
    for st in root.findall("STATE"):
        tok_attr = st.attrib.get("tokens") or st.attrib.get("token")
        cls = st.attrib.get("class"); prefix = "STATE" if not cls else f"STATE_{cls}"
        assign(tok_attr, prefix)
    for tm in root.findall("TIMEX"):
        tok_attr = tm.attrib.get("tokens") or tm.attrib.get("token")
        cls = tm.attrib.get("class"); prefix = "TIMEX" if not cls else f"TIMEX_{cls}"
        assign(tok_attr, prefix)
    return {"tokens": tokens, "tags": labels, "doc": os.path.relpath(path, BASE_DIR)}

# parse all xmls
examples_raw = [parse_timebank_xml(p) for p in xml_paths]
print("Parsed documents:", len(examples_raw))

# optionally collapse to TIMEX-only labels
def collapse_to_timex_only(example):
    new_tags = []
    for t in example["tags"]:
        if t.startswith("B-TIMEX"): new_tags.append("B-TIMEX")
        elif t.startswith("I-TIMEX"): new_tags.append("I-TIMEX")
        else: new_tags.append("O")
    return {"tokens": example["tokens"], "tags": new_tags, "doc": example["doc"]}

if TIMEX_ONLY:
    examples = [collapse_to_timex_only(e) for e in examples_raw]
    print("Collapsed labels to TIMEX-only.")
else:
    examples = examples_raw

# build stratification label (presence of TIMEX)
has_timex = [int(any(t.startswith("B-TIMEX") or t.startswith("I-TIMEX") for t in ex["tags"])) for ex in examples]
print("Documents containing TIMEX:", sum(has_timex), "/", len(has_timex))

# do stratified split: train(80)/val(10)/test(10) by document
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.1, random_state=RANDOM_SEED)
train_val_idx, test_idx = next(sss.split(np.arange(len(examples)), has_timex))
train_val = [examples[i] for i in train_val_idx]
test_docs = [examples[i] for i in test_idx]
# split train_val -> train/val ~ 0.9/0.1 of original (so val ~0.1 total)
tv_has = [has_timex[i] for i in train_val_idx]
sss2 = StratifiedShuffleSplit(n_splits=1, test_size=0.1111, random_state=RANDOM_SEED)
train_idx_rel, val_idx_rel = next(sss2.split(np.arange(len(train_val)), tv_has))
train_docs = [train_val[i] for i in train_idx_rel]
val_docs = [train_val[i] for i in val_idx_rel]

print("Split sizes -> train:", len(train_docs), "val:", len(val_docs), "test:", len(test_docs))

# Oversample TIMEX docs in training
def contains_timex(ex): return any(t.startswith("B-TIMEX") or t.startswith("I-TIMEX") for t in ex["tags"])
oversampled_train = []
for ex in train_docs:
    if contains_timex(ex):
        for _ in range(TIMEX_OVERSAMPLE_FACTOR):
            oversampled_train.append({"tokens": ex["tokens"], "tags": ex["tags"], "doc": ex["doc"]})
    else:
        oversampled_train.append(ex)
print("Train size before:", len(train_docs), "after oversample:", len(oversampled_train))

# create HF DatasetDict
ds = HFDatasetDict({
    "train": HFDataset.from_list(oversampled_train),
    "validation": HFDataset.from_list(val_docs),
    "test": HFDataset.from_list(test_docs)
})

# label set (TIMEX-only or full set)
labels = sorted({lab for ex in examples for lab in ex["tags"]})
# ensure O last
if "O" in labels:
    labels = [l for l in labels if l != "O"] + ["O"]
label2id = {lab:i for i,lab in enumerate(labels)}
id2label = {i:lab for lab,i in label2id.items()}
print("Labels:", labels)

# tokenizer + alignment
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)

def tokenize_and_align_labels(batch):
    tokenized = tokenizer(batch["tokens"], is_split_into_words=True, truncation=True)
    all_labels=[]
    for i,labs in enumerate(batch["tags"]):
        word_ids = tokenized.word_ids(batch_index=i)
        label_ids=[]
        for w in word_ids:
            if w is None:
                label_ids.append(-100)
            else:
                label_ids.append(label2id[labs[w]])
        all_labels.append(label_ids)
    tokenized["labels"] = all_labels
    return tokenized

for split in ["train","validation","test"]:
    if len(ds[split])>0:
        ds[split] = ds[split].map(tokenize_and_align_labels, batched=True, remove_columns=["tokens","tags","doc"])
print("Tokenized sizes:", {k: len(ds[k]) for k in ds})

# load model (ignore mismatch so classifier re-init to new num_labels)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForTokenClassification.from_pretrained(MODEL_NAME, num_labels=len(labels), ignore_mismatched_sizes=True)
model.config.id2label = id2label
model.config.label2id = label2id
model.to(device)

# metrics for Trainer
seqeval = evaluate.load("seqeval")
def compute_metrics(eval_pred):
    logits, label_ids = eval_pred
    preds = np.argmax(logits, axis=-1)
    true_preds=[]; true_labels=[]
    for pred_row, lab_row in zip(preds, label_ids):
        pseq=[]; lseq=[]
        for p,l in zip(pred_row, lab_row):
            if l == -100:
                continue
            pseq.append(model.config.id2label[int(p)])
            lseq.append(model.config.id2label[int(l)])
        true_preds.append(pseq); true_labels.append(lseq)
    res = seqeval.compute(predictions=true_preds, references=true_labels) if len(true_labels)>0 else {}
    # token-level flat (collapse BIO to TIMEX vs O)
    def bio_to_flat(seq): return ["O" if x=="O" else x.split("-",1)[1] for x in seq]
    flat_gold = [g for seq in true_labels for g in bio_to_flat(seq)]
    flat_pred = [p for seq in true_preds for p in bio_to_flat(seq)]
    if flat_gold:
        p_micro, r_micro, f1_micro, _ = precision_recall_fscore_support(flat_gold, flat_pred, average="micro", zero_division=0)
    else:
        p_micro=r_micro=f1_micro=0.0
    return {"entity_overall_f1": res.get("overall_f1", 0.0), "token_f1_micro": f1_micro}

# train
if DO_TRAIN:
    data_collator = DataCollatorForTokenClassification(tokenizer)
    training_args = TrainingArguments(
        output_dir=OUTPUT_DIR,
        eval_strategy="epoch",
        save_strategy="epoch",
        learning_rate=LR,
        per_device_train_batch_size=BATCH_SIZE,
        per_device_eval_batch_size=max(1, BATCH_SIZE*2),
        num_train_epochs=EPOCHS,
        weight_decay=0.01,
        logging_steps=20,
        fp16=torch.cuda.is_available(),
        load_best_model_at_end=True,
        metric_for_best_model="entity_overall_f1",
        greater_is_better=True,
        seed=RANDOM_SEED
    )
    trainer = Trainer(
        model=model, args=training_args,
        train_dataset=ds["train"],
        eval_dataset=ds["validation"] if len(ds["validation"])>0 else None,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics
    )
    trainer.train()
    trainer.save_model(OUTPUT_DIR)
    tokenizer.save_pretrained(OUTPUT_DIR)
    print("Saved model to", OUTPUT_DIR)
else:
    print("DO_TRAIN is False â€” skipping training (model is base checkpoint).")

# evaluate on test docs and save predictions
model.eval()
all_gold_flat = []; all_pred_flat = []; gold_bio_seqs = []; pred_bio_seqs = []; out_preds=[]
for orig in test_docs:
    toks = orig["tokens"]; gold_bio = orig["tags"]
    enc = tokenizer(toks, is_split_into_words=True, return_tensors="pt", truncation=True, padding=True)
    enc = {k:v.to(device) for k,v in enc.items()}
    with torch.no_grad():
        out = model(**enc)
        logits = out.logits.cpu().numpy()[0]
        pred_ids = np.argmax(logits, axis=-1)
    enc2 = tokenizer(toks, is_split_into_words=True, truncation=True)
    word_ids = enc2.word_ids()
    preds = []; last=None
    for idx,w in enumerate(word_ids):
        if w is None: continue
        if w != last:
            preds.append(model.config.id2label.get(int(pred_ids[idx]), "O"))
        last = w
    # safety align
    if len(preds) != len(toks):
        if len(preds) > len(toks): preds = preds[:len(toks)]
        else: preds += ["O"] * (len(toks)-len(preds))
    pred_bio = preds
    # flat TIMEX vs O
    def bio_flat_to_tag(b): return "O" if b=="O" else b.split("-",1)[1]
    pred_flat = [bio_flat_to_tag(x) for x in pred_bio]
    gold_flat = [bio_flat_to_tag(x) for x in gold_bio]
    all_gold_flat.extend(gold_flat); all_pred_flat.extend(pred_flat)
    gold_bio_seqs.append(gold_bio); pred_bio_seqs.append(pred_bio)
    out_preds.append({"doc": orig["doc"], "tokens": toks, "gold_bio": gold_bio, "pred_bio": pred_bio})

with open("test_predictions.jsonl","w",encoding="utf-8") as fo:
    for r in out_preds: fo.write(json.dumps(r, ensure_ascii=False) + "\n")
print("Wrote test_predictions.jsonl")

# print counters + one missed example
gold_counts = Counter([x for seq in gold_bio_seqs for x in seq])
pred_counts = Counter([x for seq in pred_bio_seqs for x in seq])
print("\nGold BIO counts:", gold_counts)
print("Pred BIO counts:", pred_counts)
print("Pred B-TIMEX:", pred_counts.get("B-TIMEX",0), "I-TIMEX:", pred_counts.get("I-TIMEX",0))

# token-level metrics
if all_gold_flat:
    p_macro, r_macro, f1_macro, _ = precision_recall_fscore_support(all_gold_flat, all_pred_flat, average="macro", zero_division=0)
    p_micro, r_micro, f1_micro, _ = precision_recall_fscore_support(all_gold_flat, all_pred_flat, average="micro", zero_division=0)
    acc = accuracy_score(all_gold_flat, all_pred_flat)
else:
    p_macro=r_macro=f1_macro=p_micro=r_micro=f1_micro=acc=0.0

print("\nTOKEN-LEVEL (flat labels):")
print(f"Macro   P: {p_macro:.4f}  R: {r_macro:.4f}  F1: {f1_macro:.4f}")
print(f"Micro   P: {p_micro:.4f}  R: {r_micro:.4f}  F1: {f1_micro:.4f}")
print(f"Accuracy: {acc:.4f}")

# entity-level
if gold_bio_seqs:
    ent_res = seqeval.compute(predictions=pred_bio_seqs, references=gold_bio_seqs)
    print("\nENTITY-LEVEL (seqeval) overall:")
    for k in ("overall_precision","overall_recall","overall_f1","overall_accuracy"):
        if k in ent_res: print(f"{k}: {ent_res[k]}")
    print("\nEntity breakdown:")
    for k,v in ent_res.items():
        if not k.startswith("overall_"):
            print(k, ":", v)
else:
    print("\nNo entity sequences present for evaluation.")

# print first missed-TIMEX doc
def has_timex(seq): return any(x.startswith("B-TIMEX") or x.startswith("I-TIMEX") for x in seq)
for r in out_preds:
    if has_timex(r["gold_bio"]) and not has_timex(r["pred_bio"]):
        print("\n--- Missed TIMEX doc:", r["doc"])
        for tok,g,p in zip(r["tokens"], r["gold_bio"], r["pred_bio"]):
            print(f"{tok:20s} GOLD={g:10s} PRED={p:10s}")
        break
else:
    print("\nNo missed-TIMEX docs found (every gold-TIMEX doc has at least one predicted TIMEX token).")

print("\nDone.")
