In [1]:
import pandas as pd
import numpy as np
import pickle
import xgboost as xgb
from sklearn.metrics import accuracy_score, classification_report
from sklearn.model_selection import train_test_split
from imblearn.over_sampling import ADASYN

In [2]:
preprocess_file = './model/saved_data/preprocessed_data.pkl'

# Load preprocessed data
with open(preprocess_file, "rb") as f:
    X, y, label_encoder = pickle.load(f)

In [3]:
# Class labels for all 11 classes
class_labels = ["Hypertension",
                "Cardiovascular Disease (CVD)",
                "Chronic Fatigue Syndrome (CFS)",
                "Stress-related Disorders",
                "Healthy",
                "Diabetes",
                "Anaemia",
                "Atherosclerosis",
                "Arrhythmia",
                "Respiratory Disease (COPD or Asthma)",
                "Autonomic Dysfunction"
                ]

# Using ADASYN instead of SMOTE
1. Adapts the sampling process based on the distribution of minority samples
- it creates more synthetic data for regions where minority samples are sparse or harder to learn.
- fewer samples are generated in well-represented regions of the minority class.
2. Reduces the risk of overfitting by focusing on under-represented regions, improving the classifier's generalisation.
3. Prioritises generating synthetic samples where the model struggles, potentially leading to more realistic samples for challenging cases.

In [4]:
# Applying ADASYN for balancing the class distribution
adasyn = ADASYN(random_state=42)
X_resampled, y_resampled = adasyn.fit_resample(X, y)

# Display the class distribution after resampling
print("Class Distribution After ADASYN:")
print(pd.Series(y_resampled).value_counts())

Class Distribution After ADASYN:
Disease Classification
10    223681
5     221591
7     221207
6     220665
2     220503
1     220201
9     220016
3     219614
8     219610
0     217900
4     211734
Name: count, dtype: int64


In [5]:
# Calculate class weightss for 'scale_pos_weight'
class_counts = pd.Series(y_resampled).value_counts()
total = len(y_resampled)
class_weights = {
    cls: total / count for cls, count in class_counts.items()
}

In [6]:
def evaluate_xgb_model(model, X_test, y_test):
    dtest = xgb.DMatrix(X_test)
    predictions = model.predict(dtest)

    if predictions.ndim > 1:
        predictions = np.argmax(predictions, axis=1)
    else:
        predictions = predictions.astype(int)

    accuracy = accuracy_score(y_test, predictions)
    report = classification_report(
        y_test,
        predictions,
        target_names=class_labels,
        labels=np.unique(y_test),
        zero_division=0
    )

    return accuracy, report

In [7]:
def stratified_subsample(X, y, train_size, test_size):
    """
    Creates a stratified subset of data for training and testing based on sample sizes.
    """
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, train_size=train_size, test_size=test_size, stratify=y, random_state=42
    )
    return X_train, X_test, y_train, y_test

In [8]:
best_params = {
    "max_depth": 6,
    "learning_rate": 0.1,
    "n_estimators": 100,
    "subsample": 0.8,
    "colsample_bytree": 0.8,
    "objective": "multi:softmax",
    "num_class": len(np.unique(y_resampled)),
    "scale_pos_weight": [class_weights.get(cls, 1) for cls in np.unique(y_resampled)]
    }

print(f"Best Hyperparameters: {best_params}")

Best Hyperparameters: {'max_depth': 6, 'learning_rate': 0.1, 'n_estimators': 100, 'subsample': 0.8, 'colsample_bytree': 0.8, 'objective': 'multi:softmax', 'num_class': 11, 'scale_pos_weight': [11.09096833409821, 10.975072774419735, 10.960041359981497, 11.004407733568899, 11.413953356570037, 10.906228141034608, 10.951995105703215, 10.925160596183664, 11.00460816902691, 10.984301141735147, 10.804324015003509]}


In [9]:
# Container for results
best_accuracy = 0
best_model = None
best_sample_size = 0

In [10]:
# Iterative training loop
for sample_size in [25, 50, 75, 100, 250, 500, 750, 1000, 2500, 5000, 7500, 10000, 20000, 30000, 40000, 50000]:
    if len(X_resampled) < sample_size:
        print(f"Skipping sample size {sample_size} due to insufficient data.")
        continue

    X_train_subset, X_test_subset, y_train_subset, y_test_subset = stratified_subsample(
        X_resampled, y_resampled, train_size=sample_size, test_size=sample_size // 2
    )

    # Train XGBoost model
    dtrain = xgb.DMatrix(X_train_subset, label=y_train_subset)
    xgb_model = xgb.train(best_params, dtrain, num_boost_round=100, verbose_eval=False)

    # Evaluate model
    accuracy, report = evaluate_xgb_model(xgb_model, X_test_subset, y_test_subset)
    print(f"Sample Size {sample_size}: Accuracy {accuracy:.4f}")
    print("Classification Report:")
    print(report)

    # Store results and track the best model
    if accuracy > best_accuracy:
        best_accuracy = accuracy
        best_model = xgb_model
        best_sample_size = sample_size

Parameters: { "n_estimators", "scale_pos_weight" } are not used.



Sample Size 25: Accuracy 0.4167
Classification Report:
                                      precision    recall  f1-score   support

                        Hypertension       0.33      1.00      0.50         1
        Cardiovascular Disease (CVD)       0.00      0.00      0.00         1
      Chronic Fatigue Syndrome (CFS)       0.00      0.00      0.00         1
            Stress-related Disorders       0.00      0.00      0.00         1
                             Healthy       0.00      0.00      0.00         1
                            Diabetes       1.00      1.00      1.00         1
                             Anaemia       0.00      0.00      0.00         1
                     Atherosclerosis       1.00      1.00      1.00         1
                          Arrhythmia       1.00      1.00      1.00         1
Respiratory Disease (COPD or Asthma)       0.50      1.00      0.67         1
               Autonomic Dysfunction       0.00      0.00      0.00         2

       

Parameters: { "n_estimators", "scale_pos_weight" } are not used.



Sample Size 50: Accuracy 0.5200
Classification Report:
                                      precision    recall  f1-score   support

                        Hypertension       0.00      0.00      0.00         2
        Cardiovascular Disease (CVD)       1.00      0.50      0.67         2
      Chronic Fatigue Syndrome (CFS)       0.67      1.00      0.80         2
            Stress-related Disorders       1.00      1.00      1.00         2
                             Healthy       0.00      0.00      0.00         2
                            Diabetes       0.50      1.00      0.67         3
                             Anaemia       1.00      1.00      1.00         2
                     Atherosclerosis       0.33      0.67      0.44         3
                          Arrhythmia       0.00      0.00      0.00         2
Respiratory Disease (COPD or Asthma)       0.50      0.50      0.50         2
               Autonomic Dysfunction       0.00      0.00      0.00         3

       

Parameters: { "n_estimators", "scale_pos_weight" } are not used.



Sample Size 75: Accuracy 0.8108
Classification Report:
                                      precision    recall  f1-score   support

                        Hypertension       0.50      0.33      0.40         3
        Cardiovascular Disease (CVD)       0.60      1.00      0.75         3
      Chronic Fatigue Syndrome (CFS)       1.00      1.00      1.00         3
            Stress-related Disorders       1.00      0.67      0.80         3
                             Healthy       0.50      0.33      0.40         3
                            Diabetes       0.80      1.00      0.89         4
                             Anaemia       0.80      1.00      0.89         4
                     Atherosclerosis       1.00      1.00      1.00         4
                          Arrhythmia       1.00      0.67      0.80         3
Respiratory Disease (COPD or Asthma)       0.75      1.00      0.86         3
               Autonomic Dysfunction       1.00      0.75      0.86         4

       

Parameters: { "n_estimators", "scale_pos_weight" } are not used.



Sample Size 100: Accuracy 0.7800
Classification Report:
                                      precision    recall  f1-score   support

                        Hypertension       1.00      0.75      0.86         4
        Cardiovascular Disease (CVD)       1.00      1.00      1.00         5
      Chronic Fatigue Syndrome (CFS)       1.00      0.60      0.75         5
            Stress-related Disorders       1.00      0.75      0.86         4
                             Healthy       0.80      1.00      0.89         4
                            Diabetes       1.00      1.00      1.00         5
                             Anaemia       1.00      1.00      1.00         5
                     Atherosclerosis       0.50      1.00      0.67         5
                          Arrhythmia       1.00      0.50      0.67         4
Respiratory Disease (COPD or Asthma)       0.67      0.50      0.57         4
               Autonomic Dysfunction       0.33      0.40      0.36         5

      

Parameters: { "n_estimators", "scale_pos_weight" } are not used.



Sample Size 250: Accuracy 0.9120
Classification Report:
                                      precision    recall  f1-score   support

                        Hypertension       0.62      0.91      0.74        11
        Cardiovascular Disease (CVD)       0.90      0.82      0.86        11
      Chronic Fatigue Syndrome (CFS)       1.00      0.82      0.90        11
            Stress-related Disorders       1.00      1.00      1.00        11
                             Healthy       0.85      1.00      0.92        11
                            Diabetes       1.00      1.00      1.00        12
                             Anaemia       1.00      0.83      0.91        12
                     Atherosclerosis       0.92      0.92      0.92        12
                          Arrhythmia       1.00      0.82      0.90        11
Respiratory Disease (COPD or Asthma)       1.00      1.00      1.00        11
               Autonomic Dysfunction       0.92      0.92      0.92        12

      

Parameters: { "n_estimators", "scale_pos_weight" } are not used.



Sample Size 500: Accuracy 0.9560
Classification Report:
                                      precision    recall  f1-score   support

                        Hypertension       0.83      0.91      0.87        22
        Cardiovascular Disease (CVD)       0.96      0.96      0.96        23
      Chronic Fatigue Syndrome (CFS)       0.88      0.96      0.92        23
            Stress-related Disorders       0.92      1.00      0.96        23
                             Healthy       1.00      0.95      0.98        22
                            Diabetes       1.00      1.00      1.00        23
                             Anaemia       1.00      0.91      0.95        23
                     Atherosclerosis       1.00      0.96      0.98        23
                          Arrhythmia       1.00      1.00      1.00        22
Respiratory Disease (COPD or Asthma)       0.96      0.96      0.96        23
               Autonomic Dysfunction       1.00      0.91      0.95        23

      

Parameters: { "n_estimators", "scale_pos_weight" } are not used.



Sample Size 750: Accuracy 0.9547
Classification Report:
                                      precision    recall  f1-score   support

                        Hypertension       0.83      0.88      0.86        34
        Cardiovascular Disease (CVD)       0.94      1.00      0.97        34
      Chronic Fatigue Syndrome (CFS)       0.94      0.97      0.96        34
            Stress-related Disorders       0.94      1.00      0.97        34
                             Healthy       0.94      1.00      0.97        33
                            Diabetes       1.00      1.00      1.00        35
                             Anaemia       0.97      1.00      0.99        34
                     Atherosclerosis       0.94      0.94      0.94        34
                          Arrhythmia       1.00      0.91      0.95        34
Respiratory Disease (COPD or Asthma)       1.00      0.97      0.99        34
               Autonomic Dysfunction       1.00      0.83      0.91        35

      

Parameters: { "n_estimators", "scale_pos_weight" } are not used.



Sample Size 1000: Accuracy 0.9760
Classification Report:
                                      precision    recall  f1-score   support

                        Hypertension       0.88      0.96      0.91        45
        Cardiovascular Disease (CVD)       0.98      1.00      0.99        46
      Chronic Fatigue Syndrome (CFS)       0.98      0.96      0.97        46
            Stress-related Disorders       0.98      1.00      0.99        45
                             Healthy       1.00      1.00      1.00        44
                            Diabetes       1.00      0.98      0.99        46
                             Anaemia       1.00      1.00      1.00        46
                     Atherosclerosis       0.96      0.96      0.96        46
                          Arrhythmia       0.98      0.98      0.98        45
Respiratory Disease (COPD or Asthma)       1.00      0.96      0.98        45
               Autonomic Dysfunction       1.00      0.96      0.98        46

     

Parameters: { "n_estimators", "scale_pos_weight" } are not used.



Sample Size 2500: Accuracy 0.9864
Classification Report:
                                      precision    recall  f1-score   support

                        Hypertension       0.91      0.97      0.94       113
        Cardiovascular Disease (CVD)       0.97      0.99      0.98       114
      Chronic Fatigue Syndrome (CFS)       1.00      1.00      1.00       114
            Stress-related Disorders       1.00      1.00      1.00       114
                             Healthy       0.98      1.00      0.99       109
                            Diabetes       1.00      1.00      1.00       115
                             Anaemia       0.99      1.00      1.00       114
                     Atherosclerosis       1.00      0.96      0.98       114
                          Arrhythmia       1.00      0.97      0.99       113
Respiratory Disease (COPD or Asthma)       1.00      0.98      0.99       114
               Autonomic Dysfunction       1.00      0.97      0.98       116

     

Parameters: { "n_estimators", "scale_pos_weight" } are not used.



Sample Size 5000: Accuracy 0.9868
Classification Report:
                                      precision    recall  f1-score   support

                        Hypertension       0.92      0.96      0.94       226
        Cardiovascular Disease (CVD)       1.00      0.98      0.99       228
      Chronic Fatigue Syndrome (CFS)       1.00      1.00      1.00       228
            Stress-related Disorders       1.00      1.00      1.00       227
                             Healthy       1.00      1.00      1.00       219
                            Diabetes       1.00      1.00      1.00       229
                             Anaemia       1.00      1.00      1.00       228
                     Atherosclerosis       0.99      0.98      0.98       229
                          Arrhythmia       1.00      1.00      1.00       227
Respiratory Disease (COPD or Asthma)       0.98      1.00      0.99       228
               Autonomic Dysfunction       0.99      0.95      0.97       231

     

Parameters: { "n_estimators", "scale_pos_weight" } are not used.



Sample Size 7500: Accuracy 0.9845
Classification Report:
                                      precision    recall  f1-score   support

                        Hypertension       0.93      0.95      0.94       338
        Cardiovascular Disease (CVD)       0.98      1.00      0.99       342
      Chronic Fatigue Syndrome (CFS)       0.97      1.00      0.98       342
            Stress-related Disorders       0.99      1.00      1.00       341
                             Healthy       0.99      0.99      0.99       329
                            Diabetes       0.99      1.00      1.00       344
                             Anaemia       1.00      1.00      1.00       342
                     Atherosclerosis       0.99      0.97      0.98       343
                          Arrhythmia       1.00      0.99      0.99       341
Respiratory Disease (COPD or Asthma)       0.99      0.99      0.99       341
               Autonomic Dysfunction       0.99      0.96      0.97       347

     

Parameters: { "n_estimators", "scale_pos_weight" } are not used.



Sample Size 10000: Accuracy 0.9866
Classification Report:
                                      precision    recall  f1-score   support

                        Hypertension       0.93      0.96      0.94       451
        Cardiovascular Disease (CVD)       0.98      0.99      0.99       456
      Chronic Fatigue Syndrome (CFS)       0.98      1.00      0.99       456
            Stress-related Disorders       1.00      1.00      1.00       454
                             Healthy       1.00      0.99      1.00       438
                            Diabetes       1.00      1.00      1.00       458
                             Anaemia       1.00      1.00      1.00       457
                     Atherosclerosis       1.00      0.97      0.98       458
                          Arrhythmia       0.99      0.99      0.99       454
Respiratory Disease (COPD or Asthma)       0.99      0.99      0.99       455
               Autonomic Dysfunction       0.99      0.97      0.98       463

    

Parameters: { "n_estimators", "scale_pos_weight" } are not used.



Sample Size 20000: Accuracy 0.9885
Classification Report:
                                      precision    recall  f1-score   support

                        Hypertension       0.94      0.96      0.95       902
        Cardiovascular Disease (CVD)       0.99      0.99      0.99       911
      Chronic Fatigue Syndrome (CFS)       0.99      1.00      0.99       912
            Stress-related Disorders       1.00      1.00      1.00       909
                             Healthy       1.00      0.99      1.00       876
                            Diabetes       0.99      0.99      0.99       917
                             Anaemia       1.00      1.00      1.00       913
                     Atherosclerosis       0.99      0.97      0.98       915
                          Arrhythmia       0.99      0.99      0.99       909
Respiratory Disease (COPD or Asthma)       1.00      1.00      1.00       910
               Autonomic Dysfunction       0.99      0.97      0.98       926

    

Parameters: { "n_estimators", "scale_pos_weight" } are not used.



Sample Size 30000: Accuracy 0.9884
Classification Report:
                                      precision    recall  f1-score   support

                        Hypertension       0.93      0.96      0.95      1352
        Cardiovascular Disease (CVD)       0.99      0.99      0.99      1367
      Chronic Fatigue Syndrome (CFS)       0.98      0.99      0.99      1369
            Stress-related Disorders       1.00      1.00      1.00      1363
                             Healthy       1.00      0.99      0.99      1314
                            Diabetes       0.99      1.00      1.00      1375
                             Anaemia       1.00      1.00      1.00      1370
                     Atherosclerosis       0.99      0.98      0.99      1373
                          Arrhythmia       1.00      1.00      1.00      1363
Respiratory Disease (COPD or Asthma)       0.99      0.99      0.99      1366
               Autonomic Dysfunction       1.00      0.97      0.98      1388

    

Parameters: { "n_estimators", "scale_pos_weight" } are not used.



Sample Size 40000: Accuracy 0.9894
Classification Report:
                                      precision    recall  f1-score   support

                        Hypertension       0.95      0.96      0.96      1803
        Cardiovascular Disease (CVD)       0.99      1.00      0.99      1822
      Chronic Fatigue Syndrome (CFS)       0.98      1.00      0.99      1825
            Stress-related Disorders       1.00      1.00      1.00      1818
                             Healthy       0.99      0.99      0.99      1752
                            Diabetes       1.00      0.99      0.99      1834
                             Anaemia       1.00      1.00      1.00      1826
                     Atherosclerosis       0.99      0.98      0.99      1831
                          Arrhythmia       0.99      0.99      0.99      1817
Respiratory Disease (COPD or Asthma)       0.99      1.00      0.99      1821
               Autonomic Dysfunction       0.99      0.98      0.99      1851

    

Parameters: { "n_estimators", "scale_pos_weight" } are not used.



Sample Size 50000: Accuracy 0.9891
Classification Report:
                                      precision    recall  f1-score   support

                        Hypertension       0.95      0.96      0.95      2254
        Cardiovascular Disease (CVD)       0.99      0.99      0.99      2278
      Chronic Fatigue Syndrome (CFS)       0.98      1.00      0.99      2281
            Stress-related Disorders       1.00      1.00      1.00      2272
                             Healthy       0.99      0.99      0.99      2190
                            Diabetes       1.00      0.99      0.99      2292
                             Anaemia       1.00      1.00      1.00      2283
                     Atherosclerosis       0.99      0.98      0.99      2288
                          Arrhythmia       0.99      0.99      0.99      2272
Respiratory Disease (COPD or Asthma)       0.99      1.00      1.00      2276
               Autonomic Dysfunction       0.99      0.97      0.98      2314

    

In [11]:
# Save the best model
if best_model is not None:
    best_model_file = f"xgb_model.pkl"
    with open(best_model_file, "wb") as model_file:
        pickle.dump(best_model, model_file)
    print(f"\nBest model saved as {best_model_file} with accuracy {best_accuracy:.4f}")


Best model saved as xgb_model.pkl with accuracy 0.9894
