# SciX Enrichment NER Model Training

Fine-tune **SciBERT** (`allenai/scibert_scivocab_uncased`) on the SciX enrichment dataset
for **token-classification NER** with BIO tags.

## Entity types
- **topic** (UAT, SWEET, GCMD vocabularies)
- **institution** (ROR)
- **author**
- **date_range**

## BIO label set (9 tags)
```
O, B-topic, I-topic, B-institution, I-institution,
B-author, I-author, B-date_range, I-date_range
```

## Prerequisites
1. **Runtime**: GPU (T4 or A100) via `Runtime > Change runtime type`
2. **Data**: Upload `enrichment_train.jsonl` and `enrichment_val.jsonl` via Google Drive or direct upload
3. **Training script**: `scripts/train_enrichment_model.py` from the repository

## 1. Install Dependencies

In [None]:
!pip install -q transformers datasets torch accelerate numpy

In [None]:
import torch

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available:  {torch.cuda.is_available()}")
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    print(f"GPU:             {gpu_name}")
    print(f"GPU memory:      {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
else:
    print("WARNING: No GPU detected. Training will be very slow.")
    print("Go to Runtime > Change runtime type > GPU")

## 2. Upload / Download Data

Choose **one** of the options below:
- **Option A**: Upload from Google Drive (recommended for persistence)
- **Option B**: Direct file upload from your local machine

### Option A: Google Drive

In [None]:
from google.colab import drive

drive.mount("/content/drive")

# Adjust these paths to match your Drive folder structure
DRIVE_DATA_DIR = "/content/drive/MyDrive/scix-enrichment"

TRAIN_FILE = f"{DRIVE_DATA_DIR}/enrichment_train.jsonl"
VAL_FILE = f"{DRIVE_DATA_DIR}/enrichment_val.jsonl"
OUTPUT_DIR = f"{DRIVE_DATA_DIR}/enrichment_model"

### Option B: Direct Upload

In [None]:
# Uncomment and run this cell if uploading directly

# from google.colab import files
# uploaded = files.upload()  # Select enrichment_train.jsonl and enrichment_val.jsonl
#
# TRAIN_FILE = "enrichment_train.jsonl"
# VAL_FILE = "enrichment_val.jsonl"
# OUTPUT_DIR = "output/enrichment_model"

In [None]:
import json
from pathlib import Path

for label, fpath in [("Train", TRAIN_FILE), ("Val", VAL_FILE)]:
    p = Path(fpath)
    if not p.exists():
        raise FileNotFoundError(f"{label} file not found: {fpath}")
    count = sum(1 for line in open(p) if line.strip())
    print(f"{label}: {count} records ({fpath})")

## 3. Training Script

The cell below contains the full training logic inlined from
`scripts/train_enrichment_model.py` so the notebook is self-contained.

It converts character-level span annotations to BIO-tagged token sequences,
then fine-tunes SciBERT with HuggingFace Trainer.

In [None]:
%%writefile train_enrichment_model.py
"""Training script for the SciX enrichment NER model.

Fine-tunes SciBERT (allenai/scibert_scivocab_uncased) on enrichment_labels
for token-classification NER with BIO tags.

BIO label set (9 tags):
    B-topic, I-topic, B-institution, I-institution,
    B-author, I-author, B-date_range, I-date_range, O
"""

from __future__ import annotations

import argparse
import json
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any

ENTITY_TYPES = ("topic", "institution", "author", "date_range")

BIO_LABELS: list[str] = ["O"]
for _etype in ENTITY_TYPES:
    BIO_LABELS.append(f"B-{_etype}")
    BIO_LABELS.append(f"I-{_etype}")

LABEL2ID: dict[str, int] = {label: i for i, label in enumerate(BIO_LABELS)}
ID2LABEL: dict[int, str] = {i: label for i, label in enumerate(BIO_LABELS)}
NUM_LABELS: int = len(BIO_LABELS)


@dataclass(frozen=True)
class TrainConfig:
    model_name: str = "allenai/scibert_scivocab_uncased"
    train_file: str = "data/enrichment_train.jsonl"
    val_file: str = "data/enrichment_val.jsonl"
    output_dir: str = "output/enrichment_model"
    learning_rate: float = 2e-5
    batch_size: int = 16
    num_epochs: int = 10
    warmup_ratio: float = 0.1
    weight_decay: float = 0.01
    max_seq_length: int = 256
    gradient_accumulation_steps: int = 1
    early_stopping_patience: int = 3
    seed: int = 42
    fp16: bool = False
    bf16: bool = False
    logging_steps: int = 50
    eval_steps: int = 200
    save_steps: int = 200
    save_total_limit: int = 3
    log_format: str = "json"


def load_enrichment_records(path: Path) -> list[dict[str, Any]]:
    records: list[dict[str, Any]] = []
    with open(path, encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line:
                records.append(json.loads(line))
    return records


def _char_labels(text: str, spans: list[dict[str, Any]]) -> list[str]:
    char_tags = ["O"] * len(text)
    for span in spans:
        start = span.get("start", 0)
        end = span.get("end", 0)
        entity_type = span.get("type", "")
        if entity_type not in ENTITY_TYPES:
            continue
        if start >= end or start >= len(text):
            continue
        end = min(end, len(text))
        if char_tags[start] != "O":
            continue
        char_tags[start] = f"B-{entity_type}"
        for ci in range(start + 1, end):
            if char_tags[ci] == "O":
                char_tags[ci] = f"I-{entity_type}"
    return char_tags


def align_labels_to_tokens(
    text: str,
    spans: list[dict[str, Any]],
    tokenizer: Any,
    max_length: int = 256,
) -> dict[str, Any]:
    char_tags = _char_labels(text, spans)
    encoding = tokenizer(
        text,
        max_length=max_length,
        padding="max_length",
        truncation=True,
        return_offsets_mapping=True,
        return_tensors=None,
    )
    offsets = encoding.pop("offset_mapping")
    labels: list[int] = []
    for token_idx, (start, end) in enumerate(offsets):
        if start == 0 and end == 0:
            labels.append(-100)
            continue
        tag = char_tags[start] if start < len(char_tags) else "O"
        labels.append(LABEL2ID.get(tag, LABEL2ID["O"]))
    encoding["labels"] = labels
    return encoding


def prepare_dataset(
    records: list[dict[str, Any]],
    tokenizer: Any,
    max_length: int = 256,
) -> Any:
    from datasets import Dataset
    all_input_ids: list[list[int]] = []
    all_attention_masks: list[list[int]] = []
    all_labels: list[list[int]] = []
    for record in records:
        text = record.get("text", "")
        spans = record.get("spans", [])
        if not text:
            continue
        encoded = align_labels_to_tokens(text, spans, tokenizer, max_length)
        all_input_ids.append(encoded["input_ids"])
        all_attention_masks.append(encoded["attention_mask"])
        all_labels.append(encoded["labels"])
    return Dataset.from_dict({
        "input_ids": all_input_ids,
        "attention_mask": all_attention_masks,
        "labels": all_labels,
    })


def build_compute_metrics(id2label: dict[int, str]) -> Any:
    import numpy as np
    def compute_metrics(eval_pred: Any) -> dict[str, float]:
        predictions, label_ids = eval_pred
        preds = np.argmax(predictions, axis=-1)
        true_labels: list[str] = []
        pred_labels: list[str] = []
        for seq_true, seq_pred in zip(label_ids, preds):
            for t, p in zip(seq_true, seq_pred):
                if t == -100:
                    continue
                true_labels.append(id2label.get(int(t), "O"))
                pred_labels.append(id2label.get(int(p), "O"))
        entity_labels = {lbl for lbl in set(true_labels) | set(pred_labels) if lbl != "O"}
        tp = 0
        fp = 0
        fn = 0
        per_label: dict[str, dict[str, int]] = {}
        for label in entity_labels:
            per_label[label] = {"tp": 0, "fp": 0, "fn": 0}
        for true, pred in zip(true_labels, pred_labels):
            if true == pred and true != "O":
                tp += 1
                if true in per_label:
                    per_label[true]["tp"] += 1
            elif pred != "O" and true != pred:
                fp += 1
                if pred in per_label:
                    per_label[pred]["fp"] += 1
                if true != "O" and true in per_label:
                    per_label[true]["fn"] += 1
            elif true != "O" and pred != true:
                fn += 1
                if true in per_label:
                    per_label[true]["fn"] += 1
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
        metrics: dict[str, float] = {
            "precision": round(precision, 4),
            "recall": round(recall, 4),
            "f1": round(f1, 4),
        }
        for label in sorted(entity_labels):
            ltp = per_label[label]["tp"]
            lfp = per_label[label]["fp"]
            lfn = per_label[label]["fn"]
            lp = ltp / (ltp + lfp) if (ltp + lfp) > 0 else 0.0
            lr = ltp / (ltp + lfn) if (ltp + lfn) > 0 else 0.0
            lf1 = 2 * lp * lr / (lp + lr) if (lp + lr) > 0 else 0.0
            metrics[f"{label}_f1"] = round(lf1, 4)
        return metrics
    return compute_metrics


def save_training_log(
    output_dir: Path,
    config: TrainConfig,
    train_result: Any,
    eval_metrics: dict[str, float] | None,
    elapsed_seconds: float,
) -> None:
    log: dict[str, Any] = {
        "model_name": config.model_name,
        "num_labels": NUM_LABELS,
        "bio_labels": BIO_LABELS,
        "label2id": LABEL2ID,
        "id2label": {str(k): v for k, v in ID2LABEL.items()},
        "hyperparameters": {
            "learning_rate": config.learning_rate,
            "batch_size": config.batch_size,
            "num_epochs": config.num_epochs,
            "warmup_ratio": config.warmup_ratio,
            "weight_decay": config.weight_decay,
            "max_seq_length": config.max_seq_length,
            "gradient_accumulation_steps": config.gradient_accumulation_steps,
            "seed": config.seed,
            "fp16": config.fp16,
            "bf16": config.bf16,
        },
        "training": {},
        "elapsed_seconds": round(elapsed_seconds, 2),
    }
    if train_result is not None:
        metrics = getattr(train_result, "metrics", {})
        log["training"] = {
            k: round(v, 6) if isinstance(v, float) else v
            for k, v in metrics.items()
        }
    if eval_metrics is not None:
        log["eval"] = {
            k: round(v, 6) if isinstance(v, float) else v
            for k, v in eval_metrics.items()
        }
    log_path = output_dir / "training_log.json"
    with open(log_path, "w", encoding="utf-8") as f:
        json.dump(log, f, indent=2, ensure_ascii=False)


def train(config: TrainConfig) -> int:
    start_time = time.time()
    import torch
    from transformers import (
        AutoModelForTokenClassification,
        AutoTokenizer,
        EarlyStoppingCallback,
        Trainer,
        TrainingArguments,
    )
    train_path = Path(config.train_file)
    val_path = Path(config.val_file)
    if not train_path.exists():
        print(f"Error: training file not found: {train_path}")
        return 1
    if not val_path.exists():
        print(f"Error: validation file not found: {val_path}")
        return 1
    output_dir = Path(config.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    print(f"Loading training data from {train_path}...")
    train_records = load_enrichment_records(train_path)
    print(f"  {len(train_records)} training records")
    print(f"Loading validation data from {val_path}...")
    val_records = load_enrichment_records(val_path)
    print(f"  {len(val_records)} validation records")
    print(f"Loading tokenizer: {config.model_name}...")
    tokenizer = AutoTokenizer.from_pretrained(
        config.model_name, model_max_length=config.max_seq_length,
    )
    print("Tokenizing and aligning BIO labels (train)...")
    train_dataset = prepare_dataset(train_records, tokenizer, config.max_seq_length)
    print(f"  {len(train_dataset)} training examples")
    print("Tokenizing and aligning BIO labels (val)...")
    val_dataset = prepare_dataset(val_records, tokenizer, config.max_seq_length)
    print(f"  {len(val_dataset)} validation examples")
    print(f"Loading model: {config.model_name}...")
    model = AutoModelForTokenClassification.from_pretrained(
        config.model_name,
        num_labels=NUM_LABELS,
        id2label=ID2LABEL,
        label2id=LABEL2ID,
    )
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"  {trainable_params:,} trainable / {total_params:,} total parameters")
    training_args = TrainingArguments(
        output_dir=str(output_dir),
        num_train_epochs=config.num_epochs,
        per_device_train_batch_size=config.batch_size,
        per_device_eval_batch_size=config.batch_size * 2,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        learning_rate=config.learning_rate,
        warmup_ratio=config.warmup_ratio,
        weight_decay=config.weight_decay,
        fp16=config.fp16,
        bf16=config.bf16,
        logging_strategy="steps",
        logging_steps=config.logging_steps,
        eval_strategy="steps",
        eval_steps=config.eval_steps,
        save_strategy="steps",
        save_steps=config.save_steps,
        save_total_limit=config.save_total_limit,
        load_best_model_at_end=True,
        metric_for_best_model="f1",
        greater_is_better=True,
        seed=config.seed,
        report_to="none",
        remove_unused_columns=False,
    )
    callbacks = []
    if config.early_stopping_patience > 0:
        callbacks.append(
            EarlyStoppingCallback(early_stopping_patience=config.early_stopping_patience)
        )
    compute_metrics_fn = build_compute_metrics(ID2LABEL)
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        compute_metrics=compute_metrics_fn,
        callbacks=callbacks,
    )
    print("\n" + "=" * 60)
    print("Starting training...")
    print(f"  Model:          {config.model_name}")
    print(f"  Labels:         {NUM_LABELS} BIO tags")
    print(f"  Train examples: {len(train_dataset)}")
    print(f"  Val examples:   {len(val_dataset)}")
    print(f"  Epochs:         {config.num_epochs}")
    print(f"  Batch size:     {config.batch_size}")
    print(f"  Learning rate:  {config.learning_rate}")
    print(f"  Early stopping: patience={config.early_stopping_patience}")
    print("=" * 60 + "\n")
    train_result = trainer.train()
    print("\nRunning final evaluation...")
    eval_metrics = trainer.evaluate()
    for key, value in sorted(eval_metrics.items()):
        formatted = f"{value:.4f}" if isinstance(value, float) else str(value)
        print(f"  {key}: {formatted}")
    print(f"\nSaving model and tokenizer to {output_dir}...")
    trainer.save_model(str(output_dir))
    tokenizer.save_pretrained(str(output_dir))
    elapsed = time.time() - start_time
    save_training_log(output_dir, config, train_result, eval_metrics, elapsed)
    print(f"Training log saved to {output_dir / 'training_log.json'}")
    print(f"\nDone! Elapsed: {elapsed:.1f}s")
    return 0


def parse_args() -> TrainConfig:
    parser = argparse.ArgumentParser(
        description="Train the SciX enrichment NER model (SciBERT token classification)."
    )
    parser.add_argument("--model-name", type=str, default="allenai/scibert_scivocab_uncased")
    parser.add_argument("--train-file", type=str, required=True)
    parser.add_argument("--val-file", type=str, required=True)
    parser.add_argument("--output-dir", type=str, default="output/enrichment_model")
    parser.add_argument("--learning-rate", type=float, default=2e-5)
    parser.add_argument("--batch-size", type=int, default=16)
    parser.add_argument("--num-epochs", type=int, default=10)
    parser.add_argument("--warmup-ratio", type=float, default=0.1)
    parser.add_argument("--weight-decay", type=float, default=0.01)
    parser.add_argument("--max-seq-length", type=int, default=256)
    parser.add_argument("--gradient-accumulation-steps", type=int, default=1)
    parser.add_argument("--early-stopping-patience", type=int, default=3)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--fp16", action="store_true")
    parser.add_argument("--bf16", action="store_true")
    parser.add_argument("--logging-steps", type=int, default=50)
    parser.add_argument("--eval-steps", type=int, default=200)
    parser.add_argument("--save-steps", type=int, default=200)
    parser.add_argument("--save-total-limit", type=int, default=3)
    args = parser.parse_args()
    return TrainConfig(
        model_name=args.model_name,
        train_file=args.train_file,
        val_file=args.val_file,
        output_dir=args.output_dir,
        learning_rate=args.learning_rate,
        batch_size=args.batch_size,
        num_epochs=args.num_epochs,
        warmup_ratio=args.warmup_ratio,
        weight_decay=args.weight_decay,
        max_seq_length=args.max_seq_length,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        early_stopping_patience=args.early_stopping_patience,
        seed=args.seed,
        fp16=args.fp16,
        bf16=args.bf16,
        logging_steps=args.logging_steps,
        eval_steps=args.eval_steps,
        save_steps=args.save_steps,
        save_total_limit=args.save_total_limit,
    )


if __name__ == "__main__":
    sys.exit(train(parse_args()))

## 4. Run Training

Detect GPU type and set the appropriate mixed-precision flag:
- **T4**: Use `--fp16` (no bf16 support)
- **A100/H100**: Use `--bf16` (more stable and efficient)

In [None]:
import torch

# Auto-detect GPU precision support
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    # A100/H100/L4 support bf16; T4/V100/P100 use fp16
    if any(arch in gpu_name for arch in ["A100", "H100", "L4", "L40"]):
        PRECISION_FLAG = "--bf16"
    else:
        PRECISION_FLAG = "--fp16"
    print(f"GPU: {gpu_name} -> using {PRECISION_FLAG}")
else:
    PRECISION_FLAG = ""
    print("No GPU detected, training without mixed precision")

In [None]:
!python train_enrichment_model.py \
    --model-name allenai/scibert_scivocab_uncased \
    --train-file {TRAIN_FILE} \
    --val-file {VAL_FILE} \
    --output-dir {OUTPUT_DIR} \
    --num-epochs 10 \
    --batch-size 16 \
    --learning-rate 2e-5 \
    --warmup-ratio 0.1 \
    --early-stopping-patience 3 \
    {PRECISION_FLAG}

## 5. Inspect Results

In [None]:
import json
from pathlib import Path

log_path = Path(OUTPUT_DIR) / "training_log.json"
if log_path.exists():
    with open(log_path) as f:
        log = json.load(f)

    print("=" * 50)
    print("Training Results")
    print("=" * 50)
    print(f"Model:   {log['model_name']}")
    print(f"Labels:  {log['num_labels']} BIO tags")
    print(f"Elapsed: {log['elapsed_seconds']:.1f}s")
    print()

    if "eval" in log:
        print("Evaluation metrics:")
        for k, v in sorted(log["eval"].items()):
            formatted = f"{v:.4f}" if isinstance(v, float) else str(v)
            print(f"  {k}: {formatted}")
else:
    print(f"Training log not found at {log_path}")
    print("Training may not have completed successfully.")

## 6. Save Checkpoint to Google Drive

Copy the trained model to a persistent Google Drive location.

In [None]:
import shutil
from pathlib import Path

DRIVE_SAVE_DIR = "/content/drive/MyDrive/scix-enrichment/enrichment_model"
drive_path = Path(DRIVE_SAVE_DIR)
model_path = Path(OUTPUT_DIR)

if not model_path.exists():
    print(f"Model directory not found: {model_path}")
    print("Run training first (Section 4).")
else:
    drive_path.mkdir(parents=True, exist_ok=True)

    # Copy all model files
    files_copied = []
    for src_file in model_path.iterdir():
        if src_file.is_file():
            dst_file = drive_path / src_file.name
            shutil.copy2(src_file, dst_file)
            size_mb = dst_file.stat().st_size / (1024 * 1024)
            files_copied.append((src_file.name, size_mb))

    print(f"Saved {len(files_copied)} files to {DRIVE_SAVE_DIR}:")
    for name, size in sorted(files_copied):
        print(f"  {name} ({size:.1f} MB)")

## 7. Push to HuggingFace Hub (Optional)

Requires `HF_TOKEN` environment variable with a write-access token.
Get one from https://huggingface.co/settings/tokens

In [None]:
import os
from pathlib import Path

HF_TOKEN = os.environ.get("HF_TOKEN", "")
HF_REPO_ID = "your-username/scix-enrichment-ner"  # Change to your repo

if not HF_TOKEN:
    print("HF_TOKEN not set. Skipping HuggingFace Hub push.")
    print("To push, set HF_TOKEN:")
    print('  import os; os.environ["HF_TOKEN"] = "<your-token>"')
else:
    from transformers import AutoModelForTokenClassification, AutoTokenizer

    model_path = Path(OUTPUT_DIR)
    if not model_path.exists():
        print(f"Model not found at {model_path}. Run training first.")
    else:
        print(f"Loading model from {model_path}...")
        model = AutoModelForTokenClassification.from_pretrained(str(model_path))
        tokenizer = AutoTokenizer.from_pretrained(str(model_path))

        print(f"Pushing to {HF_REPO_ID}...")
        model.push_to_hub(HF_REPO_ID, token=HF_TOKEN)
        tokenizer.push_to_hub(HF_REPO_ID, token=HF_TOKEN)
        print(f"Model pushed to https://huggingface.co/{HF_REPO_ID}")