In [None]:
import pandas as pd
import pickle
from tqdm import tqdm
import ollama
import re


In [5]:
excel_path = 'PUBMED_cohort.xlsx'

df = pd.read_excel(excel_path)
print(df.columns)
df.head()


Index(['Source Identifier', 'PMID', 'OMIM Diagnosis', 'Diagnoses', 'Symptoms'], dtype='object')


Unnamed: 0,Source Identifier,PMID,OMIM Diagnosis,Diagnoses,Symptoms
0,"De Maria et al., 2021 (CHD2)",34713950.0,DEVELOPMENTAL AND EPILEPTIC ENCEPHALOPATHY 94;...,CHD2-related disease; Epilepsy with myoclonic-...,Epilepsy (present in 92.2% of patients); Media...
1,"Perinelli et al., 2024 (CDKL5)",38660576.0,DEVELOPMENTAL AND EPILEPTIC ENCEPHALOPATHY 2; ...,CDKL5 deficiency disorder (CDKL5-DEE),Cortical visual impairment (CVI); Deficits in ...
2,"Whitney et al., 2023 (ST3GAL3)",37067065.0,DEVELOPMENTAL AND EPILEPTIC ENCEPHALOPATHY 15;...,ST3GAL3-related developmental and epileptic en...,Early-onset epilepsy; Seizure types: epileptic...
3,"Monfrini et al., 2023 (GABRB1)",37518907.0,DEVELOPMENTAL AND EPILEPTIC ENCEPHALOPATHY 45;...,Developmental and epileptic encephalopathy 45 ...,Hypotonia (at birth); Focal apneic seizures (a...
4,"Di Micco et al., 2024 (CACNA1E)",38780451.0,DEVELOPMENTAL AND EPILEPTIC ENCEPHALOPATHY 69;...,CACNA1E developmental and epileptic encephalop...,EEG: Polymorphic delta activity; Multifocal in...


**Extract and normalize symptom strings**

* Convert to string
* Split one cell into multiple symptoms
* Strip whitespace
* Drop empty entries
* Deduplicate
* Splits on ;
* Example: Epilepsy (present in 92.2% of patients); Median age at seizure onset: 2 years 6 months; Multiple seizure types (generalized onset tonic-clonic, absences, myoclonic, focal-onset, atonic, myoclonic-atonic, tonic, epileptic spasms, myoclonic-clonic); Single seizure types (generalized onset tonic-clonic, absences, focal, myoclonic, myoclonic-atonic); History of febrile seizures; Status epilepticus (SE) (convulsive, nonconvulsive, or both); Prevalent epilepsy type: generalized (75.5%), combined generalized and focal (22.3%), focal (2.2%); Aggressive behavior; Attention Deficit Hyperactivity Disorder (ADHD); Autism spectrum disorder (ASD) / autistic features; Specific Learning Disorder (SLD); Obsessive-Compulsive Disorder (OCD); Psychotic symptoms; Tourette's syndrome; Bipolar disorder; Childhood-onset schizophrenia; Brain MRI findings: normal (majority), or cerebellar hypoplasia, cerebellar vermis hypoplasia, cerebellar atrophy, Arnold-Chiari type I malformation, white matter hyperintensity.

In [11]:
# 1) Ensure column is string
df['Symptoms'] = df['Symptoms'].astype(str)

all_symptoms = []

for raw in df['Symptoms']:
    # Split by ';'
    parts = re.split(';', raw)
    for p in parts:
        s = p.strip()
        if s:
            all_symptoms.append(s)

print("Total symptom entries (with duplicates):", len(all_symptoms))

# 2) Deduplicate while preserving order
seen = set()
unique_symptoms = []
for s in all_symptoms:
    if s not in seen:
        seen.add(s)
        unique_symptoms.append(s)

print("Unique symptoms:", len(unique_symptoms))
print("First 10 unique symptoms:", unique_symptoms[:10])


Total symptom entries (with duplicates): 1383
Unique symptoms: 1205
First 10 unique symptoms: ['Epilepsy (present in 92.2% of patients)', 'Median age at seizure onset: 2 years 6 months', 'Multiple seizure types (generalized onset tonic-clonic, absences, myoclonic, focal-onset, atonic, myoclonic-atonic, tonic, epileptic spasms, myoclonic-clonic)', 'Single seizure types (generalized onset tonic-clonic, absences, focal, myoclonic, myoclonic-atonic)', 'History of febrile seizures', 'Status epilepticus (SE) (convulsive, nonconvulsive, or both)', 'Prevalent epilepsy type: generalized (75.5%), combined generalized and focal (22.3%), focal (2.2%)', 'Aggressive behavior', 'Attention Deficit Hyperactivity Disorder (ADHD)', 'Autism spectrum disorder (ASD) / autistic features']


In [12]:
# embed each unique symptom with Ollama
def embed_symptom(text: str, model: str = "mxbai-embed-large"):
    res = ollama.embeddings(model=model, prompt=text)
    return res["embedding"]

symptom_embeddings = {}

tqdm_bar = tqdm(unique_symptoms, desc="Embedding symptoms")

for symptom in tqdm_bar:
    try:
        emb = embed_symptom(symptom)
        symptom_embeddings[symptom] = emb
    except Exception as e:
        print(f"Error embedding '{symptom}': {e}")
        # you could also decide to continue silently or log this


Embedding symptoms: 100%|██████████| 1205/1205 [00:16<00:00, 71.54it/s]


In [13]:
print("Number of keys in dict:", len(symptom_embeddings))

first_symptom = next(iter(symptom_embeddings.keys()))
print("Example symptom:", first_symptom)
print("Embedding length:", len(symptom_embeddings[first_symptom]))
print("First 5 dims:", symptom_embeddings[first_symptom][:5])


Number of keys in dict: 1205
Example symptom: Epilepsy (present in 92.2% of patients)
Embedding length: 1024
First 5 dims: [-0.4378899037837982, 0.868465006351471, -0.28697219491004944, 1.1439778804779053, -0.30676478147506714]


In [14]:
output_path = "pubmed_symptom_embeddings_dict_mxbai.pkl"

with open(output_path, "wb") as f:
    pickle.dump(symptom_embeddings, f)

print("Saved symptom embedding dict to:", output_path)


Saved symptom embedding dict to: pubmed_symptom_embeddings_dict_mxbai.pkl


In [24]:
import pickle

with open(r"symptom_embeddings_dict_mxbai.pkl", "rb") as f:
    pubmed_symptom_dict = pickle.load(f)

for i, (sym, emb) in enumerate(pubmed_symptom_dict.items()):
    print(i+1, "Symptom:", sym)
    print("   Embedding length:", len(emb))
    print("   First 5 dims:", emb[:5])
    if i == 4:
        break


1 Symptom: Delayed development, variable severity, from birth in some patients
   Embedding length: 1024
   First 5 dims: [-0.456976    0.28462172 -0.14181492 -0.18142879 -0.47356337]
2 Symptom: Developmental regression in about 50% of patients
   Embedding length: 1024
   First 5 dims: [ 0.15248707  0.08386248 -0.17723799  0.3902016  -0.579988  ]
3 Symptom: Normal development in some patients
   Embedding length: 1024
   First 5 dims: [ 0.04364679  0.13133353  0.28363776  0.07896416 -0.39208084]
4 Symptom: Seizures, convulsive
   Embedding length: 1024
   First 5 dims: [-0.33330166 -0.4741196  -0.17374632  0.52419686 -0.57776725]
5 Symptom: Seizures, tonic-clonic
   Embedding length: 1024
   First 5 dims: [-0.06894398  0.07931923 -0.15660666  0.59845024 -1.29251409]


**Cluster Centroid Assignment**

For each cluster:
* Embed its phrases
* Average them → cluster centroid embedding
* Save an output file with cluster number, cluster symptoms and cluster centroid

For each PubMed symptom:
* Compute distance to all cluster centroids
* Assign it to the closest cluster

In [2]:
import pickle
import numpy as np
from collections import defaultdict
from tqdm.auto import tqdm
import ollama
import pandas as pd


In [3]:
pubmed_pkl_path = "pubmed_symptom_embeddings_dict_mxbai.pkl"

with open(pubmed_pkl_path, "rb") as f:
    pubmed_sym2emb = pickle.load(f)

print("Number of PubMed symptoms:", len(pubmed_sym2emb))

# Turn into ordered lists + matrix
pubmed_symptoms = list(pubmed_sym2emb.keys())
pubmed_matrix = np.array([pubmed_sym2emb[s] for s in pubmed_symptoms], dtype=np.float32)

print("PubMed embedding matrix shape:", pubmed_matrix.shape)  # (n_symptoms, dim)


Number of PubMed symptoms: 1205
PubMed embedding matrix shape: (1205, 1024)


Parse the cluster file into cluster_id -> [symptom1, symptom2,...]

In [4]:
cluster_file_path = "clustered_terms_10000_mxbai_cleaned_jun17.txt"

cluster_terms = defaultdict(list)
current_cluster = None

with open(cluster_file_path, "r", encoding="utf-8") as f:
    for line in f:
        line = line.strip()
        if not line:
            continue
        if line.startswith("Cluster "):
            # e.g. "Cluster 0:" → extract 0
            num_str = line.split("Cluster", 1)[1].split(":", 1)[0].strip()
            current_cluster = int(num_str)
        else:
            # cluster member term
            if current_cluster is not None:
                cluster_terms[current_cluster].append(line)

print("Number of clusters:", len(cluster_terms))
print("Cluster 0 examples:", cluster_terms[0][:5])


Number of clusters: 10000
Cluster 0 examples: ['dysmorphic facial features', 'dysmorphic features', 'dysmorphic face']


In [5]:
# Collect all unique cluster terms
all_cluster_terms = sorted({term for terms in cluster_terms.values() for term in terms})
print("Unique cluster terms:", len(all_cluster_terms))

Unique cluster terms: 14505


In [6]:
def embed_text(text: str, model: str = "mxbai-embed-large"):
    res = ollama.embeddings(model=model, prompt=text)
    return res["embedding"]

cluster_term2emb = {}

for term in tqdm(all_cluster_terms, desc="Embedding cluster terms"):
    try:
        cluster_term2emb[term] = embed_text(term)
    except Exception as e:
        print(f"Error embedding '{term}': {e}")

Embedding cluster terms:   0%|          | 0/14505 [00:00<?, ?it/s]

Compute centroid embedding for each cluster

In [7]:
# get embedding dimension
example_emb = next(iter(cluster_term2emb.values()))
dim = len(example_emb)
print("Embedding dimension:", dim)

cluster_ids = sorted(cluster_terms.keys())
num_clusters = len(cluster_ids)

centroids = np.zeros((num_clusters, dim), dtype=np.float32)

for idx, cid in enumerate(cluster_ids):
    embs = [cluster_term2emb[t] for t in cluster_terms[cid] if t in cluster_term2emb]
    if not embs:
        # Leave centroid as zeros if no embeddings (shouldn't happen here)
        continue
    embs = np.array(embs, dtype=np.float32)
    centroids[idx] = embs.mean(axis=0)

print("Centroids shape:", centroids.shape)  # (num_clusters, dim)


Embedding dimension: 1024
Centroids shape: (10000, 1024)


In [9]:
# save the cluster id, list of symptoms and centroids info to a output json file
cluster_info = {}

for cluster_id in cluster_ids:  # cluster_ids is the sorted list of cluster numbers
    cluster_info[cluster_id] = {
        "cluster_id": cluster_id,
        "terms": cluster_terms[cluster_id],            # list of original cluster terms
        "centroid": centroids[cluster_id].tolist()     # convert numpy array to list
    }

import json

output_json_path = "cluster_centroids_mxbai_10000.json"

with open(output_json_path, "w", encoding="utf-8") as f:
    json.dump(cluster_info, f, indent=2)

print("Saved JSON:", output_json_path)

Saved JSON: cluster_centroids_mxbai_10000.json


Assign each PubMed symptom to the nearest cluster

In [10]:
def normalize_rows(mat: np.ndarray) -> np.ndarray:
    norms = np.linalg.norm(mat, axis=1, keepdims=True)
    norms[norms == 0] = 1.0
    return mat / norms

pubmed_norm = normalize_rows(pubmed_matrix)
centroids_norm = normalize_rows(centroids)

# similarity[i, j] = cosine similarity between pubmed_symptoms[i] and cluster_ids[j]
similarity = pubmed_norm @ centroids_norm.T   # (n_symptoms, n_clusters)

print("Similarity matrix shape:", similarity.shape)


Similarity matrix shape: (1205, 10000)


In [11]:
# Get the closest cluster for each symptom
best_cluster_idx = similarity.argmax(axis=1)     # index in cluster_ids list
best_similarity = similarity.max(axis=1)

assigned_cluster_ids = [cluster_ids[i] for i in best_cluster_idx]


In [12]:
assign_df = pd.DataFrame({
    "symptom": pubmed_symptoms,
    "cluster_id": assigned_cluster_ids,
    "similarity": best_similarity,
})

assign_df.head()

output_assign_path = "pubmed_symptom_cluster_assignments.csv"
assign_df.to_csv(output_assign_path, index=False)
print("Saved assignments to:", output_assign_path)


Saved assignments to: pubmed_symptom_cluster_assignments.csv
