In [None]:
import pickle

In [None]:
with open("/path-to-folder/SD++.p","rb") as f:
    original_kg = pickle.load(f)

In [None]:
with open("/path-to-folder/SID.p","rb") as f:
    symptomID = pickle.load(f)

In [None]:
with open("/path-to-folder/DID.p","rb") as f:
    diseaseID = pickle.load(f)

In [None]:
diseaseID

{'Acanthosis nigricans': 'D_3',
 'Acariasis': 'D_85',
 'Acne': 'D_53',
 'Actinic keratosis': 'D_12',
 'Acute glaucoma': 'D_48',
 'Acute kidney injury': 'D_86',
 'Acute stress reaction': 'D_76',
 'Adhesive capsulitis of the shoulder': 'D_42',
 'Adjustment reaction': 'D_33',
 'Air embolism': 'D_78',
 'Alcohol intoxication': 'D_47',
 'Allergy': 'D_28',
 'Alzheimer disease': 'D_79',
 'Amyloidosis': 'D_59',
 'Amyotrophic lateral sclerosis ALS': 'D_4',
 'Ankylosing spondylitis': 'D_75',
 'Anxiety': 'D_82',
 'Aphakia': 'D_50',
 'Carbon monoxide poisoning': 'D_38',
 'Carcinoid syndrome': 'D_44',
 'Carpal tunnel syndrome': 'D_64',
 'Cat scratch disease': 'D_58',
 'Central retinal artery or vein occlusion': 'D_80',
 'Cerebral edema': 'D_51',
 'Chagas disease': 'D_61',
 'Chalazion': 'D_37',
 'Chancroid': 'D_74',
 'Chickenpox': 'D_27',
 'Chlamydia': 'D_23',
 'Chondromalacia of the patella': 'D_84',
 'Chronic back pain': 'D_19',
 'Chronic kidney disease': 'D_7',
 'Chronic pain disorder': 'D_56',
 '

In [None]:
symptomID_lower = {k.lower(): v for k, v in symptomID.items()}
diseaseID_lower = {k.lower(): v for k, v in diseaseID.items()}

def map_to_ids(dialogs, disease_map, symptom_map):
    result = []
    for dialog in dialogs:
        for disease, symptoms in dialog.items():
            disease_key = disease.lower()
            disease_mapped = disease_map.get(disease_key, disease)  # fallback to original if not found
            symptoms_mapped = []
            for symptom in symptoms:
                symptom_key = symptom.lower()
                symptom_mapped = symptom_map.get(symptom_key, symptom)
                symptoms_mapped.append(symptom_mapped)
            result.append({disease_mapped: symptoms_mapped})
    return result

# mapped_dialogs = map_to_ids(dialogs, diseaseID_lower, symptomID_lower)

In [None]:
with open("/path-to-folder/new_dialog_symptom_map.p","rb") as f:
    dialog_symptoms  = pickle.load(f)

In [None]:
with open("path-to-folder/new_all_self_report_file.p","rb") as f:
    test = pickle.load(f)

In [None]:
dialog_disease_map = {}
for item in test:
    dialog_disease_map[item['dialog_id']] = item['disease_tag']

In [None]:
dialogs = []

for dialogue_id, symptoms in dialog_symptoms.items():
    disease = dialog_disease_map.get(dialogue_id)
    if disease:
        dialogs.append({disease: symptoms})

In [None]:
with open("/path-to-folder/esmmd_disease_symptom_dialog_wise.p","wb") as f:
    pickle.dump(dialogs, f)

In [None]:
dialogs = map_to_ids(dialogs, diseaseID_lower, symptomID_lower)
# new_dialogs

In [None]:
import networkx as nx
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

In [None]:
symptom_disease_map = {}
for source, targets in original_kg.items():
    for target, weight in targets.items():
        if source.startswith('S_') and target.startswith('D_') and weight > 0:
            if source not in symptom_disease_map:
                symptom_disease_map[source] = []
            symptom_disease_map[source].append((target, weight))

In [None]:
y_true = []
y_pred = []

for dialog in dialogs:
    disease_true = list(dialog.keys())[0]
    symptoms = dialog[disease_true]

    # Create graph
    G = nx.DiGraph()

    # Add user and symptom edges
    for symptom in symptoms:
        G.add_edge("User", symptom, weight=1.0)

    # Add edges from symptoms to diseases using KG
    for symptom in symptoms:
        for (disease, weight) in symptom_disease_map.get(symptom, []):
            G.add_edge(symptom, disease, weight=weight)
    ppr = nx.pagerank(G, personalization={"User": 1.0}, alpha=0.85, weight='weight')
    disease_scores = {node: score for node, score in ppr.items() if node.startswith('D_')}

    if not disease_scores:
        predicted_disease = None
    else:
        predicted_disease = max(disease_scores, key=disease_scores.get)
    y_true.append(disease_true)
    y_pred.append(predicted_disease)
filtered_true = []
filtered_pred = []

for t, p in zip(y_true, y_pred):
    if p is not None:
        filtered_true.append(t)
        filtered_pred.append(p)

accuracy = accuracy_score(filtered_true, filtered_pred)
precision = precision_score(filtered_true, filtered_pred, average='macro', zero_division=0)
recall = recall_score(filtered_true, filtered_pred, average='macro', zero_division=0)
f1 = f1_score(filtered_true, filtered_pred, average='macro', zero_division=0)

print("Evaluation Results:")
print(f"Accuracy : {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall   : {recall:.4f}")
print(f"F1 Score : {f1:.4f}")


Evaluation Results:
Accuracy : 0.6569
Precision: 0.6701
Recall   : 0.6559
F1 Score : 0.6106


## With pruning diseases

In [None]:
symptom_disease_map = {}
for source, targets in original_kg.items():
    for target, weight in targets.items():
        if source.startswith('S_') and target.startswith('D_') and weight > 0:
            if source not in symptom_disease_map:
                symptom_disease_map[source] = []
            symptom_disease_map[source].append((target, weight))

y_true = []
y_pred = []

for dialog in dialogs:
    disease_true = list(dialog.keys())[0]
    symptoms = dialog[disease_true]

    # Create graph
    G = nx.DiGraph()

    # Add user and symptom edges
    for symptom in symptoms:
        G.add_edge("User", symptom, weight=1.0)

    # Find diseases connected to ALL symptoms (intersection-based filtering)
    if symptoms:
        # Get diseases for first symptom
        valid_diseases = set()
        first_symptom = symptoms[0]
        for (disease, weight) in symptom_disease_map.get(first_symptom, []):
            valid_diseases.add(disease)

        # Intersect with diseases from remaining symptoms
        for symptom in symptoms[1:]:
            symptom_diseases = set([disease for (disease, weight) in symptom_disease_map.get(symptom, [])])
            valid_diseases = valid_diseases.intersection(symptom_diseases)
    else:
        valid_diseases = set()

    # Add edges from symptoms to diseases ONLY if disease is in valid_diseases
    for symptom in symptoms:
        for (disease, weight) in symptom_disease_map.get(symptom, []):
            if disease in valid_diseases:
                G.add_edge(symptom, disease, weight=weight)

    # Run Personalized PageRank
    ppr = nx.pagerank(G, personalization={"User": 1.0}, alpha=0.85, weight='weight')

    # Filter for only disease nodes
    disease_scores = {node: score for node, score in ppr.items() if node.startswith('D_')}

    if not disease_scores:
        predicted_disease = None
    else:symptomID_lower = {k.lower(): v for k, v in symptomID.items()}
diseaseID_lower = {k.lower(): v for k, v in diseaseID.items()}
        predicted_disease = max(disease_scores, key=disease_scores.get)

    y_true.append(disease_true)
    y_pred.append(predicted_disease)

filtered_true = []
filtered_pred = []

for t, p in zip(y_true, y_pred):
    if p is not None:
        filtered_true.append(t)
        filtered_pred.append(p)

accuracy = accuracy_score(filtered_true, filtered_pred)
precision = precision_score(filtered_true, filtered_pred, average='macro', zero_division=0)
recall = recall_score(filtered_true, filtered_pred, average='macro', zero_division=0)
f1 = f1_score(filtered_true, filtered_pred, average='macro', zero_division=0)

print("Evaluation Results:")
print(f"Accuracy : {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall   : {recall:.4f}")
print(f"F1 Score : {f1:.4f}")

Evaluation Results:
Accuracy : 0.7445
Precision: 0.7746
Recall   : 0.7471
F1 Score : 0.7130


## On MDDIAL

In [None]:
with open("/path-to-folder/mddial_dialog_wise_disease_symptoms.p","rb") as f:
    mddial_dialogs = pickle.load(f)

In [None]:
with open("/path-to-folder/mddial_kg.p","rb") as f:
    mddial_kg = pickle.load(f)

In [None]:
with open("/path-to-folder/symptom.txt","r") as f:
    mddial_symptoms = f

In [None]:
symptom_file_path = "/path-to-folder/symptom.txt"
disease_file_path = "/path-to-folder/disease.txt"
with open(symptom_file_path, "r") as f:
    symptoms = [line.strip().lower() for line in f if line.strip()]
with open(disease_file_path, "r") as f:
    diseases = [line.strip().lower() for line in f if line.strip()]
symptom_dict = {symptom: f"S_{i}" for i, symptom in enumerate(symptoms)}
disease_dict = {disease: f"D_{i}" for i, disease in enumerate(diseases)}

In [None]:
mddial_dialogs_mapped = map_to_ids(mddial_dialogs, disease_dict, symptom_dict)

In [None]:
len(mddial_kg['acid reflux'])

59

In [None]:
merged_symptom_disease_map = {**symptom_dict, **disease_dict}

In [None]:
merged_symptom_disease_map

{'dizziness': 'S_0',
 'chest tightness': 'S_1',
 'burning sensation behind the breastbone': 'S_2',
 'chest tightness and shortness of breath': 'S_3',
 'pain behind the breastbone': 'S_4',
 'acid reflux': 'S_5',
 'nausea': 'S_6',
 'vomiting': 'S_7',
 'hard to swallow': 'S_8',
 'stomach ache': 'S_9',
 'bloating': 'S_10',
 'pharynx discomfort': 'S_11',
 'expectoration': 'S_12',
 'cough': 'S_13',
 'fever': 'S_14',
 'palpitations': 'S_15',
 'diarrhea': 'S_16',
 'feel sick and vomit': 'S_17',
 'loss of appetite': 'S_18',
 'hiccup': 'S_19',
 'thin': 'S_20',
 'anorexia': 'S_21',
 'increased stool frequency': 'S_22',
 'edema': 'S_23',
 'constipation': 'S_24',
 'bitter': 'S_25',
 'thirst': 'S_26',
 'hiccough': 'S_27',
 'hemoptysis': 'S_28',
 'runny nose': 'S_29',
 'twitch': 'S_30',
 'suppuration': 'S_31',
 'chills and fever': 'S_32',
 'black stool': 'S_33',
 'sweating': 'S_34',
 'shortness of breath': 'S_35',
 'poor spirits': 'S_36',
 'poor sleep': 'S_37',
 'stuffy nose': 'S_38',
 'hoarse': 'S_3

In [None]:
with open("/path-to-folder/SID.p","wb") as f:
    pickle.dump(symptom_dict,f)

In [None]:
with open("/path-to-folder/DID.p","wb") as f:
    pickle.dump(disease_dict,f)

In [None]:
mddial_kg

{'shortness of breath': {'esophagitis': 0.0,
  'enteritis': 0.0,
  'asthma': 0.0,
  'coronary heart disease': 0.0,
  'pneumonia': 0.0,
  'rhinitis': 0.0,
  'thyroiditis': 0.0,
  'traumatic brain injury': 0.0,
  'dermatitis': 0.0,
  'external otitis': 0.0,
  'conjunctivitis': 0.0,
  'mastitis': 0.0,
  'pain behind the breastbone': 0.0710128055878929,
  'expectoration': 0.07159487776484284,
  'chest tightness and shortness of breath': 0.13465987968533086,
  'chest tightness': 0.09742539496781744,
  'hemoptysis': 0.109375,
  'cough': 0.0629353831371607,
  'pharynx discomfort': 0.04317868626550299,
  'acid reflux': 0.02471169686985173,
  'hiccup': 0.024096385542168676,
  'vomiting': 0.03190460844344183,
  'nausea': 0.024828314844162706,
  'sweating': 0.036458333333333336,
  'diarrhea': 0.03563373928732522,
  'stomach ache': 0.03532490187527257,
  'hiccough': 0.037037037037037035,
  'black stool': 0.034482758620689655,
  'palpitations': 0.060836501901140684,
  'fever': 0.043275942298743604,

In [None]:
new_mddial_kg = {}

for outer_key, nested in mddial_kg.items():
    new_outer_key = merged_symptom_disease_map.get(outer_key, outer_key)

    new_nested = {}
    for key, value in nested.items():
        new_key = merged_symptom_disease_map.get(key, key)
        new_nested[new_key] = value

    new_mddial_kg[new_outer_key] = new_nested
mddial_kg = new_mddial_kg


In [None]:
with open("/path-to-folder/SD++.p","wb") as f:
    pickle.dump(mddial_kg,f)

In [None]:
symptom_disease_map = {}
for source, targets in mddial_kg.items():
    for target, weight in targets.items():
        if source.startswith('S_') and target.startswith('D_') and weight > 0:
            if source not in symptom_disease_map:
                symptom_disease_map[source] = []
            symptom_disease_map[source].append((target, weight))

y_true = []
y_pred = []

for dialog in mddial_dialogs_mapped:
    disease_true = list(dialog.keys())[0]
    symptoms = dialog[disease_true]

    # Create graph
    G = nx.DiGraph()

    # Add user and symptom edges
    for symptom in symptoms:
        G.add_edge("User", symptom, weight=1.0)

    # Find diseases connected to ALL symptoms (intersection-based filtering)
    if symptoms:
        # Get diseases for first symptom
        valid_diseases = set()
        first_symptom = symptoms[0]
        for (disease, weight) in symptom_disease_map.get(first_symptom, []):
            valid_diseases.add(disease)

        # Intersect with diseases from remaining symptoms
        for symptom in symptoms[1:]:
            symptom_diseases = set([disease for (disease, weight) in symptom_disease_map.get(symptom, [])])
            valid_diseases = valid_diseases.intersection(symptom_diseases)
    else:
        valid_diseases = set()

    # Add edges from symptoms to diseases ONLY if disease is in valid_diseases
    for symptom in symptoms:
        for (disease, weight) in symptom_disease_map.get(symptom, []):
            if disease in valid_diseases:
                G.add_edge(symptom, disease, weight=weight)

    # Run Personalized PageRank
    ppr = nx.pagerank(G, personalization={"User": 1.0}, alpha=0.85, weight='weight')

    # Filter for only disease nodes
    disease_scores = {node: score for node, score in ppr.items() if node.startswith('D_')}

    if not disease_scores:
        predicted_disease = None
    else:
        predicted_disease = max(disease_scores, key=disease_scores.get)

    y_true.append(disease_true)
    y_pred.append(predicted_disease)

filtered_true = []
filtered_pred = []

for t, p in zip(y_true, y_pred):
    if p is not None:
        filtered_true.append(t)
        filtered_pred.append(p)

accuracy = accuracy_score(filtered_true, filtered_pred)
precision = precision_score(filtered_true, filtered_pred, average='macro', zero_division=0)
recall = recall_score(filtered_true, filtered_pred, average='macro', zero_division=0)
f1 = f1_score(filtered_true, filtered_pred, average='macro', zero_division=0)

print("Evaluation Results:")
print(f"Accuracy : {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall   : {recall:.4f}")
print(f"F1 Score : {f1:.4f}")

Evaluation Results:
Accuracy : 0.9529
Precision: 0.8179
Recall   : 0.7965
F1 Score : 0.8034
