In [1]:
"""
Demo: BioClinicalBERT symptom classifier prototype

Steps:
1. Install packages
2. Load BioClinicalBERT model + tokenizer
3. Load symptom label mapping (label_mapping.json)
4. Run predict_symptoms() on example texts
"""

import json
from pathlib import Path


In [2]:
!pip install -q "transformers>=4.40.0" torch


In [3]:
# 下载你 GitHub 里的 label_mapping.json 到当前 /content 目录
!wget -O label_mapping.json https://raw.githubusercontent.com/toxiclee/medBert-triage-prototype/main/label_mapping.json


--2025-11-19 02:53:10--  https://raw.githubusercontent.com/toxiclee/medBert-triage-prototype/main/label_mapping.json
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 454 [text/plain]
Saving to: ‘label_mapping.json’


2025-11-19 02:53:11 (8.68 MB/s) - ‘label_mapping.json’ saved [454/454]



In [4]:
import torch
from transformers import (
    AutoConfig,
    AutoTokenizer,
    AutoModelForSequenceClassification,
)

# 1. 选择基座模型（medBERT = BioClinicalBERT）
BASE_MODEL = "emilyalsentzer/Bio_ClinicalBERT"

# 2. 读取你自己的 symptom label 映射
with open("label_mapping.json", "r") as f:
    mapping = json.load(f)

label2id = mapping["label2id"]                      # {'abdominal_pain': 0, ...}
id2label = {int(k): v for k, v in mapping["id2label"].items()}  # {0: 'abdominal_pain', ...}

num_labels = len(label2id)
print("Loaded labels:", label2id)

# 3. 加载 tokenizer
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)

# 4. 配置 + 加载分类模型：
#    - encoder 权重来自 BioClinicalBERT
#    - classification head 为 num_labels 维（未训练，仅作 prototype）
config = AutoConfig.from_pretrained(
    BASE_MODEL,
    num_labels=num_labels,
    id2label=id2label,
    label2id=label2id,
)

model = AutoModelForSequenceClassification.from_pretrained(
    BASE_MODEL,
    config=config,
)
model.eval()

print("Model loaded.")

# 5. 预测函数（多标签，用 sigmoid）
def predict_symptoms(text: str, threshold: float = 0.3):
    inputs = tokenizer(text, return_tensors="pt", truncation=True)

    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits          # [1, num_labels]
        probs = torch.sigmoid(logits)[0] # (num_labels,)

    # 大于阈值的 index 当作“有这个症状”
    idx = (probs > threshold).nonzero(as_tuple=True)[0].tolist()
    return {id2label[i]: float(probs[i]) for i in idx}

# 6. Demo 几个例子
tests = [
    "n/v/d and abd pain",
    "chest pain and shortness of breath",
    "fever, sore throat, cough",
    "ear pain for two days",
    "migraine, dizziness, and nausea",
]

for t in tests:
    print("\nInput:", t)
    print("Pred:", predict_symptoms(t))


Loaded labels: {'abdominal_pain': 0, 'back_pain': 1, 'chest_pain': 2, 'diarrhea': 3, 'dizziness': 4, 'ear_pain': 5, 'fever': 6, 'headache': 7, 'nausea': 8, 'shortness_of_breath': 9, 'throat_pain': 10, 'vomiting': 11}


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/436M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/436M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at emilyalsentzer/Bio_ClinicalBERT and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Model loaded.

Input: n/v/d and abd pain
Pred: {'abdominal_pain': 0.48409751057624817, 'back_pain': 0.4219190776348114, 'chest_pain': 0.4552541673183441, 'diarrhea': 0.5531032681465149, 'dizziness': 0.3884950876235962, 'ear_pain': 0.6057761311531067, 'fever': 0.40676701068878174, 'headache': 0.6029843688011169, 'nausea': 0.6829037070274353, 'shortness_of_breath': 0.4457143545150757, 'throat_pain': 0.4903886616230011, 'vomiting': 0.5709765553474426}

Input: chest pain and shortness of breath
Pred: {'abdominal_pain': 0.47913312911987305, 'back_pain': 0.39677825570106506, 'chest_pain': 0.44267547130584717, 'diarrhea': 0.5576533675193787, 'dizziness': 0.37364351749420166, 'ear_pain': 0.5813724398612976, 'fever': 0.39573290944099426, 'headache': 0.6217104196548462, 'nausea': 0.703911304473877, 'shortness_of_breath': 0.42962196469306946, 'throat_pain': 0.509831428527832, 'vomiting': 0.5769065022468567}

Input: fever, sore throat, cough
Pred: {'abdominal_pain': 0.47081294655799866, 'back_pain