In [1]:
# ===== Notebook 3: Clustering Analysis =====
# Axora: Clinical Document Classification

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from datasets import load_dataset
from scipy.sparse import csr_matrix
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.cluster import KMeans
from sklearn.metrics import (
    homogeneity_score, completeness_score, v_measure_score,
    adjusted_rand_score, adjusted_mutual_info_score, silhouette_score
)

from collections import Counter
import textwrap
import warnings, os, sys, re, math, random

# Plot aesthetics
sns.set(style="whitegrid", font_scale=1.05)
plt.rcParams["figure.figsize"] = (10, 6)

RANDOM_STATE = 42
np.random.seed(RANDOM_STATE)
random.seed(RANDOM_STATE)

warnings.filterwarnings("ignore")
print("✅ Imports ready.")

✅ Imports ready.


In [2]:
# Load the same dataset used in Notebook 1
dataset = load_dataset("hpe-ai/medical-cases-classification-tutorial")

train_df = pd.DataFrame(dataset["train"])
val_df   = pd.DataFrame(dataset["validation"])
test_df  = pd.DataFrame(dataset["test"])

# Keep only the fields we need
cols = ["transcription", "medical_specialty"]
train_df = train_df[cols].copy()
val_df   = val_df[cols].copy()
test_df  = test_df[cols].copy()

# Combine for unsupervised clustering (we'll still keep true labels for evaluation)
all_df = pd.concat(
    [
        train_df.assign(split="train"),
        val_df.assign(split="val"),
        test_df.assign(split="test"),
    ],
    axis=0,
    ignore_index=True
)

# Basic overview
print("=== Dataset Overview (Notebook 3) ===")
print(f"Total documents: {len(all_df)}")
print(f"  Train: {len(train_df)} | Val: {len(val_df)} | Test: {len(test_df)}")
print(f"Unique specialties: {all_df['medical_specialty'].nunique()}")
print(all_df["medical_specialty"].value_counts().head())

Repo card metadata block was not found. Setting CardData to empty.


=== Dataset Overview (Notebook 3) ===
Total documents: 2464
  Train: 1724 | Val: 370 | Test: 370
Unique specialties: 13
medical_specialty
Cardiovascular / Pulmonary    742
Orthopedic                    408
Neurology                     282
Gastroenterology              222
Obstetrics / Gynecology       182
Name: count, dtype: int64


In [3]:
# Minimal cleaner for robustness (keep consistent with earlier logic)
def clean_text(text: str) -> str:
    if not isinstance(text, str):
        text = str(text)
    text = text.lower()
    # keep letters, digits, and spaces; remove boilerplate symbols
    text = re.sub(r"[^a-z0-9\s]", " ", text)
    text = re.sub(r"\s+", " ", text).strip()
    return text

all_df["text_clean"] = all_df["transcription"].map(clean_text)
print("Sample cleaned text:\n", textwrap.shorten(all_df.iloc[0]["text_clean"], width=200))

Sample cleaned text:
 procedure note pacemaker icd interrogation history of present illness the patient is a 67 year old gentleman who was admitted to the hospital he has had icd pacemaker implantation this is a st [...]


In [4]:
# Match Notebook 1 settings: max_features=5000, english stopwords, ngram_range=(1,2)
vectorizer = TfidfVectorizer(
    max_features=5000,
    stop_words="english",
    ngram_range=(1, 2),
    lowercase=False,   # we already lowercased
    norm="l2"          # default; good for cosine-like behavior with euclidean
)

X_all = vectorizer.fit_transform(all_df["text_clean"])
feature_names = np.array(vectorizer.get_feature_names_out())

print("=== TF-IDF Matrix ===")
print("Shape:", X_all.shape)
print("Features (first 10):", feature_names[:10])

=== TF-IDF Matrix ===
Shape: (2464, 5000)
Features (first 10): ['00' '000' '000 epinephrine' '01' '02' '03' '04' '05' '06' '07']


In [5]:
# k = 13 per spec
K = 13
kmeans = KMeans(
    n_clusters=K,
    random_state=RANDOM_STATE,
    n_init=20,
    max_iter=500,
    verbose=0
)
kmeans.fit(X_all)

cluster_labels = kmeans.labels_
all_df["cluster"] = cluster_labels

print("Cluster sizes:", pd.Series(cluster_labels).value_counts().sort_index().to_dict())

Cluster sizes: {0: 48, 1: 81, 2: 228, 3: 120, 4: 205, 5: 555, 6: 142, 7: 52, 8: 533, 9: 132, 10: 84, 11: 243, 12: 41}


In [6]:
def top_terms_per_cluster(centers, feat_names, topn=15):
    """
    centers: (K, D) array of cluster centroids
    feat_names: np.array of length D with feature names
    """
    top = {}
    for c in range(centers.shape[0]):
        # argsort descending on centroid weights
        idx = np.argsort(-centers[c])[:topn]
        top[c] = list(zip(feat_names[idx], centers[c, idx]))
    return top

# k-means centers live in dense space; ensure ndarray
centers = kmeans.cluster_centers_
cluster_top = top_terms_per_cluster(centers, feature_names, topn=15)

# Pretty print
for c in range(K):
    terms = ", ".join([t for t, _ in cluster_top[c]])
    print(f"[Cluster {c}] Top terms: {terms}")

[Cluster 0] Top terms: l5, l4, l4 l5, disc, s1, l5 s1, nerve root, lumbar, central canal, l3, root, disk, neural, nerve, central
[Cluster 1] Top terms: c5, c6, c4, c5 c6, cervical, c4 c5, c3, c7, c6 c7, c3 c4, anterior, anterior cervical, disc, disk, plate
[Cluster 2] Top terms: mg, patient, daily, history, mg daily, day, mg day, dr, discharge, past, status, disease, hospital, medications, years
[Cluster 3] Top terms: scope, procedure, colon, bronchoscope, patient, lobe, biopsies, esophagus, bronchoscopy, normal, stomach, mg, advanced, versed, withdrawn
[Cluster 4] Top terms: artery, coronary, left, right, aortic, coronary artery, valve, stenosis, catheter, right coronary, french, vessel, descending, circumflex, normal
[Cluster 5] Top terms: patient, procedure, placed, right, left, incision, using, anesthesia, used, diagnosis, removed, closed, skin, operating, fashion
[Cluster 6] Top terms: fetal, uterus, cervix, uterine, placed, patient, placenta, incision, cm, vicryl, delivery, weeks

In [10]:
# Crosstab for inspecting alignment between unsupervised clusters and actual specialties
ct = pd.crosstab(all_df["cluster"], all_df["medical_specialty"]).sort_index()
ct_total = ct.sum(axis=1)
ct_pct = (ct.T / ct.sum(axis=1)).T  # row-normalized

print("=== Cluster vs Specialty (counts) ===")
display(ct)

print("=== Cluster vs Specialty (row % normalized) ===")
display(ct_pct.style.format("{:.2%}"))

=== Cluster vs Specialty (counts) ===


medical_specialty,Cardiovascular / Pulmonary,ENT - Otolaryngology,Gastroenterology,Hematology - Oncology,Nephrology,Neurology,Neurosurgery,Obstetrics / Gynecology,Ophthalmology,Orthopedic,Pediatrics - Neonatal,Psychiatry / Psychology,Radiology
cluster,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
0,0,0,0,0,0,9,12,0,0,27,0,0,0
1,0,0,0,0,0,2,38,0,0,41,0,0,0
2,70,0,13,18,17,38,2,1,0,5,0,64,0
3,50,4,66,0,0,0,0,0,0,0,0,0,0
4,192,0,1,0,2,3,1,0,0,0,0,0,6
5,158,45,38,30,15,27,54,17,46,125,0,0,0
6,2,0,18,0,1,0,0,113,0,0,0,0,8
7,2,0,0,0,0,0,0,0,0,50,0,0,0
8,126,21,62,61,24,65,0,38,6,66,64,0,0
9,2,0,0,0,0,85,0,0,0,17,0,0,28


=== Cluster vs Specialty (row % normalized) ===


medical_specialty,Cardiovascular / Pulmonary,ENT - Otolaryngology,Gastroenterology,Hematology - Oncology,Nephrology,Neurology,Neurosurgery,Obstetrics / Gynecology,Ophthalmology,Orthopedic,Pediatrics - Neonatal,Psychiatry / Psychology,Radiology
cluster,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
0,0.00%,0.00%,0.00%,0.00%,0.00%,18.75%,25.00%,0.00%,0.00%,56.25%,0.00%,0.00%,0.00%
1,0.00%,0.00%,0.00%,0.00%,0.00%,2.47%,46.91%,0.00%,0.00%,50.62%,0.00%,0.00%,0.00%
2,30.70%,0.00%,5.70%,7.89%,7.46%,16.67%,0.88%,0.44%,0.00%,2.19%,0.00%,28.07%,0.00%
3,41.67%,3.33%,55.00%,0.00%,0.00%,0.00%,0.00%,0.00%,0.00%,0.00%,0.00%,0.00%,0.00%
4,93.66%,0.00%,0.49%,0.00%,0.98%,1.46%,0.49%,0.00%,0.00%,0.00%,0.00%,0.00%,2.93%
5,28.47%,8.11%,6.85%,5.41%,2.70%,4.86%,9.73%,3.06%,8.29%,22.52%,0.00%,0.00%,0.00%
6,1.41%,0.00%,12.68%,0.00%,0.70%,0.00%,0.00%,79.58%,0.00%,0.00%,0.00%,0.00%,5.63%
7,3.85%,0.00%,0.00%,0.00%,0.00%,0.00%,0.00%,0.00%,0.00%,96.15%,0.00%,0.00%,0.00%
8,23.64%,3.94%,11.63%,11.44%,4.50%,12.20%,0.00%,7.13%,1.13%,12.38%,12.01%,0.00%,0.00%
9,1.52%,0.00%,0.00%,0.00%,0.00%,64.39%,0.00%,0.00%,0.00%,12.88%,0.00%,0.00%,21.21%
