In [None]:
!pip install wfdb

Collecting wfdb
  Downloading wfdb-4.3.0-py3-none-any.whl.metadata (3.8 kB)
Collecting pandas>=2.2.3 (from wfdb)
  Downloading pandas-2.3.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.metadata (91 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m91.2/91.2 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
Downloading wfdb-4.3.0-py3-none-any.whl (163 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m163.8/163.8 kB[0m [31m14.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pandas-2.3.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (12.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.4/12.4 MB[0m [31m47.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pandas, wfdb
  Attempting uninstall: pandas
    Found existing installation: pandas 2.2.2
    Uninstalling pandas-2.2.2:
      Successfully uninstalled pandas-2.2.2
[31mERROR: pip's dependency resolver does not currently take into acco

In [None]:
import os
import wfdb
import random
import numpy as np
from dataclasses import dataclass
from typing import Optional, Tuple, List
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import amp
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from scipy.interpolate import interp1d

from sklearn.metrics import average_precision_score, precision_score, recall_score, f1_score, roc_auc_score, precision_recall_curve


@dataclass
class TrainCfg:
    window_sec = 2.0
    crop_len = 360
    records = tuple([
        '100','101','102','103','104','105','106','107','108','109','111','112','113','114','115',
        '116','117','118','119','121','122','123','124','200','201','202','203','205','207','208',
        '209','210','212','213','214','215','217','219','220','221','222','223','228','230','231',
        '232','233','234'
    ])
    label_keep = ("N", "V", "F")
    split_mode = "stratified"
    train_ratio = 0.7
    val_ratio = 0.15
    test_ratio = 0.15
    crop_len: int = 720
    batch_size: int = 128
    max_epochs: int = 50
    lr: float = 3e-4
    weight_decay: float = 1e-4
    use_sampler: bool = True
    sampler_scale: float = 0.2
    focal_gamma: float = 1.5
    hybrid_switch_epoch: int = 15
    augment: bool = True
    severity: int = 0.25
    clip_grad_norm: float = 1.0
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    amp: bool = True
    seed: int = 42

CFG = TrainCfg()

torch.manual_seed(CFG.seed)
np.random.seed(CFG.seed)
random.seed(CFG.seed)

In [None]:
def load_record(name, duration_sec=600):
    rec = wfdb.rdrecord(name, pn_dir="mitdb")
    ann = wfdb.rdann(name, "atr", pn_dir="mitdb")
    fs = rec.fs
    sig = rec.p_signal[:int(fs*duration_sec)]
    valid = ann.sample < int(fs*duration_sec)
    ann.sample = ann.sample[valid]
    ann.symbol = np.array(ann.symbol)[valid].tolist()
    return sig, ann, fs

def segment_beats(signal, ann, fs, window_sec, label_keep):
    half = int(fs*window_sec/2)
    beats, labels = [], []
    for sample, sym in zip(ann.sample, ann.symbol):
        if sym not in label_keep:
            continue
        start, end = sample-half, sample+half
        if start<0 or end>len(signal):
            continue
        seg = signal[start:end,0]
        seg = (seg - np.mean(seg))/(np.std(seg)+1e-8)
        beats.append(seg.astype(np.float32))
        labels.append(0 if sym=="N" else 1)
    return np.array(beats), np.array(labels)

from sklearn.model_selection import train_test_split

def build_dataset(cfg):
    if getattr(cfg, "split_mode", "record") == "record":
        rng = np.random.default_rng(cfg.seed)
        records = list(cfg.records)
        rng.shuffle(records)
        n_total = len(records)
        n_train = int(n_total * cfg.train_ratio)
        n_val = int(n_total * cfg.val_ratio)
        train_recs = records[:n_train]
        val_recs = records[n_train:n_train + n_val]
        test_recs = records[n_train + n_val:]

        def proc(rec_list):
            xs, ys = [], []
            for r in rec_list:
                sig, ann, fs = load_record(r)
                x, y = segment_beats(sig, ann, fs, cfg.window_sec, cfg.label_keep)
                if x.size == 0 or y.size == 0:
                    continue
                if x.ndim == 1:
                    x = x[None, :]
                xs.append(x[:, None, :])
                ys.append(y)
            if len(xs) == 0:
                return np.empty((0, 1, cfg.crop_len)), np.empty((0,), dtype=int)
            return np.concatenate(xs, axis=0), np.concatenate(ys, axis=0)

        X_train, y_train = proc(train_recs)
        X_val, y_val = proc(val_recs)
        X_test, y_test = proc(test_recs)

    else:
        xs, ys = [], []
        for r in cfg.records:
            sig, ann, fs = load_record(r)
            x, y = segment_beats(sig, ann, fs, cfg.window_sec, cfg.label_keep)
            if x.size == 0 or y.size == 0:
                continue
            if x.ndim == 1:
                x = x[None, :]
            xs.append(x[:, None, :])
            ys.append(y)
        X = np.concatenate(xs, axis=0)
        y = np.concatenate(ys, axis=0)

        X_train, X_temp, y_train, y_temp = train_test_split(
            X, y,
            test_size=cfg.val_ratio + cfg.test_ratio,
            stratify=y,
            random_state=cfg.seed
        )
        rel_test = cfg.test_ratio / (cfg.val_ratio + cfg.test_ratio)
        X_val, X_test, y_val, y_test = train_test_split(
            X_temp, y_temp,
            test_size=rel_test,
            stratify=y_temp,
            random_state=cfg.seed
        )
    print("Class distribution:")
    for name, labels in zip(["Train", "Val", "Test"], [y_train, y_val, y_test]):
        if len(labels) > 0:
            ratio = np.mean(labels)
            print(f"{name}: {len(labels)} samples | pos ratio {ratio:.4f}")
        else:
            print(f"{name}: 0 samples")

    return (X_train, y_train), (X_val, y_val), (X_test, y_test)


In [None]:
(X_train, y_train), (X_val, y_val), (X_test, y_test) = build_dataset(CFG)

Class distribution:
Train: 19371 samples | pos ratio 0.0936
Val: 4151 samples | pos ratio 0.0937
Test: 4151 samples | pos ratio 0.0935


In [None]:
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset


class AugmentationScheduler:
    def __init__(self, prog_epochs=10, max_severity=None):
        self.prog_epochs = prog_epochs
        self.max_severity = max_severity

    def get(self, epoch, y):
        base = min(1.0, epoch / max(1, self.prog_epochs))
        severity = base * self.max_severity

        return min(severity, self.max_severity)


class AddGaussianNoise(nn.Module):
    def __init__(self, std_factor=0.05):
        super().__init__()
        self.std_factor = std_factor

    def forward(self, x, severity):
        std = torch.std(x) * self.std_factor * severity * 10
        return x + torch.randn_like(x) * std


class AmplitudeScale(nn.Module):
    def __init__(self, scale=0.1):
        super().__init__()
        self.scale = scale

    def forward(self, x, severity):
        factor = 1 + self.scale * (2 * torch.rand(1, device=x.device) - 1) * severity
        return x * factor


class TimeWarp(nn.Module):
    def __init__(self, max_stretch=0.05):
        super().__init__()
        self.max_stretch = max_stretch

    def forward(self, x, severity):
        L = x.shape[-1]

        stretch = 1 + (torch.rand(1, device=x.device) * 2 - 1) * self.max_stretch * severity
        new_len = max(2, int(L * stretch.item()))

        x_in = x.unsqueeze(0).unsqueeze(0)

        warped = F.interpolate(x_in, size=new_len, mode="linear", align_corners=False)
        resampled = F.interpolate(warped, size=L, mode="linear", align_corners=False)

        return resampled.squeeze(0).squeeze(0)


class RandomShift(nn.Module):
    def __init__(self, max_shift=0.05):
        super().__init__()
        self.max_shift = max_shift

    def forward(self, x, severity):
        L = x.shape[-1]
        max_shift = int(L * self.max_shift * severity)
        if max_shift > 0:
            k = int(torch.randint(-max_shift, max_shift + 1, (1,), device=x.device))
            x = torch.roll(x, shifts=k, dims=-1)
        return x


class RandomDropout(nn.Module):
    def __init__(self, max_frac=0.05):
        super().__init__()
        self.max_frac = max_frac

    def forward(self, x, severity):
        L = x.shape[-1]
        chunk_len = max(1, int(L * self.max_frac * severity))
        if chunk_len > 0:
            start = int(torch.randint(0, max(1, L - chunk_len), (1,), device=x.device))
            x[start:start + chunk_len] = 0.0
        return x


class RandomSpike(nn.Module):
    def __init__(self, max_amp=0.2):
        super().__init__()
        self.max_amp = max_amp

    def forward(self, x, severity):
        L = x.shape[-1]
        amp = self.max_amp * severity * torch.std(x)
        at = int(torch.randint(0, L, (1,), device=x.device))
        l = min(int(torch.randint(1, 4, (1,), device=x.device)), L - at)
        x[at:at + l] += torch.randn(l, device=x.device) * amp
        return x


class ECGAugment(nn.Module):
    def __init__(self, transforms):
        super().__init__()
        self.transforms = nn.ModuleList(transforms)

    def forward(self, x, severity):
        chosen = random.sample(list(self.transforms), random.randint(1, len(self.transforms)))
        for t in chosen:
            x = t(x, severity)
        return torch.clamp(x, -5, 5)


class ECGDataset(torch.utils.data.Dataset):
    def __init__(self, X, y, augment=False, cfg=None):
        self.X = torch.tensor(X.squeeze(), dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)

        self.cfg = cfg
        self.augment = augment
        self.severity = cfg.severity

        self.scheduler = AugmentationScheduler(
            max_severity=self.severity,
        )
        self.current_epoch = 0

        self.augment_pipeline = ECGAugment([
            AddGaussianNoise(),
            AmplitudeScale(),
            TimeWarp(),
            RandomShift(),
            RandomDropout(),
            RandomSpike()
        ])

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        x, y = self.X[idx], self.y[idx]
        if self.augment:
            severity = self.scheduler.get(self.current_epoch, int(y.item()))
            x = self.augment_pipeline(x.to(self.cfg.device), severity)

        return x.unsqueeze(0), y


In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_ch, out_ch, kernel=5, stride=1):
        super().__init__()
        self.conv1 = nn.Conv1d(in_ch, out_ch, kernel, stride=stride, padding=kernel//2)
        self.bn1 = nn.BatchNorm1d(out_ch)
        self.conv2 = nn.Conv1d(out_ch, out_ch, kernel, padding=kernel//2)
        self.bn2 = nn.BatchNorm1d(out_ch)
        self.down = None
        if in_ch != out_ch or stride != 1:
            self.down = nn.Conv1d(in_ch, out_ch, 1, stride=stride)

    def forward(self, x):
        identity = x if self.down is None else self.down(x)
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = F.relu(out + identity)
        return out

class SEBlock1D(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.fc1 = nn.Linear(channels, channels // reduction)
        self.fc2 = nn.Linear(channels // reduction, channels)

    def forward(self, x):
        w = x.mean(dim=2)
        w = F.relu(self.fc1(w))
        w = torch.sigmoid(self.fc2(w)).unsqueeze(-1)
        return x * w


class MultiScaleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv3 = nn.Conv1d(in_ch, out_ch, 3, padding=1)
        self.conv5 = nn.Conv1d(in_ch, out_ch, 5, padding=2)
        self.conv7 = nn.Conv1d(in_ch, out_ch, 7, padding=3)
        self.bn = nn.BatchNorm1d(out_ch*3)

    def forward(self, x):
        x3 = self.conv3(x)
        x5 = self.conv5(x)
        x7 = self.conv7(x)
        out = torch.cat([x3, x5, x7], dim=1)
        return F.relu(self.bn(out))


class ECGNet(nn.Module):
    def __init__(self, input_len=720, dropout_p=0.5, temporal_context=False, lstm_hidden=256):
        super().__init__()
        self.temporal_context = temporal_context

        self.stem = nn.Conv1d(1, 64, 7, padding=3)

        self.r1 = ResidualBlock(64, 128)
        self.r2 = ResidualBlock(128, 256)
        self.r3 = ResidualBlock(256, 256)
        self.r4 = ResidualBlock(256, 512)
        self.se = SEBlock1D(512)

        self.ms = MultiScaleConv(512, 256)

        self.lstm = None
        if temporal_context:
            self.lstm = nn.LSTM(input_size=256*3, hidden_size=lstm_hidden,
                                batch_first=True, bidirectional=True)
            self.fc = nn.Linear(lstm_hidden*2, 1)
        else:
            self.avgpool = nn.AdaptiveAvgPool1d(1)
            self.maxpool = nn.AdaptiveMaxPool1d(1)
            self.fc = nn.Linear(256*3*2, 1)

        self.dropout = nn.Dropout(dropout_p)

    def forward(self, x):
        x = self.stem(x)
        x = self.r1(x)
        x = self.r2(x)
        x = self.r3(x)
        x = self.r4(x)
        x = self.se(x)
        x = self.ms(x)

        if self.temporal_context:
            B, T, C, L = x.shape
            x = x.view(B, T, -1)
            x, _ = self.lstm(x)
            x = x[:, -1, :]
        else:
            a = self.avgpool(x).squeeze(-1)
            m = self.maxpool(x).squeeze(-1)
            x = torch.cat([a, m], dim=1)

        x = self.dropout(x)
        return self.fc(x).squeeze(1)


In [None]:
class FocalLoss(nn.Module):
    def __init__(self, focal_gamma=None, pos_weight=None):
        super().__init__()

        self.focal_gamma = focal_gamma
        self.pos_weight = pos_weight

    def forward(self, logits, targets, epoch=0):
        prob = torch.sigmoid(logits)
        bce = F.binary_cross_entropy_with_logits(
            logits, targets, reduction="none"
        )
        p_t = prob * targets + (1 - prob) * (1 - targets)
        focal = ((1 - p_t) ** self.focal_gamma) * bce
        if self.pos_weight is not None:
            w = torch.where(targets == 1, self.pos_weight, torch.tensor(1.0, device=targets.device))
            focal = focal * w
        return focal.mean()

class BCEWeighted(nn.Module):
    def __init__(self, pos_weight=None):
        super().__init__()

        self.pos_weight = pos_weight

    def forward(self, logits, targets, epoch=0):
        return F.binary_cross_entropy_with_logits(
            logits, targets, pos_weight=self.pos_weight, reduction="mean"
        )


In [None]:
def tune_threshold_from_probs(y_true: np.ndarray, probs: np.ndarray):
    precision, recall, thresholds = precision_recall_curve(y_true, probs)
    f1s = 2 * precision * recall / (precision + recall + 1e-8)
    best_idx = int(np.nanargmax(f1s))

    if best_idx >= len(thresholds):
        return 0.5, f1s[best_idx]
    return float(thresholds[best_idx]), float(f1s[best_idx])

In [None]:

def train_model(model: nn.Module,
                X_train: np.ndarray, y_train: np.ndarray,
                X_val: np.ndarray, y_val: np.ndarray,
                cfg):
    device = cfg.device
    model.to(device)

    train_ds = ECGDataset(X_train, y_train, augment=True, cfg=cfg)
    val_ds   = ECGDataset(X_val,   y_val, augment=False, cfg=cfg)

    labels_flat = y_train.astype(int).flatten()
    class_counts = np.bincount(labels_flat)
    print("Class counts:", class_counts)

    if cfg.use_sampler:
        weights = 1. / class_counts
        sample_weights = weights[y_train.astype(int)]
        sampler = WeightedRandomSampler(sample_weights, len(sample_weights))
        train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, sampler=sampler)
    else:
        train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True)

    val_loader = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False)

    pos_weight = torch.tensor(
        (len(y_train) - y_train.sum()) / max(1, y_train.sum()),
        dtype=torch.float32,
        device=device
    )

    criterion = FocalLoss(focal_gamma=cfg.focal_gamma, pos_weight=pos_weight)
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer,
        T_0=10,
        T_mult=2,
        eta_min=cfg.lr * 1e-3
    )

    scaler = amp.GradScaler("cuda", enabled=cfg.amp and device.startswith("cuda"))

    best_f1, best_state = -1.0, None
    smoothed_thresh = 0.5
    alpha_thresh = 0.3

    for epoch in range(1, cfg.max_epochs + 1):
        train_ds.current_epoch = epoch
        model.train()
        total_loss, n_seen = 0.0, 0
        severity_sum = 0.0

        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device).view(-1)
            optimizer.zero_grad()

            batch_severity = np.mean([train_ds.scheduler.get(epoch, int(y.item())) for y in yb])
            severity_sum += batch_severity * xb.size(0)

            with amp.autocast('cuda', enabled=scaler.is_enabled()):
                logits = model(xb).view(-1)
                loss = criterion(logits, yb, epoch)

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.clip_grad_norm)
            scaler.step(optimizer)
            scaler.update()

            total_loss += float(loss.item()) * xb.size(0)
            n_seen += xb.size(0)

        scheduler.step(epoch + 1)

        train_loss = total_loss / max(1, n_seen)
        avg_severity = severity_sum / max(1, n_seen)

        model.eval()
        all_probs, all_y = [], []
        with torch.inference_mode(), amp.autocast('cuda', enabled=scaler.is_enabled()):
            for xb, yb in val_loader:
                xb = xb.to(device)
                logits = model(xb).view(-1)
                probs = torch.clamp(torch.sigmoid(logits), min=1e-7, max=1 - 1e-7).cpu().numpy()
                all_probs.append(probs)
                all_y.append(yb.numpy().astype(int).flatten())

        all_probs = np.concatenate(all_probs)
        all_y = np.concatenate(all_y)

        pr_auc = float(average_precision_score(all_y, all_probs))
        roc_auc = float(roc_auc_score(all_y, all_probs))
        epoch_thresh, epoch_f1 = tune_threshold_from_probs(all_y, all_probs)

        preds = (all_probs >= epoch_thresh).astype(int)
        epoch_precision = precision_score(all_y, preds, zero_division=0)
        epoch_recall = recall_score(all_y, preds, zero_division=0)

        smoothed_thresh = alpha_thresh * epoch_thresh + (1 - alpha_thresh) * smoothed_thresh

        if epoch_f1 > best_f1:
            best_f1 = epoch_f1
            best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            ckpt_path = os.path.join(cfg.save_dir, f"best_f1_epoch{epoch:03d}_f1{best_f1:.4f}.pt")
            torch.save({
                'model_state': best_state,
                'epoch': epoch,
                'f1': best_f1,
                'threshold': smoothed_thresh
            }, ckpt_path)

        print(f"Epoch {epoch:03d} | train_loss {train_loss:.4f} "
              f"| best_F1 {best_f1:.4f} @thr {smoothed_thresh:.3f} "
              f"(Precision {epoch_precision:.3f} | Recall {epoch_recall:.3f} | PR-AUC {pr_auc:.3f} | ROC-AUC {roc_auc:.3f} | AvgSeverity {avg_severity:.3f})")

    if best_state is not None:
        model.load_state_dict(best_state)
    print(f"Training done. Best F1: {best_f1:.4f} | Smoothed threshold: {smoothed_thresh:.3f}")
    return model, smoothed_thresh, best_f1

In [None]:
params = {
          'max_epochs': 150 ,
          'lr': 3e-4,
          'weight_decay': 1e-04,
          'focal_gamma': 1,
          'crop_len': 720,
          'severity': 0.25}

cfg = TrainCfg()
for k, v in params.items():
    setattr(cfg, k, v)

model = ECGNet(input_len=cfg.crop_len, dropout_p=0.3)

trained_model, threshold, best_pr = train_model(
    model, X_train, y_train, X_val, y_val, cfg
)



test_loader = DataLoader(
    ECGDataset(X_test, y_test, augment=False, cfg=cfg),
    batch_size=cfg.batch_size
)


Class counts: [17557  1814]
Epoch 001 | train_loss 0.7092 | best_F1 0.5644 @thr 0.595 (Precision 0.422 | Recall 0.851 | PR-AUC 0.486 | ROC-AUC 0.932 | AvgSeverity 0.025)
Epoch 002 | train_loss 0.3390 | best_F1 0.5891 @thr 0.663 (Precision 0.456 | Recall 0.833 | PR-AUC 0.483 | ROC-AUC 0.935 | AvgSeverity 0.050)
Epoch 003 | train_loss 0.2973 | best_F1 0.5917 @thr 0.722 (Precision 0.436 | Recall 0.920 | PR-AUC 0.493 | ROC-AUC 0.942 | AvgSeverity 0.075)
Epoch 004 | train_loss 0.2810 | best_F1 0.5917 @thr 0.784 (Precision 0.455 | Recall 0.828 | PR-AUC 0.481 | ROC-AUC 0.941 | AvgSeverity 0.100)
Epoch 005 | train_loss 0.2608 | best_F1 0.6301 @thr 0.821 (Precision 0.509 | Recall 0.828 | PR-AUC 0.494 | ROC-AUC 0.945 | AvgSeverity 0.125)
Epoch 006 | train_loss 0.2474 | best_F1 0.6301 @thr 0.860 (Precision 0.510 | Recall 0.823 | PR-AUC 0.571 | ROC-AUC 0.952 | AvgSeverity 0.150)
Epoch 007 | train_loss 0.2331 | best_F1 0.6392 @thr 0.872 (Precision 0.509 | Recall 0.859 | PR-AUC 0.574 | ROC-AUC 0.953

In [None]:
trained_model.eval()

torch.save(trained_model.state_dict(), "cardioiq_model_v4.pt")