In [1]:
import pandas as pd
import glob
from sklearn.model_selection import train_test_split
from sklearn.impute import SimpleImputer
from sklearn.neighbors import KNeighborsClassifier
from sklearn.decomposition import PCA
import random
import os
import torch
import torch.nn as nn
from models.rnn import LeadIILSTM, LSTM_12Leads
from torch.utils.data import Dataset, DataLoader, TensorDataset
from neurokit2 import ecg
import neurokit2 as nk
import numpy as np

import matplotlib.pyplot as plt

SAMPLING_RATE = 1000


In [2]:
%load_ext jupyternotify

<IPython.core.display.Javascript object>

In [6]:
class ECGDataset(Dataset):
    def __init__(self, data_folder, class_folders, files_per_class=200):
        self.samples = []
        self.leads = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
        for folder, label in class_folders.items():
            files = glob.glob(os.path.join(data_folder, folder, '*.parquet.gzip'))
            # enforce exact files_per_class per class (downsample or upsample with replacement)
            if len(files) >= files_per_class:
                files = random.sample(files, files_per_class)
            else:
                files = random.choices(files, k=files_per_class)

            for f in files:
                try:
                    df = pd.read_parquet(f, engine='fastparquet')
                except Exception as e:
                    print(f"Failed to read {f}: {e}")
                    continue

                # ensure required lead columns exist
                if not set(self.leads).issubset(df.columns):
                    print(f"Missing leads in {f}, skipping")
                    continue

                # convert lead columns to numeric, coerce non-numeric to NaN, then fill and cast
                df_leads = df[self.leads].apply(pd.to_numeric, errors='coerce').fillna(0).astype(np.float32)

                # shape -> (12, time)
                signal = df_leads.values.T
                self.samples.append((torch.tensor(signal, dtype=torch.float32), label, os.path.basename(f)))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        signal, label, ecg_id = self.samples[idx]
        return signal, label, ecg_id

# Usage example
class_folders = {
    'arritmia': 0,
    'block': 1,
    'fibrilation': 2,
    'normal': 3
}
data_folder = 'data'
dataset = ECGDataset(data_folder, class_folders, files_per_class=1970)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

# Inspect one batch
for signals, labels, ecg_ids in dataloader:
    print('Signals shape:', signals.shape)  # (batch, 12, time)
    print('Labels:', labels)
    print('ECG IDs:', ecg_ids)
    break

Signals shape: torch.Size([8, 12, 10000])
Labels: tensor([2, 3, 0, 2, 2, 1, 0, 2])
ECG IDs: ('270369.parquet.gzip', '533892.parquet.gzip', '399764.parquet.gzip', '375116.parquet.gzip', '52639.parquet.gzip', '349968.parquet.gzip', '495783.parquet.gzip', '320776.parquet.gzip')


ECG eliminados por peso (2kb):

- block: 8846, 314864
- normal: 74424

In [7]:
# --- Prepare DataLoader for 12-lead LSTM with downsamping to fixed length ---
# Original signals are (12, time). We'll resample each to `target_len` (e.g., 300) to keep memory and compute reasonable.
import torch
import torch.nn.functional as F
from torch.utils.data import Subset

target_len = 300  # desired sequence length after downsampling (try 200-300)

def resample_signal(sig, target_len):
    # sig: Tensor shape (12, time)
    # convert to (1, channels, length) for F.interpolate -> (1,12,L)
    x = sig.unsqueeze(0)  # (1, 12, time)
    x_res = F.interpolate(x, size=target_len, mode='linear', align_corners=False)
    x_res = x_res.squeeze(0)  # (12, target_len)
    # return as (target_len, 12)
    return x_res.permute(1, 0)

def collate_ecg(batch):
    # batch: list of (signal_tensor (12, time), label, ecg_id)
    resampled = [resample_signal(item[0], target_len) for item in batch]  # list of (target_len, 12)
    # stack -> (batch, target_len, 12)
    batch_tensor = torch.stack(resampled, dim=0)
    labels = torch.tensor([item[1] for item in batch], dtype=torch.long)
    ids = [item[2] for item in batch]
    return batch_tensor, labels, ids

# Create a smaller dataset/dataloader for a quick demo (avoid using all files during tests)
loader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=collate_ecg)

In [11]:
%%notify
indices = list(range(len(dataset)))
labels_arr = [dataset.samples[i][1] for i in indices]  # extract labels for stratify

train_idx, val_idx = train_test_split(indices, test_size=0.2, stratify=labels_arr, random_state=42)

train_dataset = Subset(dataset, train_idx)
val_dataset = Subset(dataset, val_idx)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_ecg)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, collate_fn=collate_ecg)

# --- Instantiate model & optim (existing) ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = LSTM_12Leads(n_channels=12, hidden_size=64, num_layers=1, num_classes=4, bidirectional=True).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

# Training with validation after each epoch
for epoch in range(100):
    model.train()
    total_loss = 0.0
    total = 0
    correct = 0
    for signals, labels, ids in train_loader:
        signals = signals.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        logits = model(signals)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * signals.size(0)
        total += signals.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
    train_acc = correct / total if total else 0.0
    print(f'Epoch {epoch+1} - train_loss: {total_loss/total:.4f} - train_acc: {train_acc:.4f}')

    # validation
    model.eval()
    val_total = 0
    val_correct = 0
    with torch.inference_mode():
        for signals, labels, ids in val_loader:
            signals = signals.to(device)
            labels = labels.to(device)
            logits = model(signals)
            preds = logits.argmax(dim=1)
            val_total += signals.size(0)
            val_correct += (preds == labels).sum().item()
    val_acc = val_correct / val_total if val_total else 0.0
    print(f'          val_acc: {val_acc:.4f}')



Epoch 1 - train_loss: 1.2571 - train_acc: 0.4132
          val_acc: 0.4467
          val_acc: 0.4467
Epoch 2 - train_loss: 1.1747 - train_acc: 0.4659
Epoch 2 - train_loss: 1.1747 - train_acc: 0.4659
          val_acc: 0.4784
          val_acc: 0.4784
Epoch 3 - train_loss: 1.1391 - train_acc: 0.4860
Epoch 3 - train_loss: 1.1391 - train_acc: 0.4860
          val_acc: 0.4860
          val_acc: 0.4860
Epoch 4 - train_loss: 1.1139 - train_acc: 0.4957
Epoch 4 - train_loss: 1.1139 - train_acc: 0.4957
          val_acc: 0.4860
          val_acc: 0.4860
Epoch 5 - train_loss: 1.1031 - train_acc: 0.5060
Epoch 5 - train_loss: 1.1031 - train_acc: 0.5060
          val_acc: 0.4829
          val_acc: 0.4829
Epoch 6 - train_loss: 1.0846 - train_acc: 0.5138
Epoch 6 - train_loss: 1.0846 - train_acc: 0.5138
          val_acc: 0.5006
          val_acc: 0.5006
Epoch 7 - train_loss: 1.0622 - train_acc: 0.5274
Epoch 7 - train_loss: 1.0622 - train_acc: 0.5274
          val_acc: 0.5076
          val_acc: 0.5076

<IPython.core.display.Javascript object>