<a href="https://colab.research.google.com/github/s1300200/GT/blob/main/Reproduction444444444.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip -q install fugashi ipadic unidic-lite transformers>=4.41.0 accelerate torch pandas numpy tqdm

In [2]:
import unicodedata as ud
import xml.etree.ElementTree as ET
from dataclasses import dataclass
from typing import List, Tuple, Dict, Optional
import pandas as pd, numpy as np, re
from tqdm import tqdm


In [3]:
@dataclass
class Span:
    start: int
    end:   int
    label: str              # 互換用（使わなくてもOK）
    attrs: Dict[str, str]   # certainty/state/type 等
    tag:   str              # 生タグ（"m-key","r","timex3",...）


In [4]:
def tag_to_label(tag: str, attrs: Dict[str, str]) -> str:
    # 旧ラベル（互換用）。以降は canonical を使うので最低限。
    if tag == "d":
        cert = attrs.get("certainty")
        return {"positive":"D_POS","negative":"D_NEG","suspicious":"D_SUS","general":"D_GEN"}.get(cert,"D")
    if tag in {"r","cc"}:    return "RX_EXECUTED"
    if tag == "t-test":      return "TEST"
    if tag == "t-key":       return "TEST_KEY"
    if tag == "t-val":       return "TEST_VAL"
    if tag == "timex3":      return "TIME_" + attrs.get("type","TIME").upper()
    if tag == "m-key":       return "MED"
    if tag == "m-val":       return "MED"
    if tag == "a":           return "ANAT"
    if tag == "f":           return "FIND_MOD"
    if tag == "c":           return "COURSE"
    return tag.upper()

def extract_text_and_spans(elem: ET.Element) -> Tuple[str, List[Span]]:
    def norm(s: Optional[str]) -> str:
        return ud.normalize("NFKC", s) if s else ""
    text_parts: List[str] = []
    spans: List[Span] = []
    cursor = 0
    def append_text(t: Optional[str]):
        nonlocal cursor
        nt = norm(t)
        if nt:
            text_parts.append(nt)
            cursor += len(nt)
    def walk(node: ET.Element):
        nonlocal cursor
        append_text(node.text)
        for child in list(node):
            s0 = cursor
            walk(child)
            s1 = cursor
            if s1 > s0:
                spans.append(Span(s0, s1, tag_to_label(child.tag, child.attrib),
                                  dict(child.attrib), child.tag))
            append_text(child.tail)
    walk(elem)
    return "".join(text_parts), spans


In [6]:
def parse_medtxt_cr(xml_path: str):
    tree = ET.parse(xml_path)
    root = tree.getroot()
    rows_article, rows_span = [], []
    for art in root.findall(".//article"):
        aid = art.attrib.get("id")
        title = art.attrib.get("title","")
        text, spans = extract_text_and_spans(art)
        rows_article.append({"article_id": aid, "title": title, "text": text})
        for sp in spans:
            rows_span.append({
                "article_id": aid,
                "start": sp.start,
                "end":   sp.end,
                "mention": text[sp.start:sp.end],
                "label": sp.label,    # 互換用
                "tag":   sp.tag,      # ★ 生タグ
                **{f"attr_{k}": v for k,v in sp.attrs.items()}  # certainty/state/type など
            })
    return pd.DataFrame(rows_article), pd.DataFrame(rows_span)

# ★ XMLパスを合わせて実行
XML_PATH = "/content/drive/MyDrive/学士/MedTxt-CR-JA-training-pub.xml"
df_article, df_span = parse_medtxt_cr(XML_PATH)
df_article.to_csv("/content/drive/MyDrive/学士/articles.csv", index=False)
df_span.to_csv("/content/drive/MyDrive/学士/spans.csv", index=False)
len(df_article), len(df_span)


(148, 8477)

In [7]:
def gold_to_canonical(tag: str, row: dict) -> str:
    t = tag
    state     = str(row.get("attr_state",     "")).lower()
    certainty = str(row.get("attr_certainty", "")).lower()
    timetype  = str(row.get("attr_type",      "")).lower()

    if t in {"a","c","f","t-val"}:
        return t

    if t == "d":
        if certainty in {"general","negative","positive","suspicious"}:
            return f"d_{certainty}"
        return "d"

    if t in {"cc","m-key","m-val","r","t-key","t-test"}:
        if state in {"executed","negated","other","scheduled"}:
            return f"{t}_{state}"
        return t

    if t == "timex3":
        if timetype in {"age","date","duration","med","misc","set","time"}:
            return f"timex3_{timetype}"
        return "timex3"

    return t


In [9]:
def split_sentences(text: str):
    spans, start = [], 0
    for m in re.finditer(r"(.*?。)", text, flags=re.DOTALL):
        s,e = m.span(); spans.append((s,e,text[s:e])); start = e
    if start < len(text):
        spans.append((start, len(text), text[start:]))
    return spans

@dataclass
class GSpan:
    start: int
    end:   int
    canon: str  # ★ 正規形

ART_CSV = "/content/drive/MyDrive/学士/articles.csv"
SPN_CSV = "/content/drive/MyDrive/学士/spans.csv"
df_article = pd.read_csv(ART_CSV)
df_span    = pd.read_csv(SPN_CSV)

assert {"article_id","text"}.issubset(df_article.columns)
assert {"article_id","start","end","tag"}.issubset(df_span.columns)

samples = []
for _, a in df_article.iterrows():
    aid, text = a["article_id"], str(a["text"])
    sents = split_sentences(text)
    gold_rows = df_span[df_span["article_id"]==aid]
    for sid, (s0,s1,stext) in enumerate(sents):
        gspans = []
        for _, gr in gold_rows.iterrows():
            gs, ge = int(gr["start"]), int(gr["end"])
            if ge <= s0 or s1 <= gs: continue
            ss = max(gs, s0) - s0
            ee = min(ge, s1) - s0
            if ee <= ss: continue
            canon = gold_to_canonical(str(gr["tag"]), gr.to_dict())
            gspans.append(GSpan(ss, ee, canon))
        samples.append({"article_id": aid, "sent_id": sid, "text": stext, "gold_spans": gspans})
len(samples)


1247

In [10]:
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification

MODEL = "sociocom/MedTXTNER"
tokenizer = AutoTokenizer.from_pretrained(MODEL, use_fast=True)
model     = AutoModelForTokenClassification.from_pretrained(MODEL)
device    = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device).eval()
id2label = model.config.id2label

def _greedy_offsets_from_tokens(text: str, tokens: List[str]) -> List[Tuple[int,int]]:
    offs, i, n = [], 0, len(text)
    for tok in tokens:
        if tok == "[UNK]":
            j = min(i+1, n); offs.append((i,j)); i=j; continue
        core = tok[2:] if tok.startswith("##") else tok
        if not core: offs.append((i,i)); continue
        found = False
        while i < n:
            if text.startswith(core, i):
                offs.append((i, i+len(core))); i += len(core); found = True; break
            i += 1
        if not found:
            j = min(i+1, n); offs.append((i,j)); i=j
    return offs

def _tokenize_with_offsets(text: str):
    if getattr(tokenizer, "is_fast", False):
        enc = tokenizer(text, return_offsets_mapping=True, add_special_tokens=False)
        return tokenizer.convert_ids_to_tokens(enc["input_ids"]), enc["offset_mapping"], enc["input_ids"]
    else:
        toks = tokenizer.tokenize(text)
        offs = _greedy_offsets_from_tokens(text, toks)
        ids  = tokenizer.convert_tokens_to_ids(toks)
        return toks, offs, ids

def bio_to_spans(offsets: List[Tuple[int,int]], bio: List[str]) -> List[Tuple[int,int,str]]:
    spans=[]; cur=None; st=None
    def close(i):
        nonlocal cur, st
        s=offsets[st][0]; e=offsets[i-1][1]; spans.append((s,e,cur)); cur=None; st=None
    for i,lab in enumerate(bio):
        if lab=="O":
            if cur is not None: close(i); continue
        elif lab.startswith("B-"):
            if cur is not None: close(i)
            cur=lab[2:]; st=i
        elif lab.startswith("I-"):
            ln=lab[2:]
            if cur is None or ln!=cur:
                if cur is not None: close(i)
                cur=ln; st=i
    if cur is not None: close(len(bio))
    return spans

def predict_spans_windows(text: str, max_len=480, stride=120) -> List[Tuple[int,int,str]]:
    toks, offs, ids = _tokenize_with_offsets(text)
    if not ids: return []
    spans_pred=[]
    i=0
    while i < len(ids):
        j = min(i+max_len, len(ids))
        ids_sub  = ids[i:j]
        offs_sub = offs[i:j]
        with torch.no_grad():
            out  = model(torch.tensor([ids_sub]).to(device))
            bio  = [id2label[k] for k in out.logits.argmax(-1)[0].detach().cpu().tolist()]
        spans_pred.extend(bio_to_spans(offs_sub, bio))
        if j == len(ids): break
        i = j - stride
    return list(set(spans_pred))  # 重複除去


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/695 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/443M [00:00<?, ?B/s]

In [11]:
def pred_to_canonical(lbl: str) -> str:
    y = lbl.strip().lower().replace(" ", "")
    y = (y.replace("t_key","t-key").replace("tkey","t-key")
           .replace("t_val","t-val").replace("tval","t-val")
           .replace("m_key","m-key").replace("mval","m-val").replace("m_val","m-val")
           .replace("timex3_", "timex3-")
           .replace("__","_").replace("-_","_").replace("_-","_"))
    # d
    if y in {"d"}: return "d"
    if y in {"d_positive","d+","d-pos","d-positive"}:   return "d_positive"
    if y in {"d_negative","d-","d-neg","d-negative"}:   return "d_negative"
    if y in {"d_suspicious","d-suspicious","d_sus"}:    return "d_suspicious"
    if y in {"d_general","d-gen"}:                      return "d_general"
    # 属性なし
    if y in {"a"}: return "a"
    if y in {"c"}: return "c"
    if y in {"f"}: return "f"
    if y in {"t-val"}: return "t-val"
    # state系
    for base in ["cc","m-key","m-val","r","t-key","t-test"]:
        if y.startswith(base):
            part = y[len(base):].lstrip("-_")
            if part in {"executed","negated","other","scheduled"}:
                return f"{base}_{part}"
            return base
    # timex3
    if y.startswith("timex3"):
        part = y.split("-",1)[1] if "-" in y else ""
        if part in {"age","date","duration","med","misc","set","time"}:
            return f"timex3_{part}"
        return "timex3"
    return y


In [12]:
from collections import Counter, defaultdict

def prf(tp, fp, fn):
    P = tp/(tp+fp) if (tp+fp)>0 else 0.0
    R = tp/(tp+fn) if (tp+fn)>0 else 0.0
    F = 2*P*R/(P+R) if (P+R)>0 else 0.0
    return P,R,F

def to_type_only(canon: str) -> str:
    return canon.split("_", 1)[0]  # "m-key_executed"→"m-key", "d_positive"→"d", "timex3_date"→"timex3"

def evaluate(samples, mode="strict"):
    micro_tp=micro_fp=micro_fn=0
    for s in tqdm(samples, desc=f"eval:{mode}"):
        text = s["text"]; gold = s["gold_spans"]  # GSpan.canon を持つ
        pred_raw = predict_spans_windows(text)
        pred = [(ps,pe, pred_to_canonical(pl)) for (ps,pe,pl) in pred_raw]

        if mode == "strict":
            gold_use = [(g.start,g.end,g.canon) for g in gold]
            pred_use = [(ps,pe,l)               for (ps,pe,l) in pred]
        elif mode == "type":
            gold_use = [(g.start,g.end,to_type_only(g.canon)) for g in gold]
            pred_use = [(ps,pe,to_type_only(l))               for (ps,pe,l) in pred]
        else:
            raise ValueError("mode must be 'strict' or 'type'")

        GS, PS = set(gold_use), set(pred_use)
        TP = len(GS & PS); FP = len(PS - GS); FN = len(GS - PS)
        micro_tp += TP; micro_fp += FP; micro_fn += FN

    P,R,F = prf(micro_tp,micro_fp,micro_fn)
    return {"micro":{"precision":P,"recall":R,"f1":F,"tp":micro_tp,"fp":micro_fp,"fn":micro_fn}}


In [13]:
res_strict = evaluate(samples, mode="strict")
res_type   = evaluate(samples, mode="type")

print("=== Strict (attribute-aware) ===", res_strict)
print("=== Type-only (attribute-agnostic) ===", res_type)


eval:strict: 100%|██████████| 1247/1247 [06:01<00:00,  3.45it/s]
eval:type: 100%|██████████| 1247/1247 [05:38<00:00,  3.69it/s]

=== Strict (attribute-aware) === {'micro': {'precision': 0.5180149558123726, 'recall': 0.5393417482599976, 'f1': 0.5284632722649252, 'tp': 4572, 'fp': 4254, 'fn': 3905}}
=== Type-only (attribute-agnostic) === {'micro': {'precision': 0.7145289443813848, 'recall': 0.7425976170815147, 'f1': 0.7282929368889918, 'tp': 6295, 'fp': 2515, 'fn': 2182}}



