In [2]:
%pip install pandas scikit-learn snorkel transformers

# **Cell 2: Load Data**
import pandas as pd

# Make sure 'MTS-Dialog-TrainingSet.csv' is uploaded to the notebook's working directory
df = pd.read_csv("MTS-Dialog-TrainingSet.csv")
df["text"] = df["section_text"].fillna("") + " " + df["dialogue"].fillna("")

# **Cell 3: Define Specialties & Seed Keywords**
SPECIALTIES = [
    "Anesthesiology","Cardiology","Dermatology","Emergency medicine",
    "Endocrinology","Gastroenterology","General practice","Geriatrics",
    "Gynecology","Hematology","Infectious disease","Neurology",
    "Nephrology","Oncology","Ophthalmology","Orthopedics",
    "Otolaryngology","Pathology","Pediatrics","Pulmonology",
    "Psychiatry","Radiology","Rheumatology","Surgery","Urology"
]

KEYWORDS = {
    "Anesthesiology":     ["anesthesia","analgesia","preoperative","pain management"],
    "Cardiology":         ["chest pain","palpitations","shortness of breath","edema","syncope","angina","tachycardia"],
    "Dermatology":        ["rash","pruritus","eczema","psoriasis","lesion"],
    "Emergency medicine": ["trauma","acute","dyspnea","unconscious"],
    "Endocrinology":      ["polyuria","polydipsia","weight loss","thyroid","hyperglycemia"],
    "Gastroenterology":   ["abdominal pain","nausea","vomiting","diarrhea","jaundice"],
    "General practice":   ["routine check","vaccination","primary care","screening"],
    "Geriatrics":         ["memory loss","fall","frailty","polypharmacy","mobility"],
    "Gynecology":         ["pelvic pain","menstruation","bleeding","breast lump"],
    "Hematology":         ["anemia","bleeding","bruise","thrombosis","lymphadenopathy"],
    "Infectious disease": ["fever","infection","sepsis","cough","diarrhea"],
    "Neurology":          ["headache","seizure","numbness","dizziness","weakness"],
    "Nephrology":         ["edema","proteinuria","hematuria","renal failure","electrolyte"],
    "Oncology":           ["mass","tumor","cancer","metastasis","weight loss"],
    "Ophthalmology":      ["vision change","red eye","blurred vision","floaters"],
    "Orthopedics":        ["fracture","joint pain","swelling","mobility"],
    "Otolaryngology":     ["ear pain","hearing loss","tinnitus","sore throat"],
    "Pathology":          ["biopsy","histology","malignancy","inflammation"],
    "Pediatrics":         ["fever","growth","development","vaccination"],
    "Pulmonology":        ["cough","wheezing","dyspnea","asthma","COPD"],
    "Psychiatry":         ["depression","anxiety","hallucination","suicidal"],
    "Radiology":          ["imaging","x-ray","ct","mri","ultrasound"],
    "Rheumatology":       ["joint pain","swelling","autoimmune","stiffness"],
    "Surgery":            ["operative","incision","surgical","resection"],
    "Urology":            ["dysuria","hematuria","incontinence","prostate"]
}

# **Cell 4: Augment Keywords via TF-IDF**
from sklearn.feature_extraction.text import TfidfVectorizer
import numpy as np

vectorizer = TfidfVectorizer(stop_words="english", max_features=5000)
vectorizer.fit(df["text"])
terms = vectorizer.get_feature_names_out()

for spec, seeds in list(KEYWORDS.items()):
    mask = df["text"].str.contains("|".join(seeds), case=False, na=False)
    if mask.sum() == 0: 
        continue
    sub = vectorizer.transform(df.loc[mask, "text"])
    scores = np.asarray(sub.sum(axis=0)).ravel()
    top5 = [terms[i] for i in scores.argsort()[::-1] if terms[i] not in seeds][:5]
    KEYWORDS[spec] += top5

print("✨ Example augmented Cardiology keywords:", KEYWORDS["Cardiology"])

# **Cell 5: Weak Supervision via Snorkel with Debug Coverage**
from snorkel.labeling import labeling_function, PandasLFApplier
from snorkel.labeling.model.label_model import LabelModel
from transformers import pipeline
import torch
# Cell 4: Weak Supervision via Snorkel
from snorkel.labeling import labeling_function, PandasLFApplier
from snorkel.labeling.model.label_model import LabelModel
from transformers import pipeline
import torch

ABSTAIN = -1
THRESH   = 0.60  # zero-shot confidence threshold

# Helper: keyword lookup
def keyword_lookup(x, keywords, label):
    txt = x.text.lower()
    return label if any(kw.lower() in txt for kw in keywords) else ABSTAIN

# Build keyword LFs
lfs = []
for spec in SPECIALTIES:
    idx = SPECIALTIES.index(spec)
    kws = KEYWORDS[spec]
    @labeling_function(name=f"lf_{spec.replace(' ','_')}")
    def lf(x, kws=kws, idx=idx):
        return keyword_lookup(x, kws, idx)
    lfs.append(lf)

# Zero-shot LF
device = 0 if torch.cuda.is_available() else -1
zs = pipeline("zero-shot-classification", model="facebook/bart-large-mnli", device=device)

@labeling_function(name="lf_zero_shot")
def lf_zero(x):
    out = zs(x.text, SPECIALTIES)
    lbl, score = out["labels"][0], out["scores"][0]
    return SPECIALTIES.index(lbl) if score > THRESH else ABSTAIN

lfs.append(lf_zero)

# Apply LFs → label matrix L
applier = PandasLFApplier(lfs=lfs)
L       = applier.apply(df)

# Debug: coverage per LF
coverage = (L != ABSTAIN).mean(axis=0) * 100
for lf, cov in zip(lfs, coverage):
    print(f"{lf.name:20s}: {cov:5.1f}%")
print("Rows with no LF fired:", (L.sum(axis=1) == ABSTAIN * len(lfs)).sum())

# Train Snorkel LabelModel
label_model = LabelModel(cardinality=len(SPECIALTIES), verbose=True)
label_model.fit(L_train=L, n_epochs=200, log_freq=50, seed=42)
# Cell 5: Map Predictions → Specialist Names & Save
# 1) Get raw predictions (numpy array)
preds = label_model.predict(L)

# 2) Map to strings via list comprehension
df["specialist_label"] = [SPECIALTIES[i] for i in preds]

# 3) (Optional) drop numeric pseudo_label if present
if "pseudo_label" in df.columns:
    df.drop(columns=["pseudo_label"], inplace=True)

# 4) Save and inspect
out_path = "MTS-Dialog-with-specialist-labels_2.csv"
df.to_csv(out_path, index=False)
print(f"✅ Saved labeled CSV: {out_path}")
print(df["specialist_label"].value_counts())


Note: you may need to restart the kernel to use updated packages.
✨ Example augmented Cardiology keywords: ['chest pain', 'palpitations', 'shortness of breath', 'edema', 'syncope', 'angina', 'tachycardia', 'patient', 'doctor', 'chest', 'pain', 'breath']


Device set to use cuda:0
100%|██████████| 1201/1201 [11:42<00:00,  1.71it/s]
INFO:root:Computing O...
INFO:root:Estimating \mu...


lf_Anesthesiology   : 100.0%
lf_Cardiology       : 100.0%
lf_Dermatology      : 100.0%
lf_Emergency_medicine: 100.0%
lf_Endocrinology    : 100.0%
lf_Gastroenterology : 100.0%
lf_General_practice : 100.0%
lf_Geriatrics       : 100.0%
lf_Gynecology       : 100.0%
lf_Hematology       : 100.0%
lf_Infectious_disease: 100.0%
lf_Neurology        : 100.0%
lf_Nephrology       : 100.0%
lf_Oncology         : 100.0%
lf_Ophthalmology    : 100.0%
lf_Orthopedics      : 100.0%
lf_Otolaryngology   : 100.0%
lf_Pathology        : 100.0%
lf_Pediatrics       : 100.0%
lf_Pulmonology      : 100.0%
lf_Psychiatry       : 100.0%
lf_Radiology        : 100.0%
lf_Rheumatology     : 100.0%
lf_Surgery          : 100.0%
lf_Urology          : 100.0%
lf_zero_shot        :   2.5%
Rows with no LF fired: 0


  0%|          | 0/200 [00:00<?, ?epoch/s]INFO:root:[0 epochs]: TRAIN:[loss=624.262]
 25%|██▌       | 50/200 [00:00<00:00, 165.84epoch/s]INFO:root:[50 epochs]: TRAIN:[loss=5.499]
 44%|████▎     | 87/200 [00:00<00:00, 174.64epoch/s]INFO:root:[100 epochs]: TRAIN:[loss=0.064]
 71%|███████   | 142/200 [00:00<00:00, 177.56epoch/s]INFO:root:[150 epochs]: TRAIN:[loss=0.000]
100%|██████████| 200/200 [00:01<00:00, 174.21epoch/s]
INFO:root:Finished Training


✅ Saved labeled CSV: MTS-Dialog-with-specialist-labels_2.csv
specialist_label
Surgery               1182
Psychiatry               5
Neurology                3
Gastroenterology         2
Infectious disease       2
General practice         1
Oncology                 1
Orthopedics              1
Hematology               1
Urology                  1
Dermatology              1
Emergency medicine       1
Name: count, dtype: int64
