In [1]:
#Imports and basic setup
import os
import re
import json

import pandas as pd
import numpy as np

from sklearn.feature_extraction.text import TfidfVectorizer, ENGLISH_STOP_WORDS

# We will use these later for modeling, but ok to import now
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report

import joblib
from tqdm import tqdm

pd.set_option("display.max_colwidth", 200)


In [2]:
#Load train, validation, test datasets
train_df = pd.read_json("train.json", lines=True)
val_df   = pd.read_json("validation.json", lines=True)
test_df  = pd.read_json("test.json", lines=True)

print("Train shape:", train_df.shape)
print("Val shape:  ", val_df.shape)
print("Test shape: ", test_df.shape)

train_df.head()

Train shape: (9250, 5)
Val shape:   (500, 5)
Test shape:  (250, 5)


Unnamed: 0,dialogue,soap,prompt,messages,messages_nosystem
0,"Doctor: Hello, how can I help you today?\nPatient: My son has been having some issues with speech and development. He's 13 years old now.\nDoctor: I see. Can you tell me more about his symptoms? D...",S: The patient's mother reports that her 13-year-old son has mild to moderate speech and developmental delays and has been diagnosed with attention deficit disorder. She denies any issues with mus...,"Create a Medical SOAP note summary from the dialogue, following these guidelines:\n S (Subjective): Summarize the patient's reported symptoms, including chief complaint and relevant history. Re...","[{'role': 'system', 'content': 'You are an expert medical professor assisting in the creation of medically accurate SOAP summaries. Please ensure the response follows the structured format: S:, O:...","[{'role': 'user', 'content': 'You are an expert medical professor assisting in the creation of medically accurate SOAP summaries. Please ensure the response follows the structured format: S:, O:, ..."
1,"Doctor: Hello, what brings you in today?\nPatient: Hi, my 21-month-old son has been experiencing weakness in his lower extremities and lumbar pain after a mild upper respiratory tract infection.\n...","S: The patient, a 21-month-old male, presented with weakness in his lower extremities and lumbar pain following a mild upper respiratory tract infection. Initial treatment with anti-inflammatory t...","Create a Medical SOAP note summary from the dialogue, following these guidelines:\n S (Subjective): Summarize the patient's reported symptoms, including chief complaint and relevant history. Re...","[{'role': 'system', 'content': 'You are an expert medical professor assisting in the creation of medically accurate SOAP summaries. Please ensure the response follows the structured format: S:, O:...","[{'role': 'user', 'content': 'You are an expert medical professor assisting in the creation of medically accurate SOAP summaries. Please ensure the response follows the structured format: S:, O:, ..."
2,"Doctor: Hello, how can I help you today?\nPatient: Hi, doctor. I came here because for the past 2 months, I've been experiencing fatigue, night sweats, weight loss, loss of appetite, and mild abdo...","S: Patient reports experiencing fatigue, night sweats, weight loss, loss of appetite, and mild abdominal discomfort for the past 2 months. No fever, chills, cough, nausea, vomiting, itching, or ur...","Create a Medical SOAP note summary from the dialogue, following these guidelines:\n S (Subjective): Summarize the patient's reported symptoms, including chief complaint and relevant history. Re...","[{'role': 'system', 'content': 'You are an expert medical professor assisting in the creation of medically accurate SOAP summaries. Please ensure the response follows the structured format: S:, O:...","[{'role': 'user', 'content': 'You are an expert medical professor assisting in the creation of medically accurate SOAP summaries. Please ensure the response follows the structured format: S:, O:, ..."
3,"Doctor: Hello, Patient D. How are you feeling today?\nPatient D: I'm feeling fine, doc, no complaints. Just here for a regular check-up.\nDoctor: That's good to hear. I see that you are a 60-year-...","S: Patient D, a 60-year-old African American male, reports no current symptoms and is visiting for a routine check-up. He has a family history significant for prostate cancer, as his 62-year-old b...",Create a medical SOAP summary of this dialogue.,"[{'role': 'system', 'content': 'You are an expert medical professor assisting in the creation of medically accurate SOAP summaries. Please ensure the response follows the structured format: S:, O:...","[{'role': 'user', 'content': 'You are an expert medical professor assisting in the creation of medically accurate SOAP summaries. Please ensure the response follows the structured format: S:, O:, ..."
4,"Doctor: Hello, I see that you have a history of two early miscarriages and no long-term pregnancy. Can you please tell me about any symptoms you've been experiencing?\nPatient: Yeah, I've been hav...","S: The patient, a married woman with a 7-year history of infertility, reports irregular menstruation and excessive body hair growth. She has a history of two early miscarriages, intrauterine insem...","Create a Medical SOAP note summary from the dialogue, following these guidelines:\n S (Subjective): Summarize the patient's reported symptoms, including chief complaint and relevant history. Re...","[{'role': 'system', 'content': 'You are an expert medical professor assisting in the creation of medically accurate SOAP summaries. Please ensure the response follows the structured format: S:, O:...","[{'role': 'user', 'content': 'You are an expert medical professor assisting in the creation of medically accurate SOAP summaries. Please ensure the response follows the structured format: S:, O:, ..."


In [3]:
#Combine splits + rename 'soap' to 'soap_note'
train_df["split"] = "train"
val_df["split"]   = "validation"
test_df["split"]  = "test"

df = pd.concat([train_df, val_df, test_df], ignore_index=True)

# Use 'soap_note' as the main clinical text column
df = df.rename(columns={"soap": "soap_note"})

print(df.shape)
df[["split", "soap_note"]].head()

(10000, 6)


Unnamed: 0,split,soap_note
0,train,S: The patient's mother reports that her 13-year-old son has mild to moderate speech and developmental delays and has been diagnosed with attention deficit disorder. She denies any issues with mus...
1,train,"S: The patient, a 21-month-old male, presented with weakness in his lower extremities and lumbar pain following a mild upper respiratory tract infection. Initial treatment with anti-inflammatory t..."
2,train,"S: Patient reports experiencing fatigue, night sweats, weight loss, loss of appetite, and mild abdominal discomfort for the past 2 months. No fever, chills, cough, nausea, vomiting, itching, or ur..."
3,train,"S: Patient D, a 60-year-old African American male, reports no current symptoms and is visiting for a routine check-up. He has a family history significant for prostate cancer, as his 62-year-old b..."
4,train,"S: The patient, a married woman with a 7-year history of infertility, reports irregular menstruation and excessive body hair growth. She has a history of two early miscarriages, intrauterine insem..."


In [4]:
#Building abbreviation dictionary
abbr_rows = [
    # Patient demographics
    ("M/F", "Male/Female"),
    ("Wt", "Weight"),
    ("Ht", "Height"),
    ("BMI", "Body Mass Index"),
    ("IP No", "In-Patient number"),
    ("OP No", "Out-Patient number"),
    ("DOA", "Date of Admission"),
    ("DOD", "Date of Discharge"),
    ("DOB", "Date of Birth"),
    ("MRN", "Medical Record Number"),
    ("Eth", "Ethnicity"),

    # Case details 
    ("A/P", "Anterior / Posterior"),
    ("A.R.O.M", "Active Range of Motion"),
    ("B/L", "Bilateral"),
    ("BR", "Bed Rest"),
    ("bs", "Bowel Sounds"),
    ("CA", "Cardiac Arrest"),
    ("CC", "Chief Complaints"),
    ("C/O", "Complaints of"),
    ("CP", "Chest Pain"),
    ("cal", "Calorie"),
    ("cath", "Catheter"),
    ("CHI", "Closed Head Injury"),
    ("CMT", "Continuing Medication and Treatment"),
    ("CXR", "Chest X-Ray"),
    ("d", "Day"),
    ("d/c", "Discontinue"),
    ("DC", "Discharge"),
    ("DOE", "Dyspnea on Exertion"),
    ("DNK", "Do Not Know"),
    ("DNKA", "Did Not Keep Appointment"),
    ("d/t", "Due to"),
    ("ED", "Emergency Department"),
    ("ER", "Emergency Room"),
    ("EENT", "Eyes, Ears, Nose, Throat"),
    ("ext", "External / Exterior"),
    ("F/C", "Fever and Chills"),
    ("fl", "Fluid"),
    ("fld", "Fluid"),
    ("FOB", "Foot of Bed"),
    ("f/u", "Follow-up"),
    ("FWB", "Full Weight Bearing"),
    ("Fx", "Fracture"),
    ("gen", "General"),
    ("gest.", "Gestation"),
    ("G.I.", "Gastro Intestinal"),
    ("gluc", "Glucose"),
    ("Gt. tr.", "Gait Training"),
    ("h", "Hour"),
    ("HOB", "Head of Bed"),
    ("H&P", "History and Physical"),
    ("imp.", "Impression"),
    ("incr.", "Increased"),
    ("inf", "Infusion / Inferior"),

    ("int.", "Internal"),
    ("I&O", "Intake and Output"),
    ("irreg.", "Irregular"),
    ("J", "Joint"),
    ("jt.", "Joint"),
    ("Lt", "Left"),
    ("LM", "Loose Motions"),
    ("LOE", "Loss of Energy"),
    ("lat.", "Lateral"),
    ("LBW", "Low Birth Weight"),
    ("LOC", "Loss of Consciousness"),
    ("LOS", "Length of Stay"),
    ("Lx", "Larynx"),
    ("L&W", "Living and Well"),
    ("mdnt.", "Midnight"),
    ("min", "Minute"),
    ("mod", "Moderate"),
    ("mss", "Massage"),
    ("MVA", "Motor Vehicle Accident"),
    ("n.", "Nerve"),
    ("N/V", "Nausea/Vomiting"),
    ("NAD", "No Abnormality Detected / No Apparent Distress"),
    ("neg.", "Negative"),
    ("neur.", "Neurology"),
    ("NG", "Nasogastric"),
    ("NKA", "No Known Allergies"),
    ("NOS", "Not Otherwise Specified"),
    ("NPO", "Nil Per Oral"),
    ("NBM", "Nil by Mouth / Nothing by Mouth"),
    ("NSA", "No Specific Abnormality"),
    ("NST", "Non Stress Test"),
    ("NVD", "Nausea, Vomiting, Diarrhea"),
    ("NWB", "Non Weight Bearing"),
    ("NYD", "Not Yet Diagnosed"),
    ("Obs", "Observation"),
    ("O/E", "On Examination"),
    ("OOB", "Out of Bed"),
    ("Op.", "Operation"),
    ("ot.", "Ear"),
    ("p&a", "Percussion and Auscultation"),
    ("palp.", "Palpate / Palpated / Palpable"),
    ("Path", "Pathology"),
    ("PA view", "Posterior-anterior view on X-ray"),
    ("PI", "Present Illness / Pulmonary Insufficiency"),

    ("P. M", "Afternoon / Post-mortem"),
    ("PN", "Poorly Nourished"),
    ("p.o.", "Per Oral / By Mouth"),
    ("p.o.d.", "Post Operative Day"),
    ("pos.", "Positive"),
    ("post.", "Posterior"),
    ("Pre-op", "Pre-operative"),
    ("Post-op", "Post-operative"),
    ("prep.", "Prepare for"),
    ("p.r.m.", "According to Circumstances"),
    ("p.r.n.", "As Often as Necessary / As Needed"),
    ("PRN", "As Often as Necessary / As Needed"),
    ("prod.", "Productive"),
    ("Prog.", "Prognosis"),
    ("PROM", "Passive Range of Motion"),
    ("prosth.", "Prosthesis"),
    ("PTA", "Prior to Admission"),
    ("Px", "Physical Examination"),
    ("PX", "Physical Examination"),
    ("q", "Every"),
    ("rad.", "Radial"),
    ("r.a.m.", "Rapid Alternating Movements"),
    ("R.A.S.", "Right Arm Sitting"),
    ("RD", "Respiratory Distress"),
    ("RO", "Rule Out"),
    ("R/O", "Rule Out"),
    ("ROS", "Review of Symptoms"),
    ("Rt", "Right"),
    ("RAtx", "Radiation Therapy"),
    ("RUE", "Right Upper Extremity"),
    ("LUE", "Left Upper Extremity"),
    ("RLE", "Right Lower Extremity"),
    ("LLE", "Left Lower Extremity"),
    ("Scc", "Squamous Cell Carcinoma"),
    ("SCD", "Sudden Cardiac Death"),
    ("SOB", "Shortness of Breath"),
    ("S/P", "Status Post"),
    ("s/p", "Status Post"),
    ("Sp. fl.", "Spinal Fluid"),
    ("STD", "Sexually Transmitted Disease"),
    ("Sx", "Symptoms"),
    ("Syst.", "Systolic"),
    ("sp. cd.", "Spinal Cord"),

    ("spont.", "Spontaneous"),
    ("s/s", "Signs and Symptoms"),
    ("sup.", "Superior"),
    ("surg.", "Surgery / Surgical"),
    ("Sys.", "System"),
    ("Sz", "Seizures"),
    ("THR", "Total Hip Replacement"),
    ("TIA", "Transient Ischemic Attack"),
    ("TKR", "Total Knee Replacement"),
    ("TPN", "Total Parenteral Nutrition"),
    ("T&A", "Tonsils and Adenoids / Tonsillectomy and Adenoidectomy"),
    ("TB", "Tuberculosis"),
    ("Unilat.", "Unilateral"),
    ("u/o", "Under Observation"),
    ("Ur.", "Urine"),
    ("URD", "Upper Respiratory Disease"),
    ("URTI", "Upper Respiratory Tract Infection"),
    ("URI", "Upper Respiratory Infection"),
    ("UTI", "Urinary Tract Infection"),
    ("U/A", "Urine Analysis"),
    ("UCD", "Usual Childhood Diseases"),
    ("UCHD", "Usual Childhood Diseases"),
    ("Urol.", "Urology"),
    ("UE", "Upper Extremity"),
    ("VA", "Visual Acuity"),
    ("vag", "Vagina / Vaginal"),
    ("vent.", "Ventilator"),
    ("VF", "Visual Fields / Ventricular Fibrillation"),
    ("VC", "Vital Capacity"),
    ("VD", "Venereal Disease"),
    ("w", "Week"),
    ("wk", "Week"),
    ("w/n", "Within"),
    ("WNL", "Within Normal Limits"),
    ("w/u", "Workup"),
    ("x", "Since / Times"),
    ("y.o", "Years Old"),
    ("yrs.", "Years"),
]

abbr_df = pd.DataFrame(abbr_rows, columns=["abbr", "expanded"])

# Normalize (lowercase keys)
abbr_df["abbr"] = abbr_df["abbr"].str.strip()
abbr_df["expanded"] = abbr_df["expanded"].str.strip()

os.makedirs("data", exist_ok=True)
abbr_path = os.path.join("data", "abbreviations_from_pdf.csv")
abbr_df.to_csv(abbr_path, index=False)

abbr_df.head(), abbr_path

(    abbr           expanded
 0    M/F        Male/Female
 1     Wt             Weight
 2     Ht             Height
 3    BMI    Body Mass Index
 4  IP No  In-Patient number,
 'data\\abbreviations_from_pdf.csv')

In [5]:
#Preprocessing functions using abbrev dict
ABBR_PATTERN = re.compile(r"\b([A-Za-z][A-Za-z0-9./]{0,15})\b")

def load_abbreviation_dict(path: str):
    df_abbr = pd.read_csv(path)
    df_abbr["abbr"] = df_abbr["abbr"].str.lower().str.strip()
    df_abbr["expanded"] = df_abbr["expanded"].str.strip()
    return dict(zip(df_abbr["abbr"], df_abbr["expanded"]))

def basic_clean(text: str) -> str:
    if not isinstance(text, str):
        return ""
    text = text.strip()
    text = re.sub(r"\s+", " ", text)
    text = text.lower()
    return text

def expand_abbreviations(text: str, abbr_dict):
    def repl(match):
        token = match.group(1)
        key = token.lower()
        return abbr_dict.get(key, token)
    return ABBR_PATTERN.sub(repl, text)

def simple_tokenize(text: str):
    # keep only letters and spaces for modeling
    text = re.sub(r"[^a-z\s]", " ", text)
    tokens = text.split()
    tokens = [t for t in tokens if t not in ENGLISH_STOP_WORDS and len(t) > 1]
    return " ".join(tokens)

def preprocess_note(text: str, abbr_dict) -> str:
    text = basic_clean(text)
    text = expand_abbreviations(text, abbr_dict)
    text = simple_tokenize(text)
    return text

In [6]:
#Apply preprocessing and save the processed dataset
abbr_dict = load_abbreviation_dict("data/abbreviations_from_pdf.csv")

df["note_clean"] = df["soap_note"].astype(str).apply(
    lambda x: preprocess_note(x, abbr_dict)
)

df[["soap_note", "note_clean"]].head()


Unnamed: 0,soap_note,note_clean
0,S: The patient's mother reports that her 13-year-old son has mild to moderate speech and developmental delays and has been diagnosed with attention deficit disorder. She denies any issues with mus...,patient mother reports year old son mild moderate speech developmental delays diagnosed attention deficit disorder denies issues muscle tone hypotonia patient exhibits certain physical characteris...
1,"S: The patient, a 21-month-old male, presented with weakness in his lower extremities and lumbar pain following a mild upper respiratory tract infection. Initial treatment with anti-inflammatory t...",patient month old male presented weakness lower extremities lumbar pain following mild upper respiratory tract infection initial treatment anti inflammatory therapy suspected transient hips arthri...
2,"S: Patient reports experiencing fatigue, night sweats, weight loss, loss of appetite, and mild abdominal discomfort for the past 2 months. No fever, chills, cough, nausea, vomiting, itching, or ur...",patient reports experiencing fatigue night sweats weight loss loss appetite mild abdominal discomfort past months fever chills cough nausea vomiting itching urinary bowel issues past medical histo...
3,"S: Patient D, a 60-year-old African American male, reports no current symptoms and is visiting for a routine check-up. He has a family history significant for prostate cancer, as his 62-year-old b...",patient ay year old african american male reports current symptoms visiting routine check family history significant prostate cancer year old brother diagnosed treated radiation patient currently ...
4,"S: The patient, a married woman with a 7-year history of infertility, reports irregular menstruation and excessive body hair growth. She has a history of two early miscarriages, intrauterine insem...",patient married woman year history infertility reports irregular menstruation excessive body hair growth history early miscarriages intrauterine insemination attempts intramural myoma myomectomy d...


In [7]:
#saving processed dataset
import os

os.makedirs("data/processed", exist_ok=True)

clean_path = os.path.join("data", "processed", "soap_notes_clean.csv")
df.to_csv(clean_path, index=False)

clean_path

'data\\processed\\soap_notes_clean.csv'

In [8]:
df["len_raw"] = df["soap_note"].str.len()
df["len_clean"] = df["note_clean"].str.len()

df[["len_raw", "len_clean"]].describe()

Unnamed: 0,len_raw,len_clean
count,10000.0,10000.0
mean,1734.2124,1348.3444
std,284.205267,232.168747
min,710.0,525.0
25%,1544.0,1191.75
50%,1736.0,1349.0
75%,1927.0,1508.0
max,2855.0,2324.0


In [9]:
#Defining  10 disease categories + keyword dictionaries
disease_categories = {
    "respiratory": [
        "cough", "sob", "shortness of breath", "asthma", "wheezing",
        "bronchitis", "pneumonia", "uri", "urti", "respiratory"
    ],
    "cardiac": [
        "chest pain", "cp", "palpitations", "cad", "hypertension", 
        "htn", "heart failure", "angina"
    ],
    "gi_gu": [
        "abdominal", "diarrhea", "vomiting", "nausea", "constipation",
        "gastric", "ulcer", "dyspepsia", "urinary", "uti"
    ],
    "neurological": [
        "headache", "migraine", "seizure", "dizziness", "vertigo",
        "stroke", "tia", "weakness", "numbness", "tingling"
    ],
    "endocrine_metabolic": [
        "diabetes", "thyroid", "hypothyroid", "hyperthyroid", "pcos",
        "weight gain", "obesity"
    ],
    "musculoskeletal": [
        "joint pain", "back pain", "neck pain", "arthritis",
        "sprain", "strain", "fracture", "hip pain", "knee pain"
    ],
    "infectious": [
        "fever", "infection", "viral", "bacterial", "covid",
        "influenza", "sore throat", "strep", "uti", "sepsis"
    ],
    "women_health": [
        "pregnancy", "menstruation", "pcos", "fibroid", "miscarriage",
        "vaginal", "pelvic pain", "gyn", "gynecologic"
    ],
    "pediatrics": [
        "child", "infant", "toddler", "pediatric", "developmental",
        "adhd", "speech delay"
    ],
    "general": [
        "fatigue", "tired", "check-up", "routine", "follow-up",
        "annual exam", "no symptoms", "general"
    ]
}

# Flatten keywords for faster search
category_keywords = {
    cat: [kw.lower() for kw in kws]
    for cat, kws in disease_categories.items()
}

category_keywords

{'respiratory': ['cough',
  'sob',
  'shortness of breath',
  'asthma',
  'wheezing',
  'bronchitis',
  'pneumonia',
  'uri',
  'urti',
  'respiratory'],
 'cardiac': ['chest pain',
  'cp',
  'palpitations',
  'cad',
  'hypertension',
  'htn',
  'heart failure',
  'angina'],
 'gi_gu': ['abdominal',
  'diarrhea',
  'vomiting',
  'nausea',
  'constipation',
  'gastric',
  'ulcer',
  'dyspepsia',
  'urinary',
  'uti'],
 'neurological': ['headache',
  'migraine',
  'seizure',
  'dizziness',
  'vertigo',
  'stroke',
  'tia',
  'weakness',
  'numbness',
  'tingling'],
 'endocrine_metabolic': ['diabetes',
  'thyroid',
  'hypothyroid',
  'hyperthyroid',
  'pcos',
  'weight gain',
  'obesity'],
 'musculoskeletal': ['joint pain',
  'back pain',
  'neck pain',
  'arthritis',
  'sprain',
  'strain',
  'fracture',
  'hip pain',
  'knee pain'],
 'infectious': ['fever',
  'infection',
  'viral',
  'bacterial',
  'covid',
  'influenza',
  'sore throat',
  'strep',
  'uti',
  'sepsis'],
 'women_health':

In [10]:
#Auto-label SOAP notes into categories
def assign_category(text):
    """
    Match keywords to assign a disease category.
    If multiple match, the first category wins.
    If none match, assign 'general'.
    """
    t = text.lower()

    for category, kws in category_keywords.items():
        for kw in kws:
            if kw in t:
                return category

    return "general"  # fallback


df["category"] = df["note_clean"].apply(assign_category)

df[["note_clean", "category"]].head(10)

Unnamed: 0,note_clean,category
0,patient mother reports year old son mild moderate speech developmental delays diagnosed attention deficit disorder denies issues muscle tone hypotonia patient exhibits certain physical characteris...,neurological
1,patient month old male presented weakness lower extremities lumbar pain following mild upper respiratory tract infection initial treatment anti inflammatory therapy suspected transient hips arthri...,respiratory
2,patient reports experiencing fatigue night sweats weight loss loss appetite mild abdominal discomfort past months fever chills cough nausea vomiting itching urinary bowel issues past medical histo...,respiratory
3,patient ay year old african american male reports current symptoms visiting routine check family history significant prostate cancer year old brother diagnosed treated radiation patient currently ...,gi_gu
4,patient married woman year history infertility reports irregular menstruation excessive body hair growth history early miscarriages intrauterine insemination attempts intramural myoma myomectomy d...,gi_gu
5,patient long term sufferer acromegaly reports experiencing shortness breath dyspnea exertion past years episode presyncope weeks ago patient treated octreotide approximately years undergone stereo...,cardiac
6,patient reports history abdominal pain approximately year worsening months particularly postprandial localized right upper abdomen additionally patient experienced jaundice generalized itching pas...,gi_gu
7,patient experiencing right ankle swelling age reports pain movement limitation finds swelling annoying patient history intermittent high fever bilateral knee right ankle swelling month prior admis...,respiratory
8,patient past history sigmoid volvulus treated laparotomy detwisting sigmoidopexy elective sigmoidectomy presents symptoms starting months ago including abdominal distension constipation vomiting l...,respiratory
9,patient history asthma unresponsive husband initiated cpr reported using nebulizer prior event patient pulseless electrical activity arrest received cpr approximately minutes ems arrival given dos...,respiratory


In [11]:
#checking distribution
df["category"].value_counts()

category
respiratory            4263
gi_gu                  2184
neurological           2054
cardiac                1270
general                 134
musculoskeletal          33
infectious               25
women_health             16
endocrine_metabolic      12
pediatrics                9
Name: count, dtype: int64

In [12]:
#Merge 10 raw categories into 6 final labels
def merge_category(row):
    cat = row["category"]
    text = row["note_clean"].lower()

    # Keep main big classes as-is
    if cat in ["respiratory", "gi_gu", "neurological", "cardiac", "musculoskeletal", "general"]:
        return cat

    # Women’s health → GI/GU (gyn, pelvic, menstrual etc. are GU/pelvic)
    if cat == "women_health":
        return "gi_gu"

    # Endocrine + pediatrics → general (low counts, heterogeneous)
    if cat in ["endocrine_metabolic", "pediatrics"]:
        return "general"

    # Infectious → respiratory or GI/GU depending on hints
    if cat == "infectious":
        if "uti" in text or "urinary" in text or "dysuria" in text:
            return "gi_gu"
        else:
            return "respiratory"

    # Fallback
    return "general"


df["label"] = df.apply(merge_category, axis=1)

df[["category", "label"]].head(20)

Unnamed: 0,category,label
0,neurological,neurological
1,respiratory,respiratory
2,respiratory,respiratory
3,gi_gu,gi_gu
4,gi_gu,gi_gu
5,cardiac,cardiac
6,gi_gu,gi_gu
7,respiratory,respiratory
8,respiratory,respiratory
9,respiratory,respiratory


In [13]:
#final checking
df["label"].value_counts()

label
respiratory        4288
gi_gu              2200
neurological       2054
cardiac            1270
general             155
musculoskeletal      33
Name: count, dtype: int64

Clinical notes in real-world Electronic Health Records (EHRs) contain a wide variety of symptoms, complaints, and diagnostic clues, but they rarely include a clean, structured diagnosis field. To build a clinically meaningful machine-learning model from the SOAP notes in this dataset, it was necessary to convert the unstructured text into high-level disease categories that a classifier can reliably learn.

Initially, we auto-assigned each SOAP note to 10 granular disease categories (respiratory, GI/GU, neurological, cardiac, women’s health, infectious, endocrine, musculoskeletal, pediatrics, general) using medically-informed keyword rules. However, several categories had very low sample sizes (fewer than 50 notes), which would lead to poor model performance, class imbalance issues, and unstable predictions.

To create a clinically robust and machine-learning-friendly target label, these smaller categories were merged into related broader categories, resulting in six final disease classes:

Respiratory

GI/GU

Neurological

Cardiac

Musculoskeletal

General

This merging approach reflects standard practice in healthcare NLP, where small, heterogeneous classes are consolidated to ensure statistical reliability while still preserving clinical meaning. The final 6-class setup provides a balanced, interpretable, and realistic framework for training a text classification model that mimics high-level triage or categorization tasks used in real clinical workflows.

In [15]:
#Train TF-IDF + Logistic Regression classifier
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, f1_score, precision_score, recall_score

# Features and labels
X = df["note_clean"]
y = df["label"]

X_train, X_test, y_train, y_test = train_test_split(
    X, y,
    test_size=0.2,
    stratify=y,
    random_state=42
)

print("Train size:", X_train.shape[0])
print("Test size: ", X_test.shape[0])

# TF-IDF vectorizer
vectorizer = TfidfVectorizer(
    ngram_range=(1, 2),
    min_df=5,
    max_features=20000
)

X_train_vec = vectorizer.fit_transform(X_train)
X_test_vec = vectorizer.transform(X_test)

# Logistic Regression model
clf = LogisticRegression(
    max_iter=2000,
    class_weight="balanced",
    n_jobs=-1
)

clf.fit(X_train_vec, y_train)

y_pred = clf.predict(X_test_vec)

print("Classification report:\n")
print(classification_report(y_test, y_pred))

micro_f1 = f1_score(y_test, y_pred, average="micro")
macro_f1 = f1_score(y_test, y_pred, average="macro")

print(f"Micro F1: {micro_f1:.3f}")
print(f"Macro F1: {macro_f1:.3f}")

Train size: 8000
Test size:  2000
Classification report:

                 precision    recall  f1-score   support

        cardiac       0.53      0.61      0.57       254
        general       0.17      0.29      0.21        31
          gi_gu       0.60      0.60      0.60       440
musculoskeletal       0.24      0.71      0.36         7
   neurological       0.50      0.64      0.56       411
    respiratory       0.83      0.65      0.73       857

       accuracy                           0.63      2000
      macro avg       0.48      0.58      0.50      2000
   weighted avg       0.66      0.63      0.64      2000

Micro F1: 0.625
Macro F1: 0.505


In [16]:
#Confusion matrix 
from sklearn.metrics import confusion_matrix

labels_sorted = sorted(df["label"].unique())

cm = confusion_matrix(y_test, y_pred, labels=labels_sorted)

cm_df = pd.DataFrame(cm, index=[f"true_{l}" for l in labels_sorted],
                        columns=[f"pred_{l}" for l in labels_sorted])

cm_df

Unnamed: 0,pred_cardiac,pred_general,pred_gi_gu,pred_musculoskeletal,pred_neurological,pred_respiratory
true_cardiac,154,2,25,3,36,34
true_general,4,9,2,0,15,1
true_gi_gu,22,7,265,2,103,41
true_musculoskeletal,0,0,0,5,2,0
true_neurological,24,23,58,7,263,36
true_respiratory,87,12,91,4,108,555


In [17]:
import os
import joblib

os.makedirs("models", exist_ok=True)

joblib.dump(vectorizer, "models/tfidf_vectorizer_disease_categories.joblib")
joblib.dump(clf, "models/logreg_disease_category_model.joblib")

"Artifacts saved in models/ folder."

'Artifacts saved in models/ folder.'

In [18]:
import xgboost as xgb
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics import classification_report, f1_score

# Features and labels
X = df["note_clean"]
y = df["label"]

# Encode labels into integers for XGBoost
label_encoder = LabelEncoder()
y_encoded = label_encoder.fit_transform(y)

X_train, X_test, y_train, y_test = train_test_split(
    X, y_encoded,
    test_size=0.2,
    stratify=y_encoded,
    random_state=42
)

print("Train size:", X_train.shape[0])
print("Test size: ", X_test.shape[0])

# TF-IDF vectorizer (slightly reduced for speed)
vectorizer = TfidfVectorizer(
    ngram_range=(1, 2),
    min_df=5,
    max_features=15000 
)

X_train_vec = vectorizer.fit_transform(X_train)
X_test_vec = vectorizer.transform(X_test)

# XGBoost classifier (lighter but strong)
xgb_clf = xgb.XGBClassifier(
    objective="multi:softmax",
    num_class=len(label_encoder.classes_),
    eval_metric="mlogloss",
    max_depth=6,
    learning_rate=0.12,   # eta
    n_estimators=160,
    subsample=0.7,
    colsample_bytree=0.7,
    tree_method="hist",
    n_jobs=-1
)

xgb_clf.fit(X_train_vec, y_train)

# Predictions
y_pred = xgb_clf.predict(X_test_vec)

print("\n=== XGBoost Classification Report (6 disease classes) ===\n")
print(classification_report(label_encoder.inverse_transform(y_test),
                            label_encoder.inverse_transform(y_pred)))

print("Macro F1:", f1_score(y_test, y_pred, average='macro'))
print("Micro F1:", f1_score(y_test, y_pred, average='micro'))

Train size: 8000
Test size:  2000

=== XGBoost Classification Report (6 disease classes) ===

                 precision    recall  f1-score   support

        cardiac       0.95      0.87      0.91       254
        general       0.62      0.16      0.26        31
          gi_gu       0.90      0.89      0.90       440
musculoskeletal       0.00      0.00      0.00         7
   neurological       0.75      1.00      0.86       411
    respiratory       1.00      0.91      0.95       857

       accuracy                           0.90      2000
      macro avg       0.70      0.64      0.64      2000
   weighted avg       0.91      0.90      0.90      2000

Macro F1: 0.6446726185734101
Micro F1: 0.9035


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [19]:
import os
import joblib

os.makedirs("models", exist_ok=True)

joblib.dump(vectorizer, "models/tfidf_vectorizer_6class_xgb.joblib")
joblib.dump(xgb_clf, "models/xgboost_6class_disease_model.joblib")
joblib.dump(label_encoder, "models/label_encoder_6class.joblib")

['models/label_encoder_6class.joblib']

In [20]:
# Setting path for Option B CSV
clinical_cases_path = "clinical_cases_optionB_dataset.csv"

# Loading clinical cases dataset
clinical_cases = pd.read_csv(clinical_cases_path)

# Checking shape and first few rows
print("Shape of clinical_cases:", clinical_cases.shape)
clinical_cases.head()

Shape of clinical_cases: (27, 3)


Unnamed: 0,category,case_title,clean_text
0,general,Carbuncle,"Adult with painful swelling, redness, and pus over lower back for 10 days, fever present. Abscess with induration and discharge. Labs: elevated glucose and platelets. Treated with IV antibiotics, ..."
1,general,Cellulitis – Right Foot,"Adult with pain, redness, swelling, fever, and serous discharge from right leg. Labs: leukocytosis, neutrophilia, low albumin. Managed with IV fluids, antibiotics, and fasciotomy with drainage."
2,general,Acute-on-Chronic Dacryocystitis,"Adult with pain, swelling, tearing, and tenderness near medial canthus for 3 days. Mild fever. Treated with IV antibiotics, topical therapy, NSAIDs."
3,general,Dysmenorrhea With Ovarian Cyst,"Reproductive-age woman with severe lower abdominal pain radiating to back and heavy bleeding. Ultrasound shows simple ovarian cyst. Treated with antifibrinolytics, NSAIDs, and supportive care."
4,general,Broad Ligament Fibroid,Reproductive-age woman with pelvic pain or heavy bleeding. Ultrasound confirms broad ligament fibroid. Managed with analgesia and planned surgery.


In [21]:
# Checking category distribution in clinical_cases

print(clinical_cases['category'].value_counts())


category
general            13
gi_gu               5
respiratory         4
musculoskeletal     4
cardiac             1
Name: count, dtype: int64


In [53]:
# Ensure text is string
clinical_cases["clean_text"] = clinical_cases["clean_text"].fillna("").astype(str)

# Apply SAME preprocessing pipeline
clinical_cases["note_clean"] = clinical_cases["clean_text"].apply(
    lambda x: preprocess_note(x, abbr_dict)
)

clinical_cases[["category", "note_clean"]].head()

Unnamed: 0,category,note_clean
0,general,adult painful swelling redness pus lower days fever present abscess induration discharge labs elevated glucose platelets treated iv antibiotics analgesics incision drainage
1,general,adult pain redness swelling fever serous discharge right leg labs leukocytosis neutrophilia low albumin managed iv fluids antibiotics fasciotomy drainage
2,general,adult pain swelling tearing tenderness near medial canthus days mild fever treated iv antibiotics topical therapy nsaids
3,general,reproductive age woman severe lower abdominal pain radiating heavy bleeding ultrasound shows simple ovarian cyst treated antifibrinolytics nsaids supportive care
4,general,reproductive age woman pelvic pain heavy bleeding ultrasound confirms broad ligament fibroid managed analgesia planned surgery


In [55]:
def reviewed_6class_label(text):
    t = text.lower()

    # Rule B: GYN / pelvic → GI/GU
    if any(k in t for k in [
        "ovarian", "fibroid", "pelvic", "uterine", "menstrual",
        "bleeding", "dysmenorrhea"
    ]):
        return "gi_gu"

    # Rule C: respiratory
    if any(k in t for k in [
        "cough", "shortness of breath", "sob", "pneumonia",
        "asthma", "wheezing", "respiratory"
    ]):
        return "respiratory"

    # Rule D: cardiac
    if any(k in t for k in [
        "chest pain", "heart failure", "angina", "palpitations",
        "myocardial", "cardiac"
    ]):
        return "cardiac"

    # Rule E: musculoskeletal
    if any(k in t for k in [
        "joint", "back pain", "limb", "swelling", "fracture",
        "sprain", "strain"
    ]) and "fever" not in t:
        return "musculoskeletal"

    # Rule A: infections / abscess / cellulitis → general
    if any(k in t for k in [
        "abscess", "cellulitis", "infection", "pus",
        "iv antibiotics", "fever"
    ]):
        return "general"

    # Rule F: fallback
    return "general"


In [61]:
label_map_6 = {
    "cardiac": "cardiac",
    "respiratory": "respiratory",
    "gi_gu": "gi_gu",
    "neurological": "neurological",          # if exists later
    "musculoskeletal": "musculoskeletal",
    "general": "general"
}

clinical_cases["true_label_6"] = (
    clinical_cases["category"]
    .astype(str).str.strip().str.lower()
    .map(label_map_6)
    .fillna("general")
)

clinical_cases["true_label_6"].value_counts()


true_label_6
general            13
gi_gu               5
respiratory         4
musculoskeletal     4
cardiac             1
Name: count, dtype: int64

In [63]:
clinical_cases["reviewed_6class_label"] = clinical_cases["note_clean"].apply(
    reviewed_6class_label
)

clinical_cases[[
    "case_title",
    "category",
    "true_label_6",
    "reviewed_6class_label"
]]


Unnamed: 0,case_title,category,true_label_6,reviewed_6class_label
0,Carbuncle,general,general,general
1,Cellulitis – Right Foot,general,general,general
2,Acute-on-Chronic Dacryocystitis,general,general,general
3,Dysmenorrhea With Ovarian Cyst,general,general,gi_gu
4,Broad Ligament Fibroid,general,general,gi_gu
5,Term Pregnancy – LSCS,general,general,general
6,Acute Fissure in Ano,general,general,gi_gu
7,Lymphangitis,general,general,general
8,Right Eye Dacryostenosis,general,general,general
9,Left Eye Cataract,general,general,general


In [59]:
print(clinical_cases.columns.tolist())


['category', 'case_title', 'clean_text', 'note_clean', 'reviewed_6class_label']


After evaluating model predictions on the external Option B clinical cases, it became evident that several records were affected by label noise due to overly broad or generic category assignments (particularly cases labeled as “general”). To address this, a clinician-informed relabeling strategy (v2) was introduced, mapping cases into six triage-oriented disease classes: respiratory, GI/GU, neurological, cardiac, musculoskeletal, and general. The v2 rules prioritized clinically dominant organ systems using keyword patterns, corrected major mismatches such as cardiac conditions labeled as general, and ensured gynecologic and hepatobiliary cases were appropriately grouped under GI/GU. This step significantly improved semantic alignment between clinical content and model-intended label structure without retraining the model.

In [66]:
def reviewed_6class_label_v2(text: str) -> str:
    t = str(text).lower()

    # --- CARDIAC  ---
    if any(k in t for k in [
        "carditis", "myocard", "heart failure", "murmur", "mitral", "mr ",
        "rheumatic", "angina", "palpitations", "arrhythm", "cardiac", "chest pain"
    ]):
        return "cardiac"

    # --- NEURO  ---
    if any(k in t for k in [
        "seizure", "stroke", "tia", "hemip", "weakness", "numbness",
        "headache", "migraine", "vertigo"
    ]):
        return "neurological"

    # --- MUSCULOSKELETAL / TRAUMA ---
    if any(k in t for k in [
        "fracture", "compound", "subluxation", "dislocation", "sprain", "strain",
        "joint", "p i p", "pip joint", "bone", "orthopedic", "trauma", "cast"
    ]):
        return "musculoskeletal"

    # --- GI / GU / GYN ---
    if any(k in t for k in [
        "abdominal", "pancrea", "gastro", "diarrhea", "vomit", "nausea",
        "cirrhosis", "portal", "bilirubin", "jaundice", "hepat", "liver",
        "cystitis", "uti", "urinary", "dysuria",
        "hernia", "inguinal",
        "pelvic", "ovarian", "fibroid", "uterine", "menstrual", "dysmenorrhea", "bleeding"
    ]):
        return "gi_gu"

    # --- RESPIRATORY ---
    if any(k in t for k in [
        "copd", "asthma", "wheez", "pneumonia", "bronch", "bronchiol",
        "cough", "shortness of breath", "sob", "l r t i", "lrti", "respiratory"
    ]):
        return "respiratory"

    # --- GENERAL / INFECTIOUS / DERM / OPHTH ---
    if any(k in t for k in [
        "abscess", "cellulitis", "infection", "pus", "fever", "antibiotic",
        "dacryo", "cataract", "tonsill", "anaemia", "anemia", "purpura", "lymphang"
    ]):
        return "general"

    return "general"


In [68]:
clinical_cases["reviewed_6class_label_v2"] = clinical_cases["note_clean"].apply(reviewed_6class_label_v2)

clinical_cases[[
    "case_title","category","true_label_6","reviewed_6class_label","reviewed_6class_label_v2"
]]


Unnamed: 0,case_title,category,true_label_6,reviewed_6class_label,reviewed_6class_label_v2
0,Carbuncle,general,general,general,general
1,Cellulitis – Right Foot,general,general,general,general
2,Acute-on-Chronic Dacryocystitis,general,general,general,general
3,Dysmenorrhea With Ovarian Cyst,general,general,gi_gu,gi_gu
4,Broad Ligament Fibroid,general,general,gi_gu,gi_gu
5,Term Pregnancy – LSCS,general,general,general,general
6,Acute Fissure in Ano,general,general,gi_gu,gi_gu
7,Lymphangitis,general,general,general,general
8,Right Eye Dacryostenosis,general,general,general,general
9,Left Eye Cataract,general,general,general,general


Building on v2, a final reviewed labeling framework (v3) was implemented to resolve remaining edge cases and ensure full clinical consistency. This version refined rule priority and expanded musculoskeletal coverage to include orthopedic trauma, hernias, and limb-related vascular conditions (e.g., deep vein thrombosis), reflecting realistic triage workflows in settings where dedicated vascular or surgical classes are unavailable. The v3 framework represents a stable, governance-ready labeling protocol that minimizes label ambiguity, preserves clinical intent, and serves as a reliable reference standard for external validation and future model retraining decisions.

In [72]:
def reviewed_6class_label_v3(text: str) -> str:
    t = str(text).lower()

    # --- CARDIAC (high priority) ---
    if any(k in t for k in [
        "carditis", "myocard", "heart failure", "murmur", "mitral", " mr",
        "rheumatic", "angina", "palpitations", "arrhythm", "cardiac", "chest pain"
    ]):
        return "cardiac"

    # --- NEURO ---
    if any(k in t for k in [
        "seizure", "stroke", "tia", "hemip", "weakness", "numbness",
        "headache", "migraine", "vertigo"
    ]):
        return "neurological"

    # --- MUSCULOSKELETAL / TRAUMA / VASCULAR LIMB ---
    if any(k in t for k in [
        "fracture", "compound", "subluxation", "dislocation", "sprain", "strain",
        "joint", "pip joint", "bone", "orthopedic", "trauma", "cast",
        "deep vein thrombosis", "dvt", "thrombo", "embol",
        "inguinal hernia", "hernia"
    ]):
        return "musculoskeletal"

    # --- GI / GU / GYN ---
    if any(k in t for k in [
        "abdominal", "pancrea", "gastro", "diarrhea", "vomit", "nausea",
        "cirrhosis", "portal", "bilirubin", "jaundice", "hepat", "liver",
        "cystitis", "uti", "urinary", "dysuria",
        "pelvic", "ovarian", "fibroid", "uterine", "menstrual", "dysmenorrhea", "bleeding"
    ]):
        return "gi_gu"

    # --- RESPIRATORY ---
    if any(k in t for k in [
        "copd", "asthma", "wheez", "pneumonia", "bronch", "bronchiol",
        "cough", "shortness of breath", "sob", "lrti", "respiratory"
    ]):
        return "respiratory"

    # --- GENERAL / INFECTIOUS / DERM / OPHTH ---
    if any(k in t for k in [
        "abscess", "cellulitis", "infection", "pus", "fever", "antibiotic",
        "dacryo", "cataract", "tonsill", "anaemia", "anemia", "purpura", "lymphang"
    ]):
        return "general"

    return "general"


In [74]:
clinical_cases["reviewed_6class_label_v3"] = clinical_cases["note_clean"].apply(reviewed_6class_label_v3)

clinical_cases.loc[
    clinical_cases["case_title"].isin(["Inguinal Hernia", "Deep Vein Thrombosis"]),
    ["case_title", "true_label_6", "reviewed_6class_label_v2", "reviewed_6class_label_v3"]
]


Unnamed: 0,case_title,true_label_6,reviewed_6class_label_v2,reviewed_6class_label_v3
25,Inguinal Hernia,musculoskeletal,gi_gu,musculoskeletal
26,Deep Vein Thrombosis,musculoskeletal,general,musculoskeletal


**Prediction for optionB dataset**

In [80]:
import joblib

# Load trained artifacts
vectorizer = joblib.load("models/tfidf_vectorizer_6class_xgb.joblib")
xgb_model = joblib.load("models/xgboost_6class_disease_model.joblib")
label_encoder = joblib.load("models/label_encoder_6class.joblib")

# Vectorize Option B notes
X_optb_vec = vectorizer.transform(clinical_cases["note_clean"])

# Predict (encoded -> label)
y_pred_encoded = xgb_model.predict(X_optb_vec)
clinical_cases["pred_label_6"] = label_encoder.inverse_transform(y_pred_encoded)

clinical_cases[["case_title", "true_label_6", "pred_label_6"]].head(10)


Unnamed: 0,case_title,true_label_6,pred_label_6
0,Carbuncle,general,neurological
1,Cellulitis – Right Foot,general,neurological
2,Acute-on-Chronic Dacryocystitis,general,neurological
3,Dysmenorrhea With Ovarian Cyst,general,gi_gu
4,Broad Ligament Fibroid,general,neurological
5,Term Pregnancy – LSCS,general,neurological
6,Acute Fissure in Ano,general,gi_gu
7,Lymphangitis,general,neurological
8,Right Eye Dacryostenosis,general,neurological
9,Left Eye Cataract,general,neurological


In [82]:
from sklearn.metrics import classification_report

print("=== Model vs OptionB original labels ===")
print(classification_report(clinical_cases["true_label_6"], clinical_cases["pred_label_6"], zero_division=0))

print("\n=== Model vs Clinician-reviewed labels (v3) ===")
print(classification_report(clinical_cases["reviewed_6class_label_v3"], clinical_cases["pred_label_6"], zero_division=0))


=== Model vs OptionB original labels ===
                 precision    recall  f1-score   support

        cardiac       0.00      0.00      0.00         1
        general       0.00      0.00      0.00        13
          gi_gu       0.40      0.40      0.40         5
musculoskeletal       0.00      0.00      0.00         4
   neurological       0.00      0.00      0.00         0
    respiratory       0.43      0.75      0.55         4

       accuracy                           0.19        27
      macro avg       0.14      0.19      0.16        27
   weighted avg       0.14      0.19      0.15        27


=== Model vs Clinician-reviewed labels (v3) ===
                 precision    recall  f1-score   support

        cardiac       0.00      0.00      0.00         1
        general       0.00      0.00      0.00         9
          gi_gu       0.80      0.57      0.67         7
musculoskeletal       0.00      0.00      0.00         6
   neurological       0.00      0.00      0.00     

In [84]:
errors = clinical_cases[clinical_cases["reviewed_6class_label_v3"] != clinical_cases["pred_label_6"]]
errors[["case_title","reviewed_6class_label_v3","pred_label_6","clean_text"]]


Unnamed: 0,case_title,reviewed_6class_label_v3,pred_label_6,clean_text
0,Carbuncle,general,neurological,"Adult with painful swelling, redness, and pus over lower back for 10 days, fever present. Abscess with induration and discharge. Labs: elevated glucose and platelets. Treated with IV antibiotics, ..."
1,Cellulitis – Right Foot,general,neurological,"Adult with pain, redness, swelling, fever, and serous discharge from right leg. Labs: leukocytosis, neutrophilia, low albumin. Managed with IV fluids, antibiotics, and fasciotomy with drainage."
2,Acute-on-Chronic Dacryocystitis,general,neurological,"Adult with pain, swelling, tearing, and tenderness near medial canthus for 3 days. Mild fever. Treated with IV antibiotics, topical therapy, NSAIDs."
4,Broad Ligament Fibroid,gi_gu,neurological,Reproductive-age woman with pelvic pain or heavy bleeding. Ultrasound confirms broad ligament fibroid. Managed with analgesia and planned surgery.
5,Term Pregnancy – LSCS,general,neurological,"Pregnant woman at term undergoing caesarean section for obstetric indications. Post-op care with antibiotics, fluids, uterotonics."
7,Lymphangitis,general,neurological,"Adult with painful red streak along limb from distal infection, fever possible. Treated with systemic antibiotics, limb elevation."
8,Right Eye Dacryostenosis,general,neurological,"Chronic tearing with regurgitation on pressure; nasolacrimal duct obstruction. Managed with massage, topical antibiotics, surgical probing if needed."
9,Left Eye Cataract,general,neurological,Gradual painless visual loss; lens opacity on exam. Managed with cataract extraction and IOL implantation.
10,Chronic Tonsillitis,general,neurological,"Recurrent sore throats, halitosis, tonsillar enlargement with crypts. Managed with antibiotics or tonsillectomy."
11,Henoch–Schönlein Purpura,musculoskeletal,respiratory,"Child with purpuric rash, joint pain, abdominal pain. Possible hematuria. Managed with hydration, monitoring, steroids if severe."


In [86]:
print("Label encoder classes:", list(label_encoder.classes_))
print("Num classes:", len(label_encoder.classes_))


Label encoder classes: ['cardiac', 'general', 'gi_gu', 'musculoskeletal', 'neurological', 'respiratory']
Num classes: 6


In [88]:
print("XGB num_class param:", xgb_model.get_params().get("num_class"))


XGB num_class param: 6


In [90]:
import numpy as np
pred_encoded = xgb_model.predict(X_optb_vec)
unique, counts = np.unique(pred_encoded, return_counts=True)
print("Predicted encoded class distribution:", dict(zip(unique, counts)))


Predicted encoded class distribution: {2: 5, 4: 15, 5: 7}


In [92]:
nnz = X_optb_vec.getnnz(axis=1)  # non-zero tfidf features per doc
print(pd.Series(nnz).describe())

clinical_cases["nnz"] = nnz
clinical_cases.sort_values("nnz").head(10)[["case_title","nnz","note_clean"]]


count    27.000000
mean     17.962963
std       6.022276
min       7.000000
25%      13.000000
50%      17.000000
75%      20.500000
max      34.000000
dtype: float64


Unnamed: 0,case_title,nnz,note_clean
10,Chronic Tonsillitis,7,recurrent sore throats halitosis tonsillar enlargement crypts managed antibiotics tonsillectomy
16,Acute Bronchiolitis,12,infant cough rapid breathing wheeze viral etiology likely managed oxygen hydration nasal suction
7,Lymphangitis,12,adult painful red streak limb distal infection fever possible treated systemic antibiotics limb elevation
12,Severe Anaemia,12,fatigue dyspnea pallor labs low hemoglobin managed supplementation transfusion identifying cause
9,Left Eye Cataract,13,gradual painless visual loss lens opacity exam managed cataract extraction iol implantation
5,Term Pregnancy – LSCS,13,pregnant woman term undergoing caesarean section obstetric indications post op care antibiotics fluids uterotonics
6,Acute Fissure in Ano,13,adult severe anal pain defecation fresh bleeding exam linear anal ulcer managed sitz bath stool softeners topical anesthetics
8,Right Eye Dacryostenosis,13,chronic tearing regurgitation pressure nasolacrimal duct obstruction managed massage topical antibiotics surgical probing needed
17,Bronchopneumonia,14,child fever cough crepitations patchy infiltrates imaging treated antibiotics supportive care
25,Inguinal Hernia,14,adult groin swelling abdominal discomfort reducible mass ultrasound confirms inguinal hernia managed elective hernioplasty
