Importuri


In [40]:
import numpy as np
import os
from pathlib import Path
from sklearn.ensemble import ExtraTreesClassifier
from sklearn.model_selection import GridSearchCV, StratifiedKFold, KFold
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import f1_score, make_scorer
from sklearn.multioutput import MultiOutputClassifier
import warnings
import joblib

from constants import PROCESSED_DATA_DIR, FILTERED_DATA_DIR, NUM_SAMPLES, SAMPLE_RATE, SNOMED_DICT, LEADS, NUM_LEADS, PLOT_DIR, CLASSIFIER_DATA_DIR, PLOT_DIR



load_classifier_data

In [42]:
def load_classifier_data(selected_folders):
    """
    X : np.ndarray de shape (N_samples, N_features)
    Y : list of lists (multi-label)
    """
    X_all = []
    Y_all = []
    record_names_all = []

    for folder_name in selected_folders:
        folder_path = CLASSIFIER_DATA_DIR / folder_name
        if not folder_path.exists():
            print(f"[!!!] Folderul {folder_name} nu există în {CLASSIFIER_DATA_DIR}.skipp")
            continue

        # Enumerăm fișierele .npy din folderul curent
        batch_files = sorted([f for f in folder_path.iterdir() if f.is_file() and f.suffix == ".npy"])
        for bf in batch_files:
            data_dict = np.load(bf, allow_pickle=True).item()
            for record_name, rec_data in data_dict.items():
                feats = rec_data["features"]  
                labels = rec_data["labels"]   

                X_all.append(feats)
                Y_all.append(labels.tolist())  
                record_names_all.append(record_name)

    if len(X_all) == 0:
        raise ValueError("no data folder")

    X_all = np.vstack(X_all) 

    return X_all, Y_all, record_names_all


sanity

In [43]:
def sanity_check(X, Y):
    """
    Mică funcție de verificare a datelor:
      - shape
      - valori lipsă (NaN)
    """
    print("== Sanity check pe date ==")
    print(f"X shape = {X.shape}")
    print("Număr de features per înregistrare:", X.shape[1])
    print("Features prezente (exemplu):", [
        'ventricular_rate', 'atrial_rate', 'qrs_duration',
        'qt_interval', 'qrs_count', 'mean_r_onset_sec',
        'mean_r_offset_sec', 'sex_binary', 'age'
    ])
    
    # Verificăm NaN
    nans_in_X = np.isnan(X).sum()
    print(f"Număr de NaN în X: {nans_in_X}")

    # coduri SNOMED
    print(f"Număr eșantioane (Y) = {len(Y)}")
    if len(X) != len(Y):
        print("[WARNING] X și Y nu au aceeași lungime!")
    print("Exemplu Y[0]:", Y[0])


train

In [44]:

# Alegem folderele din care să antrenăm modelul
selected_folders_train = [f"{i:02d}" for i in range(1, 16)]
X, Y_list, record_names = load_classifier_data(selected_folders_train)
sanity_check(X, Y_list)

all_labels = set()
for lab_list in Y_list:
    all_labels.update(lab_list)
all_labels = sorted(list(all_labels))

mlb = MultiLabelBinarizer(classes=all_labels)
Y_bin = mlb.fit_transform(Y_list)
print(f"Y_bin shape = {Y_bin.shape}")

clf = ExtraTreesClassifier(random_state=42)
multi_clf = MultiOutputClassifier(clf)

final_param_grid = {
    "estimator__n_estimators": [300, 500, 800],
    "estimator__criterion": ["gini", "entropy"],
    "estimator__bootstrap": [True, False],
    "estimator__max_features": ["sqrt", "log2", None]
}

scorer = make_scorer(f1_score, average="micro")

gs = GridSearchCV(
    multi_clf,
    param_grid=final_param_grid,
    scoring=scorer,
    cv=KFold(n_splits=3, shuffle=True, random_state=42),
    refit=True,
    n_jobs=-1,
    verbose=1
)

with warnings.catch_warnings():
    warnings.simplefilter("ignore", category=UserWarning)
    gs.fit(X, Y_bin)

fail_indices = np.where(np.isnan(gs.cv_results_["mean_test_score"]))[0]
if len(fail_indices) > 0:
    print("Seturi de parametri care au generat erori sau warning:")
    for idx in fail_indices:
        print("Index:", idx, "Parametri:", gs.cv_results_["params"][idx])

print(f"Cel mai bun scor (cv) obținut: {gs.best_score_}")
print(f"Parametrii cei mai buni: {gs.best_params_}")

best_model = gs.best_estimator_
Y_pred = best_model.predict(X)
final_f1 = f1_score(Y_bin, Y_pred, average="micro")
print(f"[TRAIN] F1 final (pe același set) = {final_f1:.4f}")

joblib.dump(best_model, PLOT_DIR / "best_model.pkl")


== Sanity check pe date ==
X shape = (15000, 9)
Număr de features per înregistrare: 9
Features prezente (exemplu): ['ventricular_rate', 'atrial_rate', 'qrs_duration', 'qt_interval', 'qrs_count', 'mean_r_onset_sec', 'mean_r_offset_sec', 'sex_binary', 'age']
Număr de NaN în X: 26
Număr eșantioane (Y) = 15000
Exemplu Y[0]: [164889003, 59118001, 164934002]
Y_bin shape = (15000, 89)
Fitting 3 folds for each of 36 candidates, totalling 108 fits


Python(31971) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(31973) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(32105) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(32111) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(32193) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(32194) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(32261) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(32412) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(32413) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(32423) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(32910) Malloc

Cel mai bun scor (cv) obținut: 0.43168562186440473
Parametrii cei mai buni: {'estimator__bootstrap': False, 'estimator__criterion': 'entropy', 'estimator__max_features': None, 'estimator__n_estimators': 800}
[TRAIN] F1 final (pe același set) = 0.9995


['/Users/teofil/Dev/GitHub/ekg-classification-pipeline/plots/best_model.pkl']