In [1]:

import os, torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

model_id = r"/home/st426/system/global_graph/Biomistral-Calme-Instruct-7b"

# 4bit 量化設定
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,   
)

tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token = tokenizer.eos_token


device_map = {"": "cuda:0"}

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map=device_map,        
    torch_dtype=torch.float16,   
    low_cpu_mem_usage=True,
)
print("hf_device_map:", getattr(model, "hf_device_map", None))
print("any param device:", next(model.parameters()).device)
print("cuda available:", torch.cuda.is_available())
print("GPU:", torch.cuda.get_device_name(0))
print("Capability:", torch.cuda.get_device_capability(0))
print("GPU memory:", torch.cuda.get_device_properties(0).total_memory / 1024**3, "GB")



Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

hf_device_map: {'': 'cuda:0'}
any param device: cuda:0
cuda available: True
GPU: NVIDIA TITAN V
Capability: (7, 0)
GPU memory: 11.77166748046875 GB


In [None]:
# ==== English-only test: surgical margin extraction (robust) ====
import re, json, torch
from transformers import StoppingCriteria, StoppingCriteriaList
from contextlib import nullcontext
from transformers import PreTrainedTokenizerBase
from jinja2 import TemplateError

if getattr(tokenizer, "pad_token_id", None) is None:
    tokenizer.pad_token = tokenizer.eos_token

rules_dict = {
    "0": "Positive surgical margin. Pathology report describes 'involved'. Less than 1 mm, margin explicitly reported as positive.",
    "1~979": "Negative surgical margin. Record the actual margin distance in 0.1 mm units. Example: 10 mm = 100, 0.1 mm = 001.",
    "980": "Margin distance greater than 98 mm.",
    "987": "Described only as 'very close' or 'may not be free', but no exact margin distance provided.",
    "988": "Not applicable: No primary tumor resection performed.",
    "990": "No residual tumor after re-resection, margin distance unclear.",
    "991": "Surgical margin described as non-invasive carcinoma (carcinoma in situ or similar).",
    "999": "Unknown if patient received primary tumor resection or margin distance not documented."
}

test_report ="""

Pathologic diagnosis: Pathologic diagnosis: Lung  right lower lobe  permanent section of frozen section --- 1. Lung  right lower lobe  wedge resection ---- Adenocarcinoma  acinar predominant (acinar pattern: 90%; lepidic pattern: 10 %). The Atypical glands. resection margin and visceral pleura are free of tumor. 2. Lung  right lower lobe 2  wedge resection ---- Intrapulmonary Tentative frozen section diagnosis: Atypical glands. lymph node (0/1). 3. Lymph node  group 2  dissection ---- Anthracosis and Gross description: granulomatous inflammation  no tumor seen (0/4). The specimen consists of a piece of tan soft tissue  0.1x0.1x0.1 cm  4. Lymph node  group 4  dissection ---- Anthracosis and labeled as """"right lower lobe"""" for frozen section. The tentative granulomatous inflammation  no tumor seen (0/8). frozen section diagnosis is """"Atypical glands./An"""". All for section. 5. Lymph node  group 7  dissection ---- Anthracosis and granulomatous inflammation  no tumor seen (0/12). 6. Lymph node  group 9  dissection ---- Anthracosis  no tumor seen Microscopic description: Sections show pulmonary tissue with atypical glands. (0/5). #T-28000_2 #M-09350_2 2 3082 VT000F Ancillary study for diagnosis: 1. Special stains for VVG (no destruction) done for section B. 2. Immunohistochemical stains for thyroid transcription factor-1 (+) done for section B. 3. PD-L1 IHC staining done for section B  result: <1 % positive (1+: <1%  2+: 0%  3+: 0%) Prognostic and predictive factor: 1. Total tumor size/Invasive tumor size (required only if invasive nonmucinous adenocarcinomas with lepidic component is present): 1.1 cm/ 0.9cm. 2. Tumor focality: Single tumor. 3. Direct Invasion of Adjacent Structures: Not applicable. 4. Lymphovascular invasion: Not identified. 5. Perineural invasion: Not identified. 6. Surgical margin: (1) Bronchial Margin: Not applicable. (2) Vascular Margin: Not applicable. (3) Parenchymal Margin: Uninvolved by invasive carcinoma. 7. Visceral Pleural status: Uninvolved (PL0). 8. Regional lymph nodes: Number of Lymph Nodes Involved: 0/Number of Lymph Nodes Examined: 30. 9. Extranodal extension: Not identified. 10. Treatment effect: Not applicable. 11. Spread Through Air Spaces (STAS): Not identified. 12. Pathological TNM stage: pT1aN0 (According to the eighth edition  American Joint Committee on Cancer Staging Guidelines for Tumors). 13. TNM descriptors: Not applicable. Gross description: The specimen consists of 1) """"RLL""""  a wedge-shaped lung tissue  4.5x3x2.5 cm. One incision is noted over the visceral pleura with exposure of a tan soft to firm tumor  1.1x0.9 cm. The lesion is localized near the visceral pleura. The distance of tumor to parenchymal margin is 10 mm grossly. The residual lung parenchyma is grossly unremarkable. 2) """"RLL 2""""  a wedge-shaped lung tissue  3.5x1x1 cm. One incision is noted over the visceral pleura with exposure of a tan soft to firm nodule  0.7x0.5 cm. The lesion is localized near the visceral pleura. The distance of tumor to parenchymal margin is 5 mm grossly. 3) LN2  3 pieces of tan brown black soft to firm tissue  up to 2x1.3x1.3 cm. 4) LN4  2 pieces of tan brown black soft to firm tissue  up to 2.5x1x1 cm. 5) LN7  3 pieces of tan brown black soft to firm tissue  up to 3.5x2.5x1 cm. 6) LN9  4 pieces of tan brown black soft to firm tissue  up to 1x1x1 cm. Representative sections taken: A) parenchymal margin of specimen 1 B-C) lesion of specimen 1 D) lung parenchyma of specimen 1 E) parenchymal margin of specimen 2 F) all lung tissue of specimen 2 G) specimen 3 H) specimen 4 J) specimen 5 K) specimen 6. Note: 1. Material: Specimen: Formalin-fixed paraffin embedded (10% neutral buffered formalin). Time to fixation: Between 6 and 72 hours. 2. Method: Clone: SP263 monoclonal antibody. Staining System: Ventana medical system. 3. This case has been peer reviewed by two doctors. #T-28400_2 #M-81403_2 2 1103 00000F #T-C4300_2 #D2-53020_2 0 1100 000000


"""

rules_text = "\n".join([f"{c}: {desc}" for c, desc in rules_dict.items()])


system_msg = (
    "You are a pathology coding assistant. Respond with ONLY a valid JSON array."
    " Do not include any explanations or extra text."
    " Begin your answer with '[' and end it with ']'."
)
user_msg = f"""
Here are the surgical margin coding rules:
{rules_text}

Report (already pre-filtered for margin-related sentences):
{test_report}

Instructions:
- Read the report carefully.
- Only output sentences that mention surgical margins or distances.
- For each, create an object with:
  - "sentence": the exact sentence
  - "distance_raw": the numeric distance if present (e.g. "9 mm", "0.5 cm"), otherwise null
  - "code": assign the proper rule code (from the list above)

Output requirements:
- Output MUST be ONLY a valid JSON array.
- Each element must follow: {{"sentence": "...", "code": "..."}}.
"""




def build_prompt(tokenizer: PreTrainedTokenizerBase, system_msg: str, fewshot_json: str, user_msg: str) -> str:
    # messages：user/assistant/user 交替；system 放第一個
    messages = [
        {"role": "system", "content": system_msg},
        {"role": "user", "content": "Follow the instructions and reply ONLY with a valid JSON array."},
        {"role": "assistant", "content": fewshot_json},  
        {"role": "user", "content": user_msg},
    ]
    try:
        if getattr(tokenizer, "chat_template", None):
            return tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
        raise TemplateError("no chat_template")
    except Exception as e:
        print("chat_template failed, falling back to [INST]:", e)
        # Fallback：Mistral/Llama 指令格式
        return (
            f"<s>[INST] {system_msg}\n\n"
            f"{fewshot_json}\n\n"
            f"{user_msg} [/INST]"
        )


fewshot_json = """[
  {"sentence": "Example sentence.", "code": "10"}
]"""

prompt_text = build_prompt(tokenizer, system_msg, fewshot_json, user_msg)



class StopOnValidJsonArray(StoppingCriteria):
    def __init__(self, tokenizer, prompt_len):
        self.tok = tokenizer
        self.prompt_len = prompt_len
    def __call__(self, input_ids, scores, **kwargs):
        gen_ids = input_ids[0][self.prompt_len:]
        text = self.tok.decode(gen_ids, skip_special_tokens=True)
        m = re.search(r"\[\s*(?:.|\n)*\]", text)
        if not m:
            return False
        try:
            json.loads(m.group(0))
            return True
        except Exception:
            return False

def extract_json_only(gen_text: str):
    m = re.search(r"\[\s*(?:.|\n)*\]", gen_text)
    if not m:
        return []
    return json.loads(m.group(0))


inputs = tokenizer(prompt_text, return_tensors="pt").to(model.device)
stopper = StopOnValidJsonArray(tokenizer, prompt_len=inputs["input_ids"].shape[1])

sdpa_ctx = nullcontext()
if hasattr(torch.backends.cuda, "sdp_kernel"):
    sdpa_ctx = torch.backends.cuda.sdp_kernel(
        enable_flash=False, enable_mem_efficient=False, enable_math=True
    )
if hasattr(torch.backends.cuda, "enable_flash_sdp"):
    torch.backends.cuda.enable_flash_sdp(False)
    torch.backends.cuda.enable_mem_efficient_sdp(False)
    torch.backends.cuda.enable_math_sdp(True)


with torch.no_grad(), sdpa_ctx:
    out_ids = model.generate(
        **inputs,
        max_new_tokens=1024,
        do_sample=False,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
        stopping_criteria=StoppingCriteriaList([stopper]),
    )

gen_text = tokenizer.decode(out_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)

print(gen_text)



chat_template failed, falling back to [INST]: Conversation roles must alternate user/assistant/user/assistant/...
