# Symptom -> Disease Prediction Model

This notebook builds a multi-class classifier that maps binary symptom indicators to likely diseases.

## Objectives
- Validate the disease and symptom coverage in `data/dataset.csv`.
- Transform the long symptom format into a binary feature matrix (133 symptoms expected).
- Train and evaluate an ensemble model that predicts 41 diseases from symptoms.
- Persist the trained pipeline so Crew AI tools can perform real inference.
- Provide helper utilities for text-based inference to plug into agents.

In [3]:
from __future__ import annotations

import json
import re
from datetime import datetime
from pathlib import Path
from typing import Iterable, List

import joblib
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, top_k_accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, MultiLabelBinarizer

from src.processors.symptom_data_processor import SymptomDataProcessor

pd.set_option('display.max_columns', 120)
pd.set_option('display.width', 160)

In [None]:
def locate_project_root(marker: str = 'data/dataset.csv') -> Path:
    current = Path.cwd().resolve()
    for _ in range(5):
        candidate = current / marker
        if candidate.exists():
            return current
        current = current.parent
    raise FileNotFoundError(
        'Could not locate project root containing data/dataset.csv. Please run this notebook inside the repo.'
    )

PROJECT_ROOT = locate_project_root()
DATASET_FILE = PROJECT_ROOT / 'data' / 'dataset.csv'
MODEL_OUTPUT_DIR = PROJECT_ROOT / 'outputs' / 'task3_tool1'
MODEL_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

print(f'Project root: {PROJECT_ROOT}')
print(f'Using dataset: {DATASET_FILE.relative_to(PROJECT_ROOT)}')
print(f'Model artifacts will be stored in: {MODEL_OUTPUT_DIR.relative_to(PROJECT_ROOT)}')

Project root: /Users/iainrodericktay/SC4020-Group-Project-2
Using dataset: data/dataset.csv
Model artifacts will be stored in: outputs/models


In [3]:
raw_df = pd.read_csv(DATASET_FILE)
symptom_cols = [col for col in raw_df.columns if col.lower().startswith('symptom')]

print(f'Total rows: {len(raw_df):,}')
print(f'Symptom columns detected: {len(symptom_cols)}')
raw_df.head()

Total rows: 4,920
Symptom columns detected: 17


Unnamed: 0,Disease,Symptom_1,Symptom_2,Symptom_3,Symptom_4,Symptom_5,Symptom_6,Symptom_7,Symptom_8,Symptom_9,Symptom_10,Symptom_11,Symptom_12,Symptom_13,Symptom_14,Symptom_15,Symptom_16,Symptom_17
0,Fungal infection,itching,skin_rash,nodal_skin_eruptions,dischromic _patches,,,,,,,,,,,,,
1,Fungal infection,skin_rash,nodal_skin_eruptions,dischromic _patches,,,,,,,,,,,,,,
2,Fungal infection,itching,nodal_skin_eruptions,dischromic _patches,,,,,,,,,,,,,,
3,Fungal infection,itching,skin_rash,dischromic _patches,,,,,,,,,,,,,,
4,Fungal infection,itching,skin_rash,nodal_skin_eruptions,,,,,,,,,,,,,,


## Symptom normalization
The dataset mixes underscores, spaces, and spelling variations (e.g., `dischromic _patches`).
We will re-use `SymptomDataProcessor` plus extra regex cleaning to map every entry onto a stable
feature name.

In [4]:
symptom_processor = SymptomDataProcessor(data_path=str(DATASET_FILE))

def to_feature_name(value: str) -> str:
    if not isinstance(value, str):
        return ''
    normalized = symptom_processor.normalize_symptom(value)
    normalized = normalized.lower().strip()
    normalized = normalized.replace('-', ' ')
    normalized = re.sub(r'[^a-z0-9\s]', ' ', normalized)
    normalized = re.sub(r'\s+', '_', normalized)
    return normalized.strip('_')

def collect_symptoms(row: pd.Series) -> list[str]:
    seen = set()
    cleaned: list[str] = []
    for value in row.values:
        feature = to_feature_name(value)
        if feature and feature not in seen:
            seen.add(feature)
            cleaned.append(feature)
    return cleaned

raw_df['symptom_list'] = raw_df[symptom_cols].apply(collect_symptoms, axis=1)
symptom_universe = sorted({symptom for items in raw_df['symptom_list'] for symptom in items})

print('Detected diseases:', raw_df['Disease'].nunique())
print('Unique normalized symptoms:', len(symptom_universe), '(expected ~133)')
symptom_universe[:15]

Detected diseases: 41
Unique normalized symptoms: 118 (expected ~133)


['abdominal_pain',
 'abnormal_menstruation',
 'acidity',
 'acute_liver_failure',
 'altered_sensorium',
 'anal_discomfort',
 'anxiety',
 'back_pain',
 'blackheads',
 'bladder_discomfort',
 'blister',
 'blood_in_sputum',
 'bloody_stool',
 'blurred_and_distorted_vision',
 'breathlessness']

In [5]:
mlb = MultiLabelBinarizer(classes=symptom_universe)
symptom_matrix = mlb.fit_transform(raw_df['symptom_list'])
symptom_frame = pd.DataFrame(symptom_matrix, columns=mlb.classes_, index=raw_df.index).astype(np.uint8)
dataset_binary = pd.concat([raw_df[['Disease']], symptom_frame], axis=1)

print(f'Final feature matrix shape: {symptom_frame.shape}')
dataset_binary.head()

Final feature matrix shape: (4920, 118)


Unnamed: 0,Disease,abdominal_pain,abnormal_menstruation,acidity,acute_liver_failure,altered_sensorium,anal_discomfort,anxiety,back_pain,blackheads,bladder_discomfort,blister,blood_in_sputum,bloody_stool,blurred_and_distorted_vision,breathlessness,brittle_nails,bruising,chest_pain,chills,cold_hands_and_feet,coma,congestion,constipation,continuous_feel_of_urine,continuous_sneezing,cough,cramps,dark_urine,dehydration,depression,diarrhoea,dischromic_patches,distention_of_abdomen,dizziness,drying_and_tingling_lips,enlarged_thyroid,extra_marital_contacts,family_history,fast_heart_rate,fatigue,fluid_overload,foul_smell_of_urine,headache,high_fever,hip_joint_pain,history_of_alcohol_consumption,increased_appetite,indigestion,inflammatory_nails,irregular_sugar_level,irritability,itching,joint_pain,knee_pain,lack_of_concentration,loss_of_appetite,loss_of_smell,mild_fever,mood_swings,movement_stiffness,mucoid_sputum,muscle_pain,muscle_wasting,muscle_weakness,neck_pain,nodal_skin_eruptions,obesity,pain_behind_the_eyes,pain_during_bowel_movements,painful_urination,painful_walking,palpitations,passage_of_gases,patches_in_throat,phlegm,polyuria,prominent_veins_on_calf,puffy_face_and_eyes,pus_filled_pimples,receiving_blood_transfusion,receiving_unsterile_injections,red_sore_around_nose,red_spots_over_body,redness_of_eyes,restlessness,runny_nose,rusty_sputum,shivering,silver_like_dusting,sinus_pressure,skin_peeling,skin_rash,skin_scaling,slurred_speech,small_dents_in_nails,spotting_during_urination,stiff_neck,stomach_bleeding,sunken_eyes,sweating,swelled_lymph_nodes,swelling_joints,swollen_blood_vessels,swollen_extremities,throat_irritation,toxic_look_typhus,ulcers_on_tongue,unsteadiness,visual_disturbances,vomiting,watering_from_eyes,weakness_of_one_body_side,weight_gain,weight_loss,yellow_crust_ooze,yellow_urine,yellowing_of_eyes,yellowish_skin
0,Fungal infection,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
1,Fungal infection,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
2,Fungal infection,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
3,Fungal infection,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
4,Fungal infection,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0


In [6]:
label_encoder = LabelEncoder()
y = label_encoder.fit_transform(raw_df['Disease'])
X = symptom_frame.values

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, stratify=y, random_state=42
)

print(f'Train size: {X_train.shape[0]}, Test size: {X_test.shape[0]}')

Train size: 3936, Test size: 984


In [7]:
rf_clf = RandomForestClassifier(
    n_estimators=600,
    max_depth=None,
    min_samples_leaf=1,
    class_weight='balanced_subsample',
    random_state=42,
    n_jobs=-1
)
rf_clf.fit(X_train, y_train)

print('RandomForest training complete.')

RandomForest training complete.


In [8]:
train_pred = rf_clf.predict(X_train)
test_pred = rf_clf.predict(X_test)
test_proba = rf_clf.predict_proba(X_test)

metrics = {
    'train_accuracy': accuracy_score(y_train, train_pred),
    'test_accuracy': accuracy_score(y_test, test_pred),
    'top3_accuracy': top_k_accuracy_score(y_test, test_proba, k=3, labels=rf_clf.classes_),
}
metrics

{'train_accuracy': 1.0, 'test_accuracy': 1.0, 'top3_accuracy': 1.0}

In [9]:
print('Classification report (test set):')
print(classification_report(y_test, test_pred, target_names=label_encoder.classes_))

importance_df = (
    pd.DataFrame({'symptom': mlb.classes_, 'importance': rf_clf.feature_importances_})
    .sort_values('importance', ascending=False)
    .head(20)
)
importance_df

Classification report (test set):
                                         precision    recall  f1-score   support

(vertigo) Paroymsal  Positional Vertigo       1.00      1.00      1.00        24
                                   AIDS       1.00      1.00      1.00        24
                                   Acne       1.00      1.00      1.00        24
                    Alcoholic hepatitis       1.00      1.00      1.00        24
                                Allergy       1.00      1.00      1.00        24
                              Arthritis       1.00      1.00      1.00        24
                       Bronchial Asthma       1.00      1.00      1.00        24
                   Cervical spondylosis       1.00      1.00      1.00        24
                            Chicken pox       1.00      1.00      1.00        24
                    Chronic cholestasis       1.00      1.00      1.00        24
                            Common Cold       1.00      1.00      1.00    

Unnamed: 0,symptom,importance
61,muscle_pain,0.021288
109,vomiting,0.019433
116,yellowing_of_eyes,0.017681
107,unsteadiness,0.017497
17,chest_pain,0.016553
0,abdominal_pain,0.016211
51,itching,0.016178
32,distention_of_abdomen,0.015934
39,fatigue,0.01592
5,anal_discomfort,0.01536


In [10]:
artifact = {
    'model': rf_clf,
    'symptom_names': mlb.classes_.tolist(),
    'label_encoder_classes': label_encoder.classes_.tolist(),
    'created_at': datetime.utcnow().isoformat() + 'Z',
    'metrics': metrics,
}
artifact_path = MODEL_OUTPUT_DIR / 'disease_prediction_model.pkl'
joblib.dump(artifact, artifact_path)
print(f'Model artifact saved to {artifact_path.relative_to(PROJECT_ROOT)}')

Model artifact saved to outputs/models/disease_prediction_model.pkl


  'created_at': datetime.utcnow().isoformat() + 'Z',


In [1]:
symptom_index = {name: idx for idx, name in enumerate(mlb.classes_)}

def predict_from_text(symptom_phrases: Iterable[str], top_k: int = 5) -> pd.DataFrame:
    vector = np.zeros(len(symptom_index), dtype=np.float32)
    normalized_inputs: list[str] = []
    for phrase in symptom_phrases:
        feature = to_feature_name(phrase)
        if feature and feature in symptom_index:
            normalized_inputs.append(feature)
            vector[symptom_index[feature]] = 1.0
    if not normalized_inputs:
        raise ValueError('No recognizable symptoms provided.')
    probs = rf_clf.predict_proba(vector.reshape(1, -1))[0]
    ranked = np.argsort(probs)[::-1][:top_k]
    return pd.DataFrame({
        'disease': label_encoder.inverse_transform(ranked),
        'probability': probs[ranked]
    })

example_predictions = predict_from_text(['I have pyrexia', 'running nose', 'yellow pee'])
example_predictions

NameError: name 'mlb' is not defined

In [6]:
from collections.abc import Mapping, Sequence
import pickle

FREQ_ITEMSETS_PATH = PROJECT_ROOT / "outputs" / "disease_frequent_itemsets.pkl"
print(f"Loading frequent itemsets from: {FREQ_ITEMSETS_PATH.relative_to(PROJECT_ROOT)}")
if not FREQ_ITEMSETS_PATH.exists():
    raise FileNotFoundError(f"Missing file: {FREQ_ITEMSETS_PATH}")

with FREQ_ITEMSETS_PATH.open("rb") as fh:
    frequent_itemsets = pickle.load(fh)

print("Object type:", type(frequent_itemsets))
if isinstance(frequent_itemsets, Mapping):
    print("Mapping keys sample:", list(frequent_itemsets.keys())[:5])
elif isinstance(frequent_itemsets, Sequence) and not isinstance(frequent_itemsets, (str, bytes, bytearray)):
    print("Sequence length:", len(frequent_itemsets))
    print("First element example:", frequent_itemsets[:1])
else:
    print("Preview:", frequent_itemsets)

frequent_itemsets


Loading frequent itemsets from: outputs/disease_frequent_itemsets.pkl
Object type: <class 'pandas.core.frame.DataFrame'>
Preview:                              itemset   support                                  disease
0               (headache, vomiting)  0.146341  (vertigo) Paroymsal  Positional Vertigo
1          (unsteadiness, dizziness)  0.073171  (vertigo) Paroymsal  Positional Vertigo
2         (abdominal pain, vomiting)  0.219512                      Alcoholic hepatitis
3   (abdominal pain, yellowish skin)  0.170732                      Alcoholic hepatitis
4         (yellowish skin, vomiting)  0.170732                      Alcoholic hepatitis
..                               ...       ...                                      ...
74             (fatigue, high fever)  0.219512                                  Typhoid
75               (vomiting, fatigue)  0.195122                                  Typhoid
76        (abdominal pain, vomiting)  0.219512                              he

Unnamed: 0,itemset,support,disease
0,"(headache, vomiting)",0.146341,(vertigo) Paroymsal Positional Vertigo
1,"(unsteadiness, dizziness)",0.073171,(vertigo) Paroymsal Positional Vertigo
2,"(abdominal pain, vomiting)",0.219512,Alcoholic hepatitis
3,"(abdominal pain, yellowish skin)",0.170732,Alcoholic hepatitis
4,"(yellowish skin, vomiting)",0.170732,Alcoholic hepatitis
...,...,...,...
74,"(fatigue, high fever)",0.219512,Typhoid
75,"(vomiting, fatigue)",0.195122,Typhoid
76,"(abdominal pain, vomiting)",0.219512,hepatitis A
77,"(vomiting, loss of appetite)",0.195122,hepatitis A
