In [None]:
import os
import numpy as np
import wfdb
from scipy.signal import resample
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import confusion_matrix, classification_report
from imblearn.over_sampling import RandomOverSampler

# === 0) PARÂMETROS GERAIS ===
data_dir = r'C:\Users\Stella\Desktop\TCC\db'
fs = 977
window_ms = 300
window_size = int(fs * window_ms / 1000)

# === 1) LISTA DE REGISTROS & MAPEAMENTO DE CLASSES ===
with open(os.path.join(data_dir, 'RECORDS')) as f:
    records = [r.strip() for r in f if r.strip()]

symbol_set = set()
for rec in records:
    ann = wfdb.rdann(os.path.join(data_dir, rec), 'atr')
    symbol_set.update(ann.symbol)
symbol_list = sorted(symbol_set)
symbol_to_idx = {s: i for i, s in enumerate(symbol_list)}
n_classes = len(symbol_list)
print(f"→ {len(records)} registros, {n_classes} classes: {symbol_list}")

# === 2) EXTRAÇÃO DE BATIMENTOS ===
def extract_beats(rec_name):
    rec = wfdb.rdrecord(os.path.join(data_dir, rec_name))
    ann = wfdb.rdann(os.path.join(data_dir, rec_name), 'atr')
    sig = rec.p_signal[:, 0]
    half = window_size // 2
    Xb, yb = [], []
    for samp, sym in zip(ann.sample, ann.symbol):
        st, ed = samp - half, samp + half
        if st < 0 or ed > len(sig):
            continue
        beat = sig[st:ed]
        if len(beat) != window_size:
            beat = resample(beat, window_size)
        Xb.append(beat)
        yb.append(symbol_to_idx[sym])
    return np.array(Xb, dtype=np.float32), np.array(yb, dtype=np.int64)

# === 3) CONSTRÓI DATASET COMPLETO ===
X_list, y_list = [], []
for rec in records:
    xb, yb = extract_beats(rec)
    X_list.append(xb)
    y_list.append(yb)
X = np.vstack(X_list)
y = np.hstack(y_list)
print(f"→ Extraídos {X.shape[0]} batimentos × {X.shape[1]} amostras")

# === 4) PRÉ-PROCESSAMENTO ===
X_norm = (X - X.mean(axis=1, keepdims=True)) / (X.std(axis=1, keepdims=True) + 1e-8)
X_flat = X_norm.reshape(X_norm.shape[0], -1)

# === 5) OVERSAMPLING ===
ros = RandomOverSampler(random_state=42)
X_resampled, y_resampled = ros.fit_resample(X_flat, y)

# === 6) SPLIT TREINO/TESTE ===
X_train, X_test, y_train, y_test = train_test_split(
    X_resampled, y_resampled, test_size=0.2, stratify=y_resampled, random_state=42
)

# === 7) NORMALIZAÇÃO GLOBAL ===
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# === 8) GRID SEARCH NO RANDOM FOREST ===
param_grid = {
    'n_estimators': [100, 200],
    'max_depth': [10, 20],
    'min_samples_split': [2, 5],
    'class_weight': ['balanced_subsample']
}

grid = GridSearchCV(
    RandomForestClassifier(random_state=42, n_jobs=-1),
    param_grid=param_grid,
    cv=3,
    scoring='f1_macro',
    n_jobs=-1
)
grid.fit(X_train, y_train)

# === 9) MELHOR MODELO E AVALIAÇÃO ===
clf = grid.best_estimator_
print("✅ Melhor modelo selecionado:")
print(grid.best_params_)

y_pred = clf.predict(X_test)

print("\nClassification Report:")
print(classification_report(y_test, y_pred, target_names=symbol_list))

cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', xticklabels=symbol_list, yticklabels=symbol_list, cmap='Blues')
plt.xlabel('Predito')
plt.ylabel('Verdadeiro')
plt.title('Matriz de Confusão - Random Forest com Oversampling + GridSearch')
plt.tight_layout()
plt.show()


→ 39 registros, 16 classes: ['+', '/', 'A', 'F', 'J', 'L', 'N', 'Q', 'R', 'V', 'X', 'a', 'b', 'f', 'j', '~']
→ Extraídos 118193 batimentos × 293 amostras
