In [3]:
"""
Colab Version: build_triage_datasets_colab.py

步骤：
1. 手动上传 triage.csv.gz 到 Colab (/content/triage.csv.gz)
2. 运行这段代码
3. 会在 /content 生成：
    - symptom_dataset_triage.csv
    - urgency_dataset_triage.csv
"""

import pandas as pd
from pathlib import Path

# ========= 你上传之后的文件路径（不用改） =========
TRIAGE_PATH = "/content/triage.csv.gz"

# ========= 输出路径 =========
OUT_DIR = Path("/content/triage_processed")
OUT_DIR.mkdir(parents=True, exist_ok=True)


# ========= acuity → urgency 映射函数 =========

def map_acuity_to_urgency(acuity):
    try:
        a = int(acuity)
    except (ValueError, TypeError):
        return None

    if a in (1, 2):
        return "Emergency"
    elif a == 3:
        return "Urgent"
    elif a == 4:
        return "Soon"
    elif a == 5:
        return "Routine"
    else:
        return None


# ========= 弱标签规则：自然语言 → symptom labels =========

def extract_symptoms(text: str):
    if not isinstance(text, str):
        return []

    t = text.lower()
    labels = []

    # chest pain
    if "chest" in t or "cp " in t or "cp," in t:
        labels.append("chest_pain")

    # abdominal pain
    if "abd" in t or "abdominal" in t or "belly pain" in t:
        labels.append("abdominal_pain")

    # back pain
    if "back pain" in t or "low back" in t or "lbp" in t:
        labels.append("back_pain")

    # headache
    if "headache" in t or " ha " in f" {t} ":
        labels.append("headache")

    # ear & throat
    if "ear pain" in t or "otalgia" in t:
        labels.append("ear_pain")
    if "sore throat" in t or "throat pain" in t:
        labels.append("throat_pain")

    # SOB
    if "sob" in t or "shortness of breath" in t or "dyspnea" in t:
        labels.append("shortness_of_breath")

    # GI (n/v/d)
    if "n/v/d" in t or "n/v" in t or "nv" in t or "nausea" in t:
        labels.append("nausea")
    if "n/v/d" in t or "n/v" in t or "vomit" in t or "emesis" in t:
        labels.append("vomiting")
    if "n/v/d" in t or "diarrhea" in t or "diarrhoea" in t:
        labels.append("diarrhea")

    # dizziness
    if "dizz" in t or "vertigo" in t:
        labels.append("dizziness")

    # fever
    if "fever" in t or "febrile" in t or "pyrexia" in t:
        labels.append("fever")

    return sorted(set(labels))


# ========= 主处理函数 =========

def process_triage():
    print(f"Loading triage file: {TRIAGE_PATH}")

    tri = pd.read_csv(
        TRIAGE_PATH,
        usecols=["subject_id", "stay_id", "acuity", "chiefcomplaint"],
        compression="infer"
    )

    # 去掉没有文本的行
    tri = tri.dropna(subset=["chiefcomplaint"])

    # 统一文本列
    tri["text"] = tri["chiefcomplaint"].astype(str)

    # 提取症状标签
    tri["symptom_labels"] = tri["text"].apply(extract_symptoms)

    # 映射 urgency（根据 acuity）
    tri["urgency"] = tri["acuity"].apply(map_acuity_to_urgency)

    # 只保留至少有一个 symptom 的样本
    tri = tri[tri["symptom_labels"].map(len) > 0]

    return tri


# ========= 导出两个数据集 =========

tri_df = process_triage()

# symptom dataset
symptom_out = OUT_DIR / "symptom_dataset_triage.csv"
tri_df[["subject_id", "stay_id", "text", "symptom_labels"]].to_csv(symptom_out, index=False)

# urgency dataset（要求 urgency 不为空）
urgency_df = tri_df[tri_df["urgency"].notna()]
urgency_out = OUT_DIR / "urgency_dataset_triage.csv"
urgency_df[["subject_id", "stay_id", "text", "symptom_labels", "urgency"]].to_csv(urgency_out, index=False)

print("Saved symptom dataset:", symptom_out)
print("Saved urgency dataset:", urgency_out)

print("\nUrgency distribution:")
print(urgency_df["urgency"].value_counts())



Loading triage file: /content/triage.csv.gz
Saved symptom dataset: /content/triage_processed/symptom_dataset_triage.csv
Saved urgency dataset: /content/triage_processed/urgency_dataset_triage.csv

Urgency distribution:
urgency
Urgent       97337
Emergency    64969
Soon          4761
Routine         72
Name: count, dtype: int64


In [4]:
!pip install -q transformers datasets accelerate scikit-learn


In [6]:
!pip install -q "transformers>=4.40.0"


In [9]:
import ast
import numpy as np
import pandas as pd
from pathlib import Path

import torch
from datasets import Dataset
from sklearn.metrics import f1_score

from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
)

# ========= 路径 =========
DATA_PATH = Path("/content/triage_processed/symptom_dataset_triage.csv")
MODEL_NAME = "emilyalsentzer/Bio_ClinicalBERT"   # BioClinicalBERT
OUTPUT_DIR = Path("/content/bioclinicalbert_symptom_model")


# ========= 1. 读取数据 =========
df = pd.read_csv(DATA_PATH)

print("Raw columns:", df.columns)
print("Example rows:")
print(df.head(5))

# symptom_labels 在 CSV 里是字符串形如 "['abdominal_pain', 'nausea']"
# 需要转回 Python list
def parse_labels(s):
    try:
        return ast.literal_eval(s)
    except Exception:
        return []

df["symptom_labels"] = df["symptom_labels"].apply(parse_labels)

# 只保留有标签的样本（理论上已经过滤过了，这里再保险一下）
df = df[df["symptom_labels"].map(len) > 0].reset_index(drop=True)
print("Samples after filtering:", len(df))


# ========= 2. 构建 label space（从数据里自动提取） =========
all_labels = set()
for labs in df["symptom_labels"]:
    all_labels.update(labs)

label_list = sorted(all_labels)
label2id = {l: i for i, l in enumerate(label_list)}
id2label = {i: l for l, i in label2id.items()}
num_labels = len(label_list)

print("Symptom labels:", label_list)
print("Number of labels:", num_labels)


# ========= 3. 把多标签转成 multi-hot 向量 =========
def labels_to_multihot(labels):
    vec = np.zeros(num_labels, dtype=np.float32)
    for l in labels:
        if l in label2id:
            vec[label2id[l]] = 1.0
    return vec

df["label_vec"] = df["symptom_labels"].apply(labels_to_multihot)


# ========= 4. 划分 train / valid =========
# 简单划分：前 90% 做 train，后 10% 做 valid
split_idx = int(0.9 * len(df))
train_df = df.iloc[:split_idx].reset_index(drop=True)
valid_df = df.iloc[split_idx:].reset_index(drop=True)

print("Train size:", len(train_df), "Valid size:", len(valid_df))


# ========= 5. 转成 HuggingFace Dataset =========
train_dataset = Dataset.from_pandas(
    train_df[["text", "label_vec"]],
    preserve_index=False
)
valid_dataset = Dataset.from_pandas(
    valid_df[["text", "label_vec"]],
    preserve_index=False
)

# ========= 6. 加载 tokenizer & model =========
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

def tokenize_batch(batch):
    enc = tokenizer(
        batch["text"],
        truncation=True,
        padding="max_length",
        max_length=64,   # chief complaint 很短，64 足够，后面可调
    )
    # 把 label_vec 作为 labels 传入模型
    enc["labels"] = batch["label_vec"]
    return enc

train_dataset_tok = train_dataset.map(tokenize_batch, batched=True)
valid_dataset_tok = valid_dataset.map(tokenize_batch, batched=True)

# 把 label_vec 列删除，避免 Trainer 报错
train_dataset_tok = train_dataset_tok.remove_columns(["label_vec", "text"])
valid_dataset_tok = valid_dataset_tok.remove_columns(["label_vec", "text"])

model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=num_labels,
    problem_type="multi_label_classification",
)


# ========= 7. 定义评价指标（micro/macro F1） =========
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    # BCEWithLogitsLoss 默认输出 raw logits，需要过 sigmoid
    probs = 1 / (1 + np.exp(-logits))
    preds = (probs > 0.5).astype(int)

    micro_f1 = f1_score(labels, preds, average="micro", zero_division=0)
    macro_f1 = f1_score(labels, preds, average="macro", zero_division=0)

    return {"micro_f1": micro_f1, "macro_f1": macro_f1}


# ========= 8. 设置训练参数 =========
training_args = TrainingArguments(
    output_dir=str(OUTPUT_DIR),
    learning_rate=2e-5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=64,
    num_train_epochs=1,      # 只跑 1 轮
    max_steps=500,           # 最多 500 个 step，就停
    weight_decay=0.01,
    logging_steps=50,
    do_eval=True,
    report_to="none",
)



# ========= 9. 构建 Trainer 并开训 =========
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset_tok,
    eval_dataset=valid_dataset_tok,
    compute_metrics=compute_metrics,
)

trainer.train()

# ========= 10. 保存模型和标签映射 =========
model.save_pretrained(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)

# 保存 label2id / id2label 方便部署
import json
with open(OUTPUT_DIR / "label_mapping.json", "w") as f:
    json.dump({"label2id": label2id, "id2label": id2label}, f)

print("Model & tokenizer saved to:", OUTPUT_DIR)


Raw columns: Index(['subject_id', 'stay_id', 'text', 'symptom_labels'], dtype='object')
Example rows:
   subject_id   stay_id                                      text  \
0    10000032  33258284            Abd pain, Abdominal distention   
1    10000032  35968195                           n/v/d, Abd pain   
2    10000032  38112554                      Abdominal distention   
3    10000032  39399961  Abdominal distention, Abd pain, LETHAGIC   
4    10000285  36555703                                  Abd pain   

                                      symptom_labels  
0                                 ['abdominal_pain']  
1  ['abdominal_pain', 'diarrhea', 'nausea', 'vomi...  
2                                 ['abdominal_pain']  
3                                 ['abdominal_pain']  
4                                 ['abdominal_pain']  
Samples after filtering: 167550
Symptom labels: ['abdominal_pain', 'back_pain', 'chest_pain', 'diarrhea', 'dizziness', 'ear_pain', 'fever', 'headache', '

Map:   0%|          | 0/150795 [00:00<?, ? examples/s]

Map:   0%|          | 0/16755 [00:00<?, ? examples/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at emilyalsentzer/Bio_ClinicalBERT and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss
50,0.4404
100,0.2545
150,0.1867
200,0.1421
250,0.123
300,0.1072
350,0.0983
400,0.0879


Step,Training Loss
50,0.4404
100,0.2545
150,0.1867
200,0.1421
250,0.123
300,0.1072
350,0.0983
400,0.0879
450,0.0866
500,0.0841


Model & tokenizer saved to: /content/bioclinicalbert_symptom_model


In [10]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

MODEL_DIR = "/content/bioclinicalbert_symptom_model"

tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
model.eval()


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

In [13]:
import json
import torch
from transformers import BertTokenizer, BertForSequenceClassification

MODEL_DIR = "/content/bioclinicalbert_symptom_model"

# 1. 加载 tokenizer 和 model
tokenizer = BertTokenizer.from_pretrained(MODEL_DIR)
model = BertForSequenceClassification.from_pretrained(MODEL_DIR)
model.eval()

# 2. 正确加载 label_mapping.json
with open(MODEL_DIR + "/label_mapping.json", "r") as f:
    mapping = json.load(f)

# 里面有两个 key： "label2id" 和 "id2label"
label2id = mapping["label2id"]                        # {'abdominal_pain': 0, ...}
id2label = {int(k): v for k, v in mapping["id2label"].items()}  # {0: 'abdominal_pain', ...}

print("label2id:", label2id)
print("id2label:", id2label)

# 3. 预测函数（多标签，用 sigmoid）
def predict_symptoms(text, threshold=0.3):
    inputs = tokenizer(text, return_tensors="pt", truncation=True)

    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits          # [1, num_labels]
        probs = torch.sigmoid(logits)[0] # 变成概率

    # 大于阈值的 index 当作“有这个症状”
    idx = (probs > threshold).nonzero(as_tuple=True)[0].tolist()
    labels = [id2label[i] for i in idx]

    # 也可以一起返回概率，方便你以后分析
    return {id2label[i]: float(probs[i]) for i in idx}

# 4. 简单测几句
tests = [
    "n/v/d and abd pain",
    "chest pain and shortness of breath",
    "fever and sore throat",
    "ear pain for two days",
]

for t in tests:
    print("\nInput:", t)
    print("Pred:", predict_symptoms(t))


Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


label2id: {'abdominal_pain': 0, 'back_pain': 1, 'chest_pain': 2, 'diarrhea': 3, 'dizziness': 4, 'ear_pain': 5, 'fever': 6, 'headache': 7, 'nausea': 8, 'shortness_of_breath': 9, 'throat_pain': 10, 'vomiting': 11}
id2label: {0: 'abdominal_pain', 1: 'back_pain', 2: 'chest_pain', 3: 'diarrhea', 4: 'dizziness', 5: 'ear_pain', 6: 'fever', 7: 'headache', 8: 'nausea', 9: 'shortness_of_breath', 10: 'throat_pain', 11: 'vomiting'}

Input: n/v/d and abd pain
Pred: {'abdominal_pain': 0.8479625582695007, 'diarrhea': 0.6592486500740051, 'nausea': 0.9087455868721008, 'vomiting': 0.9086993336677551}

Input: chest pain and shortness of breath
Pred: {'chest_pain': 0.8963233828544617, 'shortness_of_breath': 0.4727950096130371}

Input: fever and sore throat
Pred: {'fever': 0.7907466888427734, 'throat_pain': 0.451372891664505}

Input: ear pain for two days
Pred: {}


In [14]:
print(predict_symptoms("n/v/d and abd pain"))
print(predict_symptoms("chest pain and shortness of breath"))
print(predict_symptoms("fever, sore throat, cough"))
print(predict_symptoms("ear pain for two days"))
print(predict_symptoms("migraine, dizziness, and nausea"))


{'abdominal_pain': 0.8479625582695007, 'diarrhea': 0.6592486500740051, 'nausea': 0.9087455868721008, 'vomiting': 0.9086993336677551}
{'chest_pain': 0.8963233828544617, 'shortness_of_breath': 0.4727950096130371}
{'fever': 0.8454817533493042, 'throat_pain': 0.5239660739898682}
{}
{'dizziness': 0.7996644377708435, 'nausea': 0.6824076771736145, 'vomiting': 0.3083231449127197}
