In [13]:
from back import generate_ecg, extract_ecg_features
import numpy as np
import pandas as pd

In [14]:
num_samples = 100
np.random.seed(42)

data_list = []

for i in range(num_samples):
    age = np.random.randint(20,80)
    gender = np.random.choice([0,1])
    heart_rate = np.random.randint(60,110)

    ecg = generate_ecg(duration = 10, sampling_rate = 500, heart_rate=heart_rate)

    feats = extract_ecg_features(ecg,500)

    feats['age'] = age
    feats['sex']= gender

    risk = 0

    if age <45 and 60<feats['HR']<=80 and feats['QRS']<120:
        risk=0

    elif(45<=age<=60) or (80<feats['HR']<= 100) or (120<= feats['QRS']<=140):
        risk = 1

    elif age>60 or feats['HR']>100 or feats['QRS']>140 or feats['SDNN']<50:
        risk = 2

    
    feats['risk'] = risk

    data_list.append(feats)


In [15]:
df = pd.DataFrame(data_list)

| Risk Level     | Criteria                                                                   |
| -------------- | -------------------------------------------------------------------------- |
| **Low (0)**    | Age < 45 **AND** HR 60–80 bpm **AND** normal QRS (<120 ms)                 |
| **Medium (1)** | Age 45–60 **OR** HR 80–100 bpm **OR** mild QRS prolongation (120–140 ms)   |
| **High (2)**   | Age > 60 **OR** HR >100 bpm **OR** QRS > 140 ms **OR** abnormal RMSSD/SDNN |


In [16]:
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import StandardScaler


X = df.drop("risk", axis =1)
y = df["risk"]


X_train, X_test, y_train, y_test = train_test_split(X,y, test_size=0.2,random_state=42)


scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.fit_transform(X_test)


model = RandomForestClassifier()

model.fit(X_train_scaled,y_train)

print("Accuracy", model.score(X_test_scaled,y_test))

Accuracy 0.85


Model trained on synthetic data using rule-based labels

In [17]:
import joblib

joblib.dump(model, "ecg_risk_model.pkl")
joblib.dump(X.columns.tolist(), "feature_names.pkl")

['feature_names.pkl']