# NEREL Joint NER + Document Classification (Multi-Task)

This notebook implements a single-encoder model that performs token-level NER (BIO) and document-level multi-label classification on NEREL.

- Encoder: Hugging Face Transformer (small Russian BERT)
- Heads: token classification head (NER), doc classification head (multi-label)
- Loss: sum or uncertainty-weighted
- Metrics: token F1 (seqeval), doc micro-F1 (sklearn)



In [None]:
# Environment and imports
from __future__ import annotations
import os
import json
import random
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple

import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

from transformers import (
    AutoTokenizer,
    AutoModel,
    DataCollatorForTokenClassification,
)

from sklearn.metrics import f1_score, precision_score, recall_score

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

MODEL_NAME = os.environ.get("RU_BASE_MODEL", "cointegrated/rubert-tiny2")
USE_UNCERTAINTY = True
MAX_LENGTH = 256
BATCH_SIZE = 8
LR = 2e-5
EPOCHS = 3
ACCUM_STEPS = 1
GRAD_CLIP_NORM = 1.0

DATA_DIR = "."
TRAIN_PATH = os.path.join(DATA_DIR, "train.jsonl")
DEV_PATH = os.path.join(DATA_DIR, "dev.jsonl")
TEST_PATH = os.path.join(DATA_DIR, "test.jsonl")



In [None]:
# Tokenizer setup
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)



In [None]:
# Download NEREL JSONL files if missing
import urllib.request

NEREL_BASE = "https://huggingface.co/datasets/iluvvatar/NEREL/resolve/main/data"
FILES = {
    TRAIN_PATH: f"{NEREL_BASE}/train.jsonl",
    DEV_PATH: f"{NEREL_BASE}/dev.jsonl",
    TEST_PATH: f"{NEREL_BASE}/test.jsonl",
}

for local_path, url in FILES.items():
    if not os.path.exists(local_path):
        urllib.request.urlretrieve(url, local_path)
local_paths = list(FILES.keys())
local_paths


In [None]:
# Data utils: loading JSONL and simple EDA helpers

def read_jsonl(path: str) -> List[Dict[str, Any]]:
    data: List[Dict[str, Any]] = []
    if not os.path.exists(path):
        return data
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            data.append(json.loads(line))
    return data

train_raw = read_jsonl(TRAIN_PATH)
dev_raw = read_jsonl(DEV_PATH)
test_raw = read_jsonl(TEST_PATH)

len(train_raw), len(dev_raw), len(test_raw)


In [None]:
# EDA: lightweight counts (safe if dataset not yet downloaded)
from collections import Counter
from typing import Counter as TCounter

def top_k(counter: TCounter[Any], k: int = 15):
    return counter.most_common(k)

entity_type_counter: TCounter[str] = Counter()
text_len_counter: TCounter[int] = Counter()
entities_per_doc_counter: TCounter[int] = Counter()

for doc in train_raw[:2000]:
    text = doc.get("text", "")
    ents = doc.get("entities", [])
    for e in ents:
        entity_type_counter[e.get("type", "UNKNOWN")] += 1
    text_len_counter[len(text.split())] += 1
    entities_per_doc_counter[len(ents)] += 1

{
    "top_entity_types": top_k(entity_type_counter, 15),
    "text_len_bins": top_k(text_len_counter, 15),
    "entities_per_doc": top_k(entities_per_doc_counter, 15),
}


In [None]:
# Parsing utilities: whitespace tokens, spans, BIO, doc labels

def whitespace_tokenize_with_spans(text: str) -> Tuple[List[str], List[Tuple[int, int]]]:
    tokens: List[str] = []
    spans: List[Tuple[int, int]] = []
    i = 0
    while i < len(text):
        if text[i].isspace():
            i += 1
            continue
        j = i
        while j < len(text) and not text[j].isspace():
            j += 1
        tokens.append(text[i:j])
        spans.append((i, j))
        i = j
    return tokens, spans

@dataclass(frozen=True, slots=True)
class Entity:
    start: int
    end: int
    type: str

# Build label spaces

def build_label_spaces(docs: List[Dict[str, Any]], top_k_events: int = 30):
    ent_types: Dict[str, int] = {}
    event_counter: Dict[str, int] = {}

    for doc in docs:
        for e in doc.get("entities", []):
            ent_types[e.get("type", "UNKNOWN")] = 1
        for rel in doc.get("relations", []):
            t = rel.get("type", "UNKNOWN")
            event_counter[t] = event_counter.get(t, 0) + 1
        for ev in doc.get("events", []):
            t = ev.get("type", "UNKNOWN")
            event_counter[t] = event_counter.get(t, 0) + 1

    ent_types_sorted = sorted(ent_types.keys())
    ner_tags: List[str] = ["O"]
    for t in ent_types_sorted:
        ner_tags.append(f"B-{t}")
        ner_tags.append(f"I-{t}")
    ner_label_to_id = {t: i for i, t in enumerate(ner_tags)}
    ner_id_to_label = {i: t for t, i in ner_label_to_id.items()}

    event_sorted = sorted(event_counter.items(), key=lambda x: x[1], reverse=True)
    top_events = [t for t, _ in event_sorted[:top_k_events]]
    doc_label_to_id = {t: i for i, t in enumerate(top_events)}
    doc_id_to_label = {i: t for t, i in doc_label_to_id.items()}

    return ner_label_to_id, ner_id_to_label, doc_label_to_id, doc_id_to_label

# Convert entities to BIO over whitespace tokens

def entities_to_bio(tokens: List[str], spans: List[Tuple[int, int]], ents: List[Dict[str, Any]], ner_label_to_id: Dict[str, int]) -> List[int]:
    labels = [ner_label_to_id["O"]] * len(tokens)
    for e in ents:
        st = int(e.get("start", -1))
        en = int(e.get("end", -1))
        et = e.get("type", "UNKNOWN")
        began = False
        for idx, (s, e_) in enumerate(spans):
            if s >= en:
                break
            if e_ <= st:
                continue
            if s < en and e_ > st:
                if not began:
                    labels[idx] = ner_label_to_id.get(f"B-{et}", ner_label_to_id["O"])
                    began = True
                else:
                    labels[idx] = ner_label_to_id.get(f"I-{et}", ner_label_to_id["O"])
    return labels

# Build examples from raw docs

def build_examples_from_nerel(raw_docs: List[Dict[str, Any]], ner_label_to_id: Dict[str, int], doc_label_to_id: Dict[str, int]):
    examples = []
    for doc in raw_docs:
        text = doc.get("text", "")
        tokens, spans = whitespace_tokenize_with_spans(text)
        bio = entities_to_bio(tokens, spans, doc.get("entities", []), ner_label_to_id)
        multi = np.zeros(len(doc_label_to_id), dtype=np.float32)
        for rel in doc.get("relations", []):
            t = rel.get("type", "UNKNOWN")
            if t in doc_label_to_id:
                multi[doc_label_to_id[t]] = 1.0
        for ev in doc.get("events", []):
            t = ev.get("type", "UNKNOWN")
            if t in doc_label_to_id:
                multi[doc_label_to_id[t]] = 1.0
        examples.append({
            "tokens": tokens,
            "token_spans": spans,
            "bio": bio,
            "cls_vec": multi,
        })
    return examples

# Label spaces
ner_label_to_id, ner_id_to_label, doc_label_to_id, doc_id_to_label = build_label_spaces(train_raw)
len(ner_label_to_id), len(doc_label_to_id)


In [None]:
# Build train/dev/test examples
train_examples = build_examples_from_nerel(train_raw, ner_label_to_id, doc_label_to_id)
dev_examples = build_examples_from_nerel(dev_raw, ner_label_to_id, doc_label_to_id)
test_examples = build_examples_from_nerel(test_raw, ner_label_to_id, doc_label_to_id)

len(train_examples), len(dev_examples), len(test_examples)


In [None]:
# Dataset fix and simple collate
class JointDataset(Dataset):
    def __init__(self, examples: List[Dict[str, Any]]):
        super().__init__()
        self.examples = examples
    def __len__(self) -> int:
        return len(self.examples)
    def __getitem__(self, idx: int) -> Dict[str, Any]:
        ex = self.examples[idx]
        enc, lab = tokenize_and_align_labels(ex["tokens"], ex["bio"])
        enc["labels"] = lab
        enc["cls_labels"] = ex["cls_vec"].astype(np.float32)
        return enc

train_ds = JointDataset(train_examples)
dev_ds = JointDataset(dev_examples)
test_ds = JointDataset(test_examples)


def collate_batch(features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
    max_len = max(len(f["input_ids"]) for f in features)
    input_ids, attention_mask, labels = [], [], []
    for f in features:
        ids = f["input_ids"]
        att = f["attention_mask"]
        lab = f["labels"]
        pad_len = max_len - len(ids)
        input_ids.append(ids + [tokenizer.pad_token_id] * pad_len)
        attention_mask.append(att + [0] * pad_len)
        labels.append(lab + [-100] * pad_len)
    batch = {
        "input_ids": torch.tensor(input_ids, dtype=torch.long),
        "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
        "labels": torch.tensor(labels, dtype=torch.long),
        "cls_labels": torch.tensor([f["cls_labels"] for f in features], dtype=torch.float32),
    }
    return batch

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
dev_loader = DataLoader(dev_ds, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch)


In [None]:
# Joint model: encoder + two heads + optional uncertainty weighting
class JointModel(nn.Module):
    def __init__(self, base_model_name: str, num_ner_labels: int, num_doc_labels: int, use_uncertainty: bool = True):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(base_model_name)
        hidden = self.encoder.config.hidden_size
        self.dropout = nn.Dropout(0.1)
        self.token_cls = nn.Linear(hidden, num_ner_labels)
        self.cls_cls = nn.Linear(hidden, num_doc_labels)
        self.use_uncertainty = use_uncertainty
        if use_uncertainty:
            self.log_sigma_token = nn.Parameter(torch.tensor(0.0))
            self.log_sigma_cls = nn.Parameter(torch.tensor(0.0))

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> Dict[str, torch.Tensor]:
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        seq = self.dropout(outputs.last_hidden_state)
        token_logits = self.token_cls(seq)
        if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
            pooled = outputs.pooler_output
        else:
            pooled = seq[:, 0]
        pooled = self.dropout(pooled)
        cls_logits = self.cls_cls(pooled)
        return {"token_logits": token_logits, "cls_logits": cls_logits}

    def compute_loss(self, token_logits: torch.Tensor, token_labels: torch.Tensor, cls_logits: torch.Tensor, cls_labels: torch.Tensor) -> torch.Tensor:
        ce = nn.CrossEntropyLoss(ignore_index=-100)
        bce = nn.BCEWithLogitsLoss()
        C = token_logits.size(-1)
        token_loss = ce(token_logits.view(-1, C), token_labels.view(-1))
        cls_loss = bce(cls_logits, cls_labels)
        if getattr(self, "use_uncertainty", False):
            loss_token_term = torch.exp(-2.0 * self.log_sigma_token) * token_loss + self.log_sigma_token
            loss_cls_term = torch.exp(-2.0 * self.log_sigma_cls) * cls_loss + self.log_sigma_cls
            return loss_token_term + loss_cls_term
        return token_loss + cls_loss

num_ner_labels = len(ner_label_to_id)
num_doc_labels = len(doc_label_to_id)
model = JointModel(MODEL_NAME, num_ner_labels, num_doc_labels, use_uncertainty=USE_UNCERTAINTY).to(device)



In [None]:
# Training & evaluation utilities
from torch.cuda.amp import autocast, GradScaler

optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
scaler = GradScaler(enabled=torch.cuda.is_available())


def decode_token_preds(logits: np.ndarray, labels: np.ndarray) -> Tuple[List[str], List[str]]:
    # Convert per-token logits and labels to label strings for F1
    pred_ids = logits.argmax(-1)
    y_true, y_pred = [], []
    for true_row, pred_row in zip(labels, pred_ids):
        for t, p in zip(true_row, pred_row):
            if t == -100:
                continue
            y_true.append(ner_id_to_label[int(t)])
            y_pred.append(ner_id_to_label[int(p)])
    return y_true, y_pred


def evaluate(model: JointModel, loader: DataLoader) -> Dict[str, float]:
    model.eval()
    all_true_doc_bin: List[int] = []
    all_pred_doc_bin: List[int] = []

    token_true_labels: List[str] = []
    token_pred_labels: List[str] = []

    with torch.inference_mode():
        for batch in loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            cls_labels = batch["cls_labels"].to(device)

            out = model(input_ids=input_ids, attention_mask=attention_mask)
            token_logits = out["token_logits"].detach().cpu().numpy()
            cls_logits = out["cls_logits"].detach().cpu().numpy()

            y_true_tokens, y_pred_tokens = decode_token_preds(token_logits, labels.detach().cpu().numpy())
            token_true_labels.extend(y_true_tokens)
            token_pred_labels.extend(y_pred_tokens)

            y_true_doc = (cls_labels.detach().cpu().numpy() > 0.5).astype(int)
            y_pred_doc = (1 / (1 + np.exp(-cls_logits)) >= 0.5).astype(int)
            all_true_doc_bin.extend(y_true_doc.reshape(-1).tolist())
            all_pred_doc_bin.extend(y_pred_doc.reshape(-1).tolist())

    token_f1 = f1_score(token_true_labels, token_pred_labels, average="macro") if token_true_labels else 0.0
    micro_f1 = f1_score(all_true_doc_bin, all_pred_doc_bin, average="micro") if all_true_doc_bin else 0.0
    micro_p = precision_score(all_true_doc_bin, all_pred_doc_bin, average="micro") if all_true_doc_bin else 0.0
    micro_r = recall_score(all_true_doc_bin, all_pred_doc_bin, average="micro") if all_true_doc_bin else 0.0
    return {"token_f1": float(token_f1), "cls_micro_f1": float(micro_f1), "cls_precision": float(micro_p), "cls_recall": float(micro_r)}


def train(model: JointModel, train_loader: DataLoader, dev_loader: DataLoader, epochs: int = EPOCHS):
    best_dev = -1.0
    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad(set_to_none=True)
        step = 0
        for batch in train_loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            cls_labels = batch["cls_labels"].to(device)

            with autocast(device_type="cuda", dtype=torch.float16, enabled=torch.cuda.is_available()):
                out = model(input_ids=input_ids, attention_mask=attention_mask)
                loss = model.compute_loss(out["token_logits"], labels, out["cls_logits"], cls_labels)
                loss = loss / ACCUM_STEPS

            scaler.scale(loss).backward()
            if (step + 1) % ACCUM_STEPS == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP_NORM)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad(set_to_none=True)
            step += 1

        dev_metrics = evaluate(model, dev_loader)
        print({"epoch": epoch + 1, **dev_metrics})
        if dev_metrics["cls_micro_f1"] > best_dev:
            best_dev = dev_metrics["cls_micro_f1"]
    return best_dev

best_dev = train(model, train_loader, dev_loader, EPOCHS)


In [None]:
# Test evaluation on test set
metrics_test = evaluate(model, test_loader)
metrics_test


In [None]:
# Inference function

def predict(text: str) -> Dict[str, Any]:
    model.eval()
    tokens, spans = whitespace_tokenize_with_spans(text)
    enc, lab = tokenize_and_align_labels(tokens, [ner_label_to_id["O"]] * len(tokens))
    batch = collate_batch([{**enc, "labels": lab, "cls_labels": np.zeros(len(doc_label_to_id), dtype=np.float32)}])
    with torch.inference_mode():
        out = model(batch["input_ids"].to(device), batch["attention_mask"].to(device))
        token_logits = out["token_logits"][0].detach().cpu().numpy()
        cls_logits = out["cls_logits"][0].detach().cpu().numpy()
    token_ids = token_logits.argmax(-1).tolist()
    token_labels = [ner_id_to_label[i] if lab_i != -100 else "IGN" for i, lab_i in zip(token_ids, batch["labels"][0].tolist())]
    cls_probs = (1 / (1 + np.exp(-cls_logits))).tolist()
    return {
        "tokens": tokens,
        "token_labels": token_labels,
        "cls": sorted([(doc_id_to_label[i], float(p)) for i, p in enumerate(cls_probs)], key=lambda x: x[1], reverse=True)[:10],
    }

# Example
predict("Компания А подписала договор с Банком Б 2020 года.")


In [None]:
# Simple dynamic quantization demo (CPU only)
qm = torch.quantization.quantize_dynamic(model.cpu(), {nn.Linear}, dtype=torch.qint8)
qm


In [None]:
# Tokenize and align labels

def tokenize_and_align_labels(tokens: List[str], bio_labels: List[int]):
    encoding = tokenizer(
        tokens,
        is_split_into_words=True,
        truncation=True,
        padding=False,
        max_length=MAX_LENGTH,
        return_tensors=None,
    )
    word_ids = encoding.word_ids()
    labels: List[int] = []
    for wi in word_ids:
        if wi is None:
            labels.append(-100)
        else:
            labels.append(bio_labels[wi])
    return encoding, labels

class JointDataset(Dataset):
    def __init__(self, examples: List[Dict[str, Any]]):
        self.examples = examples
    def __len__(self) -> int:
        return len(self.examples)
    def __getitem__(self, idx: int) -> Dict[str, Any]:
        ex = self.examples[idx]
        enc, lab = tokenize_and_align_labels(ex["tokens"], ex["bio"])
        enc["labels"] = lab
        enc["cls_labels"] = ex["cls_vec"].astype(np.float32)
        return enc

train_ds = JointDataset(train_examples)
dev_ds = JointDataset(dev_examples)
test_ds = JointDataset(test_examples)

collator = DataCollatorForTokenClassification(tokenizer=tokenizer)
