In [1]:
from transformers import AutoModel, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_path = "/data/model_hub/mdeberta-v3-base/"

In [2]:
# model = AutoModel.from_pretrained("/data/model_hub/mdeberta-v3-base/")

Loading weights: 100%|██████████████████████████████████████████████████████████████████████████████████| 198/198 [00:00<00:00, 2530.16it/s, Materializing param=encoder.rel_embeddings.weight]
[1mDebertaV2Model LOAD REPORT[0m from: /data/model_hub/mdeberta-v3-base/
Key                                        | Status     |  | 
-------------------------------------------+------------+--+-
mask_predictions.classifier.bias           | UNEXPECTED |  | 
lm_predictions.lm_head.LayerNorm.bias      | UNEXPECTED |  | 
lm_predictions.lm_head.dense.weight        | UNEXPECTED |  | 
mask_predictions.dense.bias                | UNEXPECTED |  | 
lm_predictions.lm_head.bias                | UNEXPECTED |  | 
lm_predictions.lm_head.dense.bias          | UNEXPECTED |  | 
mask_predictions.LayerNorm.bias            | UNEXPECTED |  | 
mask_predictions.LayerNorm.weight          | UNEXPECTED |  | 
mask_predictions.classifier.weight         | UNEXPECTED |  | 
deberta.embeddings.word_embeddings._weight | UNEXPE

In [2]:
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional, Tuple, Dict

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel


def make_mlp(input_dim: int, hidden_dim: int, output_dim: int, dropout: float = 0.1) -> nn.Module:
    return nn.Sequential(
        nn.Linear(input_dim, hidden_dim),
        nn.ReLU(),
        nn.Dropout(dropout),
        nn.Linear(hidden_dim, output_dim),
    )


@dataclass
class SpanBatch:
    # Text
    text_input_ids: torch.LongTensor         # (B, N)
    text_attention_mask: torch.LongTensor    # (B, N)

    # Labels (label+desc) as a batch-of-batches flattened to (B*M, L)
    label_input_ids: torch.LongTensor        # (B*M, L)
    label_attention_mask: torch.LongTensor   # (B*M, L)

    # Number of labels per sample (M). Assumed fixed M across batch for simplicity.
    num_labels: int

    # Optional supervision (multi-label over spans):
    # span_targets: (B, num_spans, M) with 0/1, or float in [0,1]
    span_targets: Optional[torch.FloatTensor] = None


class CrossAttentionFusion(nn.Module):
    """Label->Text cross-attention. Q attends to H."""
    def __init__(self, d_model: int, num_heads: int = 8, dropout: float = 0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=num_heads, dropout=dropout, batch_first=True)
        self.ln = nn.LayerNorm(d_model)

    def forward(self, q: torch.Tensor, h: torch.Tensor, h_key_padding_mask: Optional[torch.BoolTensor] = None) -> torch.Tensor:
        """
        q: (B, M, d)
        h: (B, N, d)
        h_key_padding_mask: (B, N) True for PAD (to mask out)
        """
        ctx, _ = self.attn(query=q, key=h, value=h, key_padding_mask=h_key_padding_mask, need_weights=False)
        return self.ln(q + ctx)


class SpanEnumerator:
    """Enumerate all spans up to max_width for each sequence length N (excluding pads by attention mask later)."""
    def __init__(self, max_width: int):
        self.max_width = int(max_width)

    def enumerate(self, seq_len: int, device: torch.device) -> Tuple[torch.LongTensor, torch.LongTensor, torch.LongTensor]:
        """
        Returns:
          start_idx: (S,)
          end_idx:   (S,)
          width:     (S,)  (end-start)
        where S = sum_{w=1..max_width} (seq_len - w + 1)
        """
        starts = []
        ends = []
        widths = []
        for w in range(1, self.max_width + 1):
            s = torch.arange(0, seq_len - w + 1, device=device, dtype=torch.long)
            e = s + (w - 1)
            starts.append(s)
            ends.append(e)
            widths.append(torch.full_like(s, w - 1))
        start_idx = torch.cat(starts, dim=0)
        end_idx = torch.cat(ends, dim=0)
        width = torch.cat(widths, dim=0)
        return start_idx, end_idx, width


class DebertaSchemaSpanModel(nn.Module):
    """
    Main model:
      Text encoder -> token reps
      Label encoder -> label reps
      Cross-attn fusion -> fused label reps
      Span scoring -> logits(span, label)
    """
    def __init__(
        self,
        backbone_name: str = "microsoft/mdeberta-v3-base",
        share_encoders: bool = True,
        use_width_embedding: bool = True,
        max_span_width: int = 12,
        num_heads: int = 8,
        dropout: float = 0.1,
    ):
        super().__init__()

        self.text_encoder = AutoModel.from_pretrained(backbone_name)
        if share_encoders:
            self.label_encoder = self.text_encoder
        else:
            self.label_encoder = AutoModel.from_pretrained(backbone_name)

        d_model = self.text_encoder.config.hidden_size
        self.d_model = d_model
        self.max_span_width = int(max_span_width)
        self.span_enum = SpanEnumerator(max_width=max_span_width)

        self.fuse = CrossAttentionFusion(d_model=d_model, num_heads=num_heads, dropout=dropout)

        self.use_width_embedding = bool(use_width_embedding)
        if self.use_width_embedding:
            self.width_emb = nn.Embedding(max_span_width, d_model)  # width in [0..max_span_width-1]
            span_in = d_model * 2 + d_model
        else:
            span_in = d_model * 2

        self.span_ffn = make_mlp(span_in, hidden_dim=d_model * 4, output_dim=d_model, dropout=dropout)

        # Optional: bias term per label (helps calibration when many negatives)
        self.label_bias = nn.Parameter(torch.zeros(1))

    @staticmethod
    def _pool_cls(last_hidden_state: torch.Tensor) -> torch.Tensor:
        """Use first token as CLS. Shape: (B, L, d) -> (B, d)."""
        return last_hidden_state[:, 0, :]

    def encode_text(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor) -> torch.Tensor:
        out = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
        return out.last_hidden_state  # (B, N, d)

    def encode_labels(
        self,
        input_ids: torch.LongTensor,
        attention_mask: torch.LongTensor,
        batch_size: int,
        num_labels: int,
    ) -> torch.Tensor:
        """
        input_ids/attention_mask: (B*M, L)
        returns Q: (B, M, d)
        """
        out = self.label_encoder(input_ids=input_ids, attention_mask=attention_mask)
        pooled = self._pool_cls(out.last_hidden_state)  # (B*M, d)
        q = pooled.view(batch_size, num_labels, -1)     # (B, M, d)
        return q

    def forward(self, batch: SpanBatch) -> Dict[str, torch.Tensor]:
        """
        Returns:
          logits: (B, S, M) where S = num_spans based on N and max_span_width
          loss: optional
          span_mask: (B, S) valid spans under attention_mask
          start_idx/end_idx/width: (S,)
        """
        B, N = batch.text_input_ids.shape
        M = int(batch.num_labels)
        device = batch.text_input_ids.device

        # 1) Encode text
        H = self.encode_text(batch.text_input_ids, batch.text_attention_mask)  # (B, N, d)

        # 2) Encode labels (label+desc)
        Q = self.encode_labels(batch.label_input_ids, batch.label_attention_mask, batch_size=B, num_labels=M)  # (B, M, d)

        # 3) Cross-attn fusion (labels attend to text)
        # key_padding_mask expects True for padding positions
        pad_mask = batch.text_attention_mask == 0  # (B, N) bool
        Qf = self.fuse(Q, H, h_key_padding_mask=pad_mask)  # (B, M, d)

        # 4) Enumerate spans (based on full N, later masked by attention_mask)
        start_idx, end_idx, width = self.span_enum.enumerate(seq_len=N, device=device)  # (S,)
        S = start_idx.numel()

        # 5) Build span representations from token reps
        # Gather start/end token reps: (B, S, d)
        H_start = H.index_select(dim=1, index=start_idx)  # (B, S, d)
        H_end = H.index_select(dim=1, index=end_idx)      # (B, S, d)

        if self.use_width_embedding:
            W = self.width_emb(width).unsqueeze(0).expand(B, S, self.d_model)  # (B, S, d)
            span_in = torch.cat([H_start, H_end, W], dim=-1)                   # (B, S, 3d)
        else:
            span_in = torch.cat([H_start, H_end], dim=-1)                      # (B, S, 2d)

        span_vec = self.span_ffn(span_in)  # (B, S, d)

        # 6) Compute logits via dot-product with fused label vectors
        # logits[b, s, m] = <span_vec[b, s], Qf[b, m]>
        logits = torch.einsum("bsd,bmd->bsm", span_vec, Qf) + self.label_bias  # (B, S, M)

        # 7) Span validity mask (exclude spans that touch padding tokens)
        # Valid if both start and end positions are within attention_mask=1
        attn = batch.text_attention_mask.bool()  # (B, N)
        valid_start = attn.index_select(dim=1, index=start_idx)  # (B, S)
        valid_end = attn.index_select(dim=1, index=end_idx)      # (B, S)
        span_mask = valid_start & valid_end                       # (B, S)

        out: Dict[str, torch.Tensor] = {
            "logits": logits,             # (B, S, M)
            "span_mask": span_mask,       # (B, S)
            "start_idx": start_idx,       # (S,)
            "end_idx": end_idx,           # (S,)
            "width": width,               # (S,)
        }

        # 8) Optional loss (multi-label BCE over spans x labels)
        if batch.span_targets is not None:
            # span_targets expected shape: (B, S, M)
            if batch.span_targets.shape != logits.shape:
                raise ValueError(f"span_targets shape {batch.span_targets.shape} must match logits {logits.shape}")

            # Mask out invalid spans by setting them to ignore (we'll zero their loss weight)
            # BCEWithLogitsLoss supports per-element weights.
            weight = span_mask.unsqueeze(-1).float()  # (B, S, 1)
            loss = F.binary_cross_entropy_with_logits(
                logits,
                batch.span_targets,
                weight=weight,
                reduction="sum",
            )
            denom = weight.sum().clamp_min(1.0) * M
            out["loss"] = loss / denom

        return out

In [3]:
B, N = 2, 128
M, L = 20, 64
model = DebertaSchemaSpanModel(
    backbone_name="/data/model_hub/mdeberta-v3-base",
    share_encoders=True,
    use_width_embedding=True,
    max_span_width=12,
    num_heads=8,
    dropout=0.1,
)


Loading weights: 100%|██████████████████████████████████████████████████████████████████████████████████| 198/198 [00:00<00:00, 2605.51it/s, Materializing param=encoder.rel_embeddings.weight]
[1mDebertaV2Model LOAD REPORT[0m from: /data/model_hub/mdeberta-v3-base
Key                                        | Status     |  | 
-------------------------------------------+------------+--+-
mask_predictions.LayerNorm.bias            | UNEXPECTED |  | 
deberta.embeddings.word_embeddings._weight | UNEXPECTED |  | 
mask_predictions.LayerNorm.weight          | UNEXPECTED |  | 
lm_predictions.lm_head.bias                | UNEXPECTED |  | 
mask_predictions.dense.weight              | UNEXPECTED |  | 
lm_predictions.lm_head.dense.bias          | UNEXPECTED |  | 
lm_predictions.lm_head.dense.weight        | UNEXPECTED |  | 
mask_predictions.classifier.bias           | UNEXPECTED |  | 
mask_predictions.dense.bias                | UNEXPECTED |  | 
lm_predictions.lm_head.LayerNorm.bias      | UNEXPEC

In [5]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device  = "cpu"
model.to(device).half()

batch = SpanBatch(
    text_input_ids=torch.randint(0, 1000, (B, N), device=device),
    text_attention_mask=torch.ones(B, N, device=device, dtype=torch.long),
    label_input_ids=torch.randint(0, 1000, (B * M, L), device=device),
    label_attention_mask=torch.ones(B * M, L, device=device, dtype=torch.long),
    num_labels=M,
    span_targets=None,
)

out = model(batch)
print(out["logits"].shape)  # (B, S, M)

torch.Size([2, 1470, 20])


In [10]:
def build_batch(samples, tokenizer, max_span_len=12, neg_ratio=1, hard_bank=None):
    # hard_bank: dict[(text_id,label)] -> list[(span, score)]  可选

    texts = [s["sentence"] if "sentence" in s else s["text"] for s in samples]
    enc = tokenizer(
        texts,
        padding=True,
        truncation=True,
        return_offsets_mapping=True,
        return_tensors="pt"
    )

    label_instances = []

    for i, s in enumerate(samples):
        text = texts[i]
        entities = s["entities"]
        label_desc_map = {d["label"]: d["definitions"] for d in s["description"]}  # 你的格式

        gold_by_label = group_by_label([
            {"start": e["pos"][0], "end": e["pos"][1], "label": e["type"]}
            for e in entities
        ])

        # 1) 生成候选 spans（你可替换为更强的 proposal）
        cand_spans = propose_spans(enc["offset_mapping"][i], max_span_len=max_span_len)
        # cand_spans: list[(char_s,char_e)]

        # 2) 对该样本每个 label，生成一个 label_instance
        for label, pos_spans in gold_by_label.items():
            dx, desc = sample_desc(label_desc_map, label)

            # 2.1 负例池：候选中排除与任一正例重叠
            pos_set = pos_spans
            neg_pool = [sp for sp in cand_spans if all(not overlap(sp, p) for p in pos_set)]

            # 2.2 先随机负例
            need_neg = max(len(pos_spans) * neg_ratio, 1)
            neg_rand = random.sample(neg_pool, k=min(need_neg, len(neg_pool)))

            # 2.3 hard negatives（如果有）
            neg_hard = []
            if hard_bank is not None:
                key = (s.get("id", i), label)
                # hard_bank[key] = list[((s,e),score), ...]
                if key in hard_bank:
                    hard_sorted = [sp for (sp,score) in sorted(hard_bank[key], key=lambda x:-x[1])]
                    # 排除与正例重叠
                    hard_sorted = [sp for sp in hard_sorted if all(not overlap(sp,p) for p in pos_set)]
                    neg_hard = hard_sorted[: min(need_neg, len(hard_sorted))]

            # 2.4 混合：hard优先 + 少量随机补齐
            neg_final = []
            if neg_hard:
                hard_take = int(need_neg * 0.7)
                neg_final.extend(neg_hard[:hard_take])
                rest = need_neg - len(neg_final)
                if rest > 0:
                    # 从随机里补齐
                    neg_final.extend(neg_rand[:rest])
            else:
                neg_final = neg_rand

            label_instances.append({
                "sample_idx": i,
                "label": label,
                "desc_key": dx,
                "desc": desc,
                "pos_spans": pos_spans,
                "neg_spans": neg_final,
            })

    return {
        "texts": texts,
        "tokenized": enc,
        "label_instances": label_instances
    }
def propose_spans(offset_mapping, max_span_len=12):
    # offset_mapping: list[(cs,ce)] token->char
    spans = []
    n = len(offset_mapping)
    for i in range(n):
        if offset_mapping[i] == (0,0):  # padding
            continue
        for j in range(i, min(n, i+max_span_len)):
            cs, _ = offset_mapping[i]
            _, ce = offset_mapping[j]
            if ce <= cs: 
                continue
            spans.append((cs, ce))
    # 可去重
    spans = list(dict.fromkeys(spans))
    return spans


import random

def sample_desc(label_desc_map, label):
    defs = label_desc_map[label]  # {"D1":..., ...}
    k = random.choice(list(defs.keys()))  # D1~D6
    return k, defs[k]

def overlap(a,b):
    (s1,e1),(s2,e2)=a,b
    return not (e1 <= s2 or e2 <= s1)
from collections import defaultdict

def group_by_label(entities):
    mp = defaultdict(list)
    for e in entities:
        mp[e["label"]].append((e["start"], e["end"]))
    return mp


In [5]:
import json
import os
import tqdm
parent_dir = "../dataset/"
describe_filename_list = []
for item in os.listdir(parent_dir):
    if item.startswith("instruct_uie_ner_converted_description_"):
        describe_filename_list.append(item)
data = []
for item in tqdm.tqdm( describe_filename_list):
    file_path = os.path.join(parent_dir, item)
    with open(file_path,"r", encoding="utf-8") as f:
        content = f.readlines()
    for line in content:
        obj = json.loads(line.strip())
        data.append(obj)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 612/612 [00:04<00:00, 152.05it/s]


In [3]:
tokenizer = AutoTokenizer.from_pretrained(model_path)

The tokenizer you are loading from '/data/model_hub/mdeberta-v3-base/' with an incorrect regex pattern: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/discussions/84#69121093e8b480e709447d5e. This will lead to incorrect tokenization. You should set the `fix_mistral_regex=True` flag when loading this tokenizer to fix this issue.


In [11]:
build_batch(data[0:2],tokenizer)

{'texts': ['<resources>\n    <!-- Example customization of dimensions originally defined in res/values/dimens.xml\n         (such as screen margins) for screens with more than 820dp of available width. This\n         would include 7" and 10" devices in landscape (~960dp and ~1280dp respectively). -->\n    <dimen name="activity_horizontal_margin">64dp</dimen>\n</resources>',
 'tokenized': {'input_ids': tensor([[     1,   1043, 128743,    670,    260, 102962,    260,  43141,  24088,
           14535,    305,    260,  98979,   4704,    485,    260,  54629,    282,
            8602,    276,  22346,    276,    286, 102474,    261,  11395,    275,
          103545,    528,  10989,    260,   2017,    264,    272,    333,  10989,
             264,    515,   1098,   2422,  91114,    286,    326,    305,   4636,
             260,   1494,    261,   1495,    260,   2221,   9453,    618,    312,
             306,    476,    312,  33822,    282,    260,  33365, 148760,  16978,
             286,    3

NameError: name 'model' is not defined