In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import os
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from sklearn.metrics import f1_score, classification_report
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from torch.optim.lr_scheduler import ReduceLROnPlateau
import warnings

In [6]:
# Suppress PyTorch UserWarning about DataLoader iterator re-creation
warnings.filterwarnings("ignore", category=UserWarning)

In [7]:
# --- 0. Configuration ---
CLEANED_CSV_PATH = 'metadata_with_features.csv' 
CHECKPOINT_DIR = './cnn_lstm_hybrid_checkpoints'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

BATCH_SIZE = 32
EPOCHS = 60 
LEAD_COUNT = 12
SAMPLES = 2500
NUM_CLASSES = 3 
NUM_CLINICAL_FEATURES = 8 
HIDDEN_SIZE = 64
LEARNING_RATE = 2e-4
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE

device(type='cuda')

In [8]:
# --- 1. Focal Loss Function (For Imbalance) ---
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        if self.alpha is not None and not isinstance(self.alpha, torch.Tensor):
            self.alpha = torch.tensor(self.alpha, dtype=torch.float32)

    def forward(self, inputs, targets):
        log_p = F.log_softmax(inputs, dim=1)
        log_p = log_p.gather(1, targets.view(-1, 1)).view(-1)
        p = torch.exp(log_p)
        loss = - (1 - p)**self.gamma * log_p
        
        if self.alpha is not None:
            if self.alpha.device != inputs.device:
                self.alpha = self.alpha.to(inputs.device)
            alpha_t = self.alpha.gather(0, targets.view(-1))
            loss = alpha_t * loss
            
        return loss.mean() if self.reduction == 'mean' else loss.sum()

In [9]:
# --- 2. Hybrid CNN-LSTM Model Architecture ---

class CNNLSTM_FeatureExtractor(nn.Module):
    """Processes the raw ECG signal to generate a deep feature vector."""
    def __init__(self, hidden_size=HIDDEN_SIZE, num_layers=1):
        super().__init__()
        self.hidden_size = hidden_size
        
        # CNN (Feature Extraction)
        self.cnn = nn.Sequential(
            nn.Conv1d(LEAD_COUNT, 32, kernel_size=15, stride=2, padding=7),
            nn.BatchNorm1d(32), nn.ReLU(inplace=True),
            nn.MaxPool1d(kernel_size=3, stride=2, padding=1),
            
            nn.Conv1d(32, 64, kernel_size=11, stride=2, padding=5),
            nn.BatchNorm1d(64), nn.ReLU(inplace=True),
            nn.MaxPool1d(kernel_size=3, stride=2, padding=1),
            
            nn.Conv1d(64, 128, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm1d(128), nn.ReLU(inplace=True)
        )
        
        # LSTM (Temporal Context)
        self.lstm = nn.LSTM(
            input_size=128, hidden_size=hidden_size, num_layers=num_layers,
            batch_first=True, bidirectional=True
        )

    def forward(self, x):
        cnn_out = self.cnn(x) 
        lstm_input = cnn_out.transpose(1, 2) 
        
        _, (h_n, _) = self.lstm(lstm_input)
        
        final_state = torch.cat((h_n[-2, :, :], h_n[-1, :, :]), dim=1)
        return final_state

In [10]:
class FinalMultiInputHybridModel(nn.Module):
    """Combines deep features (CNN-LSTM output) and 8 handcrafted features."""
    def __init__(self, num_classes=NUM_CLASSES, num_clinical_features=NUM_CLINICAL_FEATURES):
        super().__init__()
        
        self.signal_extractor = CNNLSTM_FeatureExtractor(hidden_size=HIDDEN_SIZE)
        
        INPUT_SIZE = (2 * HIDDEN_SIZE) + num_clinical_features 
        
        self.final_fc = nn.Sequential(
            nn.Linear(INPUT_SIZE, 64),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(64, num_classes)
        )

    def forward(self, signal, clinical_features):
        deep_features = self.signal_extractor(signal) 
        
        # Concatenate Deep Features with Handcrafted Features (ensuring float type)
        combined_features = torch.cat((deep_features, clinical_features.float()), dim=1) 
        
        return self.final_fc(combined_features)

In [11]:
# --- 3. Custom Dataset for Hybrid Input ---
class HybridECGDataset(Dataset):
    def __init__(self, df):
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        # Load Signal from absolute path (processed_npy_path)
        try:
            signal = np.load(row['processed_npy_path']).T
            if signal.shape != (LEAD_COUNT, SAMPLES):
                signal = signal[:, :SAMPLES]
            signal = np.nan_to_num(signal, nan=0.0)
        except Exception: 
            signal = np.zeros((LEAD_COUNT, SAMPLES), dtype=np.float32)

        # Load Features from absolute path (feature_path)
        try:
            features = np.load(row['feature_path'])
        except Exception:
            features = np.zeros(NUM_CLINICAL_FEATURES, dtype=np.float32)

        label = {"Low": 0, "Moderate": 1, "High": 2}[row['severity_level']]
        
        return (torch.tensor(signal, dtype=torch.float32), 
                torch.tensor(features, dtype=torch.float32)), label

In [12]:
# --- 4. Evaluation and Checkpointing Functions ---
top_checkpoints = []
N = 5
def save_checkpoint(model, epoch, val_f1):
    global top_checkpoints
    path = os.path.join(CHECKPOINT_DIR, f"hybrid_model_epoch{epoch}.pt")
    torch.save(model.state_dict(), path)
    top_checkpoints.append((val_f1, path))
    top_checkpoints.sort(reverse=True, key=lambda x: x[0])
    if len(top_checkpoints) > N:
        _, worst_path = top_checkpoints.pop()
        if os.path.exists(worst_path):
             os.remove(worst_path)
    print(f"Saved checkpoint: {path} (F1: {val_f1:.4f})")

In [13]:
def hybrid_evaluate(model, dataloader):
    model.eval()
    y_true, y_pred = [], []
    with torch.no_grad():
        for (signal, features), y in dataloader:
            signal, features = signal.to(DEVICE), features.to(DEVICE)
            logits = model(signal, features)
            preds = torch.argmax(logits, dim=1).cpu().numpy()
            y_pred.extend(preds)
            y_true.extend(y.numpy())
    f1 = f1_score(y_true, y_pred, average='weighted')
    acc = np.mean(np.array(y_true) == np.array(y_pred))
    return acc, f1

In [14]:
# Load the cleaned metadata
df = pd.read_csv(CLEANED_CSV_PATH)

# Split the data
train_df, val_df = train_test_split(df, test_size=0.2, stratify=df['severity_level'], random_state=42)

# --- Setup Weights and Sampler ---
train_labels = train_df['severity_level'].map({'Low': 0, 'Moderate': 1, 'High': 2})

In [15]:
# Focal Loss Alpha Weights
class_weights_array = compute_class_weight(class_weight="balanced", classes=np.unique(train_labels), y=train_labels)
alpha_weights = torch.tensor(class_weights_array, dtype=torch.float32).to(DEVICE)

In [16]:
# Sampler Weights (Inverse SQRT frequency for batch balance)
train_labels_np = train_labels.values
class_counts = train_df['severity_level'].value_counts().sort_index()
num_samples = len(train_df)
weights = 1.0 / np.sqrt(class_counts.values) 
sample_weights = weights[train_labels_np]

sampler = WeightedRandomSampler(weights=sample_weights, num_samples=num_samples, replacement=True)

In [17]:
# DataLoaders
train_loader = DataLoader(HybridECGDataset(train_df), batch_size=BATCH_SIZE, sampler=sampler)
val_loader = DataLoader(HybridECGDataset(val_df), batch_size=BATCH_SIZE)

In [18]:
# --- Model, Loss, Optimizer ---
model = FinalMultiInputHybridModel(num_clinical_features=NUM_CLINICAL_FEATURES).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) 

criterion = FocalLoss(alpha=alpha_weights, gamma=2.0) 

scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=5)

In [20]:
# --- Training Loop ---
print("Starting Final Hybrid Training with Focal Loss...")
best_val_f1 = 0.0

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    for (x_signal, x_features), y in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        x_signal, x_features, y = x_signal.to(DEVICE), x_features.to(DEVICE), y.to(DEVICE)
        optimizer.zero_grad()
        
        logits = model(x_signal, x_features)
        loss = criterion(logits, y)

        if torch.isnan(loss): continue
                
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    val_acc, val_f1 = hybrid_evaluate(model, val_loader)
    
    scheduler.step(val_f1)

    print(f"Loss: {avg_loss:.4f} | Val Acc: {val_acc:.4f} | Val F1: {val_f1:.4f}")
    save_checkpoint(model, epoch + 1, val_f1)

Starting Final Hybrid Training with Focal Loss...


Epoch 1/60: 100%|██████████| 785/785 [16:01<00:00,  1.22s/it]


Loss: 0.4168 | Val Acc: 0.4912 | Val F1: 0.5747
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch1.pt (F1: 0.5747)


Epoch 2/60: 100%|██████████| 785/785 [07:21<00:00,  1.78it/s]


Loss: 0.4052 | Val Acc: 0.3499 | Val F1: 0.4208
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch2.pt (F1: 0.4208)


Epoch 3/60: 100%|██████████| 785/785 [04:05<00:00,  3.20it/s]


Loss: 0.4168 | Val Acc: 0.4957 | Val F1: 0.5778
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch3.pt (F1: 0.5778)


Epoch 4/60: 100%|██████████| 785/785 [02:26<00:00,  5.35it/s]


Loss: 0.3971 | Val Acc: 0.5053 | Val F1: 0.5877
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch4.pt (F1: 0.5877)


Epoch 5/60: 100%|██████████| 785/785 [02:00<00:00,  6.51it/s]


Loss: 0.4048 | Val Acc: 0.4927 | Val F1: 0.5750
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch5.pt (F1: 0.5750)


Epoch 6/60: 100%|██████████| 785/785 [01:49<00:00,  7.15it/s]


Loss: 0.3857 | Val Acc: 0.5403 | Val F1: 0.6191
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch6.pt (F1: 0.6191)


Epoch 7/60: 100%|██████████| 785/785 [01:43<00:00,  7.58it/s]


Loss: 0.3790 | Val Acc: 0.4957 | Val F1: 0.5782
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch7.pt (F1: 0.5782)


Epoch 8/60: 100%|██████████| 785/785 [01:40<00:00,  7.80it/s]


Loss: 0.3670 | Val Acc: 0.5575 | Val F1: 0.6344
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch8.pt (F1: 0.6344)


Epoch 9/60: 100%|██████████| 785/785 [01:39<00:00,  7.89it/s]


Loss: 0.4003 | Val Acc: 0.3958 | Val F1: 0.4739
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch9.pt (F1: 0.4739)


Epoch 10/60: 100%|██████████| 785/785 [01:38<00:00,  7.96it/s]


Loss: 0.3906 | Val Acc: 0.5250 | Val F1: 0.6052
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch10.pt (F1: 0.6052)


Epoch 11/60: 100%|██████████| 785/785 [01:45<00:00,  7.42it/s]


Loss: 0.3574 | Val Acc: 0.5706 | Val F1: 0.6452
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch11.pt (F1: 0.6452)


Epoch 12/60: 100%|██████████| 785/785 [01:38<00:00,  7.99it/s]


Loss: 0.3887 | Val Acc: 0.5750 | Val F1: 0.6489
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch12.pt (F1: 0.6489)


Epoch 13/60: 100%|██████████| 785/785 [01:37<00:00,  8.03it/s]


Loss: 0.3743 | Val Acc: 0.5395 | Val F1: 0.6180
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch13.pt (F1: 0.6180)


Epoch 14/60: 100%|██████████| 785/785 [01:43<00:00,  7.58it/s]


Loss: 0.3552 | Val Acc: 0.6307 | Val F1: 0.6931
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch14.pt (F1: 0.6931)


Epoch 15/60: 100%|██████████| 785/785 [01:44<00:00,  7.52it/s]


Loss: 0.3724 | Val Acc: 0.5330 | Val F1: 0.6120
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch15.pt (F1: 0.6120)


Epoch 16/60: 100%|██████████| 785/785 [01:44<00:00,  7.53it/s]


Loss: 0.3712 | Val Acc: 0.5993 | Val F1: 0.6688
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch16.pt (F1: 0.6688)


Epoch 17/60: 100%|██████████| 785/785 [01:44<00:00,  7.52it/s]


Loss: 0.3390 | Val Acc: 0.5298 | Val F1: 0.6084
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch17.pt (F1: 0.6084)


Epoch 18/60: 100%|██████████| 785/785 [01:37<00:00,  8.07it/s]


Loss: 0.3347 | Val Acc: 0.5951 | Val F1: 0.6647
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch18.pt (F1: 0.6647)


Epoch 19/60: 100%|██████████| 785/785 [01:37<00:00,  8.07it/s]


Loss: 0.3603 | Val Acc: 0.4347 | Val F1: 0.5165
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch19.pt (F1: 0.5165)


Epoch 20/60: 100%|██████████| 785/785 [01:44<00:00,  7.48it/s]


Loss: 0.3256 | Val Acc: 0.7715 | Val F1: 0.7890
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch20.pt (F1: 0.7890)


Epoch 21/60: 100%|██████████| 785/785 [01:37<00:00,  8.02it/s]


Loss: 0.3253 | Val Acc: 0.7177 | Val F1: 0.7558
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch21.pt (F1: 0.7558)


Epoch 22/60: 100%|██████████| 785/785 [01:38<00:00,  7.99it/s]


Loss: 0.2943 | Val Acc: 0.6232 | Val F1: 0.6871
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch22.pt (F1: 0.6871)


Epoch 23/60: 100%|██████████| 785/785 [01:38<00:00,  8.00it/s]


Loss: 0.3214 | Val Acc: 0.7670 | Val F1: 0.7883
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch23.pt (F1: 0.7883)


Epoch 24/60: 100%|██████████| 785/785 [01:38<00:00,  8.00it/s]


Loss: 0.2949 | Val Acc: 0.6442 | Val F1: 0.7026
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch24.pt (F1: 0.7026)


Epoch 25/60: 100%|██████████| 785/785 [01:37<00:00,  8.03it/s]


Loss: 0.2859 | Val Acc: 0.6577 | Val F1: 0.7131
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch25.pt (F1: 0.7131)


Epoch 26/60: 100%|██████████| 785/785 [01:38<00:00,  8.01it/s]


Loss: 0.2844 | Val Acc: 0.7290 | Val F1: 0.7628
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch26.pt (F1: 0.7628)


Epoch 27/60: 100%|██████████| 785/785 [01:37<00:00,  8.07it/s]


Loss: 0.2579 | Val Acc: 0.7683 | Val F1: 0.7905
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch27.pt (F1: 0.7905)


Epoch 28/60: 100%|██████████| 785/785 [01:47<00:00,  7.31it/s]


Loss: 0.2186 | Val Acc: 0.7814 | Val F1: 0.7992
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch28.pt (F1: 0.7992)


Epoch 29/60: 100%|██████████| 785/785 [01:44<00:00,  7.49it/s]


Loss: 0.2138 | Val Acc: 0.7706 | Val F1: 0.7926
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch29.pt (F1: 0.7926)


Epoch 30/60: 100%|██████████| 785/785 [01:44<00:00,  7.50it/s]


Loss: 0.2092 | Val Acc: 0.7478 | Val F1: 0.7786
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch30.pt (F1: 0.7786)


Epoch 31/60: 100%|██████████| 785/785 [01:37<00:00,  8.03it/s]


Loss: 0.2233 | Val Acc: 0.7455 | Val F1: 0.7775
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch31.pt (F1: 0.7775)


Epoch 32/60: 100%|██████████| 785/785 [01:37<00:00,  8.09it/s]


Loss: 0.2083 | Val Acc: 0.7728 | Val F1: 0.7950
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch32.pt (F1: 0.7950)


Epoch 33/60: 100%|██████████| 785/785 [01:37<00:00,  8.02it/s]


Loss: 0.2149 | Val Acc: 0.7774 | Val F1: 0.7975
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch33.pt (F1: 0.7975)


Epoch 34/60: 100%|██████████| 785/785 [01:38<00:00,  8.00it/s]


Loss: 0.2193 | Val Acc: 0.7929 | Val F1: 0.8079
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch34.pt (F1: 0.8079)


Epoch 35/60: 100%|██████████| 785/785 [01:39<00:00,  7.86it/s]


Loss: 0.2078 | Val Acc: 0.7741 | Val F1: 0.7968
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch35.pt (F1: 0.7968)


Epoch 36/60: 100%|██████████| 785/785 [01:41<00:00,  7.77it/s]


Loss: 0.2041 | Val Acc: 0.7766 | Val F1: 0.7978
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch36.pt (F1: 0.7978)


Epoch 37/60: 100%|██████████| 785/785 [01:40<00:00,  7.85it/s]


Loss: 0.2052 | Val Acc: 0.7827 | Val F1: 0.8029
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch37.pt (F1: 0.8029)


Epoch 38/60: 100%|██████████| 785/785 [01:39<00:00,  7.92it/s]


Loss: 0.2104 | Val Acc: 0.7943 | Val F1: 0.8084
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch38.pt (F1: 0.8084)


Epoch 39/60: 100%|██████████| 785/785 [01:38<00:00,  7.95it/s]


Loss: 0.2036 | Val Acc: 0.7954 | Val F1: 0.8092
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch39.pt (F1: 0.8092)


Epoch 40/60: 100%|██████████| 785/785 [01:38<00:00,  8.00it/s]


Loss: 0.1714 | Val Acc: 0.7898 | Val F1: 0.8067
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch40.pt (F1: 0.8067)


Epoch 41/60: 100%|██████████| 785/785 [01:37<00:00,  8.02it/s]


Loss: 0.1980 | Val Acc: 0.7819 | Val F1: 0.8019
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch41.pt (F1: 0.8019)


Epoch 42/60: 100%|██████████| 785/785 [01:36<00:00,  8.11it/s]


Loss: 0.1984 | Val Acc: 0.7808 | Val F1: 0.8013
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch42.pt (F1: 0.8013)


Epoch 43/60: 100%|██████████| 785/785 [01:37<00:00,  8.01it/s]


Loss: 0.1857 | Val Acc: 0.7828 | Val F1: 0.8023
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch43.pt (F1: 0.8023)


Epoch 44/60: 100%|██████████| 785/785 [01:37<00:00,  8.03it/s]


Loss: 0.1863 | Val Acc: 0.7968 | Val F1: 0.8101
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch44.pt (F1: 0.8101)


Epoch 45/60: 100%|██████████| 785/785 [01:37<00:00,  8.04it/s]


Loss: 0.1756 | Val Acc: 0.8035 | Val F1: 0.8149
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch45.pt (F1: 0.8149)


Epoch 46/60: 100%|██████████| 785/785 [01:37<00:00,  8.03it/s]


Loss: 0.1625 | Val Acc: 0.8023 | Val F1: 0.8157
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch46.pt (F1: 0.8157)


Epoch 47/60: 100%|██████████| 785/785 [01:37<00:00,  8.09it/s]


Loss: 0.1858 | Val Acc: 0.7835 | Val F1: 0.8029
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch47.pt (F1: 0.8029)


Epoch 48/60: 100%|██████████| 785/785 [01:49<00:00,  7.17it/s]


Loss: 0.1611 | Val Acc: 0.7962 | Val F1: 0.8107
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch48.pt (F1: 0.8107)


Epoch 49/60: 100%|██████████| 785/785 [02:00<00:00,  6.54it/s]


Loss: 0.1607 | Val Acc: 0.8094 | Val F1: 0.8187
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch49.pt (F1: 0.8187)


Epoch 50/60: 100%|██████████| 785/785 [05:40<00:00,  2.31it/s]


Loss: 0.1474 | Val Acc: 0.7994 | Val F1: 0.8122
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch50.pt (F1: 0.8122)


Epoch 51/60: 100%|██████████| 785/785 [06:53<00:00,  1.90it/s]


Loss: 0.1545 | Val Acc: 0.7938 | Val F1: 0.8096
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch51.pt (F1: 0.8096)


Epoch 52/60: 100%|██████████| 785/785 [03:53<00:00,  3.37it/s]


Loss: 0.1701 | Val Acc: 0.7898 | Val F1: 0.8056
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch52.pt (F1: 0.8056)


Epoch 53/60: 100%|██████████| 785/785 [02:43<00:00,  4.80it/s]


Loss: 0.1463 | Val Acc: 0.7873 | Val F1: 0.8042
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch53.pt (F1: 0.8042)


Epoch 54/60: 100%|██████████| 785/785 [01:57<00:00,  6.66it/s]


Loss: 0.1527 | Val Acc: 0.8070 | Val F1: 0.8172
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch54.pt (F1: 0.8172)


Epoch 55/60: 100%|██████████| 785/785 [01:58<00:00,  6.61it/s]


Loss: 0.1355 | Val Acc: 0.8107 | Val F1: 0.8188
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch55.pt (F1: 0.8188)


Epoch 56/60: 100%|██████████| 785/785 [01:46<00:00,  7.34it/s]


Loss: 0.1502 | Val Acc: 0.8059 | Val F1: 0.8154
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch56.pt (F1: 0.8154)


Epoch 57/60: 100%|██████████| 785/785 [01:45<00:00,  7.46it/s]


Loss: 0.1609 | Val Acc: 0.8102 | Val F1: 0.8184
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch57.pt (F1: 0.8184)


Epoch 58/60: 100%|██████████| 785/785 [01:49<00:00,  7.20it/s]


Loss: 0.1618 | Val Acc: 0.8086 | Val F1: 0.8182
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch58.pt (F1: 0.8182)


Epoch 59/60: 100%|██████████| 785/785 [01:46<00:00,  7.39it/s]


Loss: 0.1459 | Val Acc: 0.8133 | Val F1: 0.8194
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch59.pt (F1: 0.8194)


Epoch 60/60: 100%|██████████| 785/785 [01:46<00:00,  7.37it/s]


Loss: 0.1643 | Val Acc: 0.8139 | Val F1: 0.8209
Saved checkpoint: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch60.pt (F1: 0.8209)


In [21]:
print("\n--- Final Evaluation ---")
if top_checkpoints:
    best_path = top_checkpoints[0][1]
    model.load_state_dict(torch.load(best_path))
    print(f"Loaded best model from: {best_path}")

model.eval()
y_true, y_pred = [], []
with torch.no_grad():
    for (x_signal, x_features), y in val_loader:
        x_signal, x_features = x_signal.to(DEVICE), x_features.to(DEVICE)
        out = model(x_signal, x_features)
        preds = torch.argmax(out, dim=1).cpu().numpy()
        y_pred.extend(preds)
        y_true.extend(y.numpy())


--- Final Evaluation ---
Loaded best model from: ./cnn_lstm_hybrid_checkpoints\hybrid_model_epoch60.pt


In [None]:
print("\nFinal Classification Report (Hybrid CNN-LSTM):")
print(classification_report(y_true, y_pred, target_names=["Low", "Moderate", "High"]))


Final Classification Report (Hybrid CNN-LSTM):
              precision    recall  f1-score   support

         Low       0.91      0.88      0.90      5493
    Moderate       0.20      0.13      0.16       190
        High       0.29      0.40      0.33       593

    accuracy                           0.81      6276
   macro avg       0.46      0.47      0.46      6276
weighted avg       0.83      0.81      0.82      6276



In [None]:
''' 
Classification Report Interpreting:

Support: total number of samples in the validation set
Weighted Average:   F1 score weighted by the size of each class, in this case will always be high due to 'Low' class dominating the score. Poor 
                    metric for imbalanced data.
Macro Average:      Calculates metric for each class independently, and then takes unweighted arithmetic mean of the values.
'''

In [None]:
'''
The major weakness is the Recall for Moderate (0.13) and High (0.40), meaning the model misses most of the severe cases.

IMPROVE USING:

A. Data Augmentation
You've trained on the same data 60 times. Augmentation creates new, slightly varied training examples without collecting new data, directly 
addressing the data scarcity for minority classes.
Action: 
    Implement Time-Series Augmentations directly in your HybridECGDataset.__getitem__ method (applied only to training data). 
    Useful techniques include:
        Time-Warping: Slightly stretching or compressing the signal.
        Scaling/Jittering: Randomly scaling the amplitude or adding Gaussian noise.
        Random Lead Dropout: Temporarily setting one or two leads to zero to improve robustness.

B. Hyperparameter Tuning (Focal Loss Gamma)
The F1 is high, but the Recall is low, suggesting the model is still finding it "easy" to achieve a good loss value without correctly classifying
the hardest samples.
Action: 
    Increase the Focal Loss gamma parameter (currently 2) to 3 or 4. A higher gamma further penalizes predictions where the model is confident 
    but wrong, forcing it to focus more intensely on the hardest, most ambiguous (often minority) samples.

'''

In [None]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# import numpy as np
# import pandas as pd
# import os
# from tqdm import tqdm
# from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
# from sklearn.metrics import f1_score, classification_report
# from sklearn.model_selection import train_test_split
# from sklearn.utils.class_weight import compute_class_weight
# from torch.optim.lr_scheduler import ReduceLROnPlateau
# import warnings

# # Suppress PyTorch UserWarning about DataLoader iterator re-creation
# warnings.filterwarnings("ignore", category=UserWarning)

# # --- 0. Configuration ---
# CLEANED_CSV_PATH = 'metadata_with_features.csv' 
# CHECKPOINT_DIR = './cnn_lstm_hybrid_checkpoints'
# os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# BATCH_SIZE = 32
# EPOCHS = 60 
# LEAD_COUNT = 12
# SAMPLES = 2500
# NUM_CLASSES = 3 
# NUM_CLINICAL_FEATURES = 8 
# HIDDEN_SIZE = 64
# LEARNING_RATE = 2e-4
# DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# DEVICE

# # --- 1. Focal Loss Function (For Imbalance) ---
# class FocalLoss(nn.Module):
#     def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
#         super().__init__()
#         self.alpha = alpha
#         self.gamma = gamma
#         self.reduction = reduction
#         if self.alpha is not None and not isinstance(self.alpha, torch.Tensor):
#             self.alpha = torch.tensor(self.alpha, dtype=torch.float32)

#     def forward(self, inputs, targets):
#         log_p = F.log_softmax(inputs, dim=1)
#         log_p = log_p.gather(1, targets.view(-1, 1)).view(-1)
#         p = torch.exp(log_p)
#         loss = - (1 - p)**self.gamma * log_p
        
#         if self.alpha is not None:
#             if self.alpha.device != inputs.device:
#                 self.alpha = self.alpha.to(inputs.device)
#             alpha_t = self.alpha.gather(0, targets.view(-1))
#             loss = alpha_t * loss
            
#         return loss.mean() if self.reduction == 'mean' else loss.sum()
		
# # --- 2. Hybrid CNN-LSTM Model Architecture ---

# class CNNLSTM_FeatureExtractor(nn.Module):
#     """Processes the raw ECG signal to generate a deep feature vector."""
#     def __init__(self, hidden_size=HIDDEN_SIZE, num_layers=1):
#         super().__init__()
#         self.hidden_size = hidden_size
        
#         # CNN (Feature Extraction)
#         self.cnn = nn.Sequential(
#             nn.Conv1d(LEAD_COUNT, 32, kernel_size=15, stride=2, padding=7),
#             nn.BatchNorm1d(32), nn.ReLU(inplace=True),
#             nn.MaxPool1d(kernel_size=3, stride=2, padding=1),
            
#             nn.Conv1d(32, 64, kernel_size=11, stride=2, padding=5),
#             nn.BatchNorm1d(64), nn.ReLU(inplace=True),
#             nn.MaxPool1d(kernel_size=3, stride=2, padding=1),
            
#             nn.Conv1d(64, 128, kernel_size=7, stride=2, padding=3),
#             nn.BatchNorm1d(128), nn.ReLU(inplace=True)
#         )
        
#         # LSTM (Temporal Context)
#         self.lstm = nn.LSTM(
#             input_size=128, hidden_size=hidden_size, num_layers=num_layers,
#             batch_first=True, bidirectional=True
#         )

#     def forward(self, x):
#         cnn_out = self.cnn(x) 
#         lstm_input = cnn_out.transpose(1, 2) 
        
#         _, (h_n, _) = self.lstm(lstm_input)
        
#         final_state = torch.cat((h_n[-2, :, :], h_n[-1, :, :]), dim=1)
#         return final_state
		
# class FinalMultiInputHybridModel(nn.Module):
#     """Combines deep features (CNN-LSTM output) and 8 handcrafted features."""
#     def __init__(self, num_classes=NUM_CLASSES, num_clinical_features=NUM_CLINICAL_FEATURES):
#         super().__init__()
        
#         self.signal_extractor = CNNLSTM_FeatureExtractor(hidden_size=HIDDEN_SIZE)
        
#         INPUT_SIZE = (2 * HIDDEN_SIZE) + num_clinical_features 
        
#         self.final_fc = nn.Sequential(
#             nn.Linear(INPUT_SIZE, 64),
#             nn.ReLU(inplace=True),
#             nn.Dropout(0.5),
#             nn.Linear(64, num_classes)
#         )

#     def forward(self, signal, clinical_features):
#         deep_features = self.signal_extractor(signal) 
        
#         # Concatenate Deep Features with Handcrafted Features (ensuring float type)
#         combined_features = torch.cat((deep_features, clinical_features.float()), dim=1) 
        
#         return self.final_fc(combined_features)
		
# # --- 3. Custom Dataset for Hybrid Input ---
# class HybridECGDataset(Dataset):
#     def __init__(self, df):
#         self.df = df.reset_index(drop=True)

#     def __len__(self):
#         return len(self.df)

#     def __getitem__(self, idx):
#         row = self.df.iloc[idx]
        
#         # Load Signal from absolute path (processed_npy_path)
#         try:
#             signal = np.load(row['processed_npy_path']).T
#             if signal.shape != (LEAD_COUNT, SAMPLES):
#                 signal = signal[:, :SAMPLES]
#             signal = np.nan_to_num(signal, nan=0.0)
#         except Exception: 
#             signal = np.zeros((LEAD_COUNT, SAMPLES), dtype=np.float32)

#         # Load Features from absolute path (feature_path)
#         try:
#             features = np.load(row['feature_path'])
#         except Exception:
#             features = np.zeros(NUM_CLINICAL_FEATURES, dtype=np.float32)

#         label = {"Low": 0, "Moderate": 1, "High": 2}[row['severity_level']]
        
#         return (torch.tensor(signal, dtype=torch.float32), 
#                 torch.tensor(features, dtype=torch.float32)), label
				
# # --- 4. Evaluation and Checkpointing Functions ---
# top_checkpoints = []
# N = 5
# def save_checkpoint(model, epoch, val_f1):
#     global top_checkpoints
#     path = os.path.join(CHECKPOINT_DIR, f"hybrid_model_epoch{epoch}.pt")
#     torch.save(model.state_dict(), path)
#     top_checkpoints.append((val_f1, path))
#     top_checkpoints.sort(reverse=True, key=lambda x: x[0])
#     if len(top_checkpoints) > N:
#         _, worst_path = top_checkpoints.pop()
#         if os.path.exists(worst_path):
#              os.remove(worst_path)
#     print(f"Saved checkpoint: {path} (F1: {val_f1:.4f})")
	
# def hybrid_evaluate(model, dataloader):
#     model.eval()
#     y_true, y_pred = [], []
#     with torch.no_grad():
#         for (signal, features), y in dataloader:
#             signal, features = signal.to(DEVICE), features.to(DEVICE)
#             logits = model(signal, features)
#             preds = torch.argmax(logits, dim=1).cpu().numpy()
#             y_pred.extend(preds)
#             y_true.extend(y.numpy())
#     f1 = f1_score(y_true, y_pred, average='weighted')
#     acc = np.mean(np.array(y_true) == np.array(y_pred))
#     return acc, f1
	
# # Load the cleaned metadata
# df = pd.read_csv(CLEANED_CSV_PATH)

# # Split the data
# train_df, val_df = train_test_split(df, test_size=0.2, stratify=df['severity_level'], random_state=42)

# # --- Setup Weights and Sampler ---
# train_labels = train_df['severity_level'].map({'Low': 0, 'Moderate': 1, 'High': 2})

# # Focal Loss Alpha Weights
# class_weights_array = compute_class_weight(class_weight="balanced", classes=np.unique(train_labels), y=train_labels)
# alpha_weights = torch.tensor(class_weights_array, dtype=torch.float32).to(DEVICE)

# # Sampler Weights (Inverse SQRT frequency for batch balance)
# train_labels_np = train_labels.values
# class_counts = train_df['severity_level'].value_counts().sort_index()
# num_samples = len(train_df)
# weights = 1.0 / np.sqrt(class_counts.values) 
# sample_weights = weights[train_labels_np]

# sampler = WeightedRandomSampler(weights=sample_weights, num_samples=num_samples, replacement=True)

# # DataLoaders
# train_loader = DataLoader(HybridECGDataset(train_df), batch_size=BATCH_SIZE, sampler=sampler)
# val_loader = DataLoader(HybridECGDataset(val_df), batch_size=BATCH_SIZE)

# # --- Model, Loss, Optimizer ---
# model = FinalMultiInputHybridModel(num_clinical_features=NUM_CLINICAL_FEATURES).to(DEVICE)
# optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) 

# criterion = FocalLoss(alpha=alpha_weights, gamma=2.0) 

# scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=5)

# # --- Training Loop ---
# print("Starting Final Hybrid Training with Focal Loss...")
# best_val_f1 = 0.0

# for epoch in range(EPOCHS):
#     model.train()
#     total_loss = 0
#     for (x_signal, x_features), y in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
#         x_signal, x_features, y = x_signal.to(DEVICE), x_features.to(DEVICE), y.to(DEVICE)
#         optimizer.zero_grad()
        
#         logits = model(x_signal, x_features)
#         loss = criterion(logits, y)

#         if torch.isnan(loss): continue
                
#         loss.backward()
#         torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
#         optimizer.step()
#         total_loss += loss.item()

#     avg_loss = total_loss / len(train_loader)
#     val_acc, val_f1 = hybrid_evaluate(model, val_loader)
    
#     scheduler.step(val_f1)

#     print(f"Loss: {avg_loss:.4f} | Val Acc: {val_acc:.4f} | Val F1: {val_f1:.4f}")
#     save_checkpoint(model, epoch + 1, val_f1)
	
# 	print("\n--- Final Evaluation ---")
# if top_checkpoints:
#     best_path = top_checkpoints[0][1]
#     model.load_state_dict(torch.load(best_path))
#     print(f"Loaded best model from: {best_path}")

# model.eval()
# y_true, y_pred = [], []
# with torch.no_grad():
#     for (x_signal, x_features), y in val_loader:
#         x_signal, x_features = x_signal.to(DEVICE), x_features.to(DEVICE)
#         out = model(x_signal, x_features)
#         preds = torch.argmax(out, dim=1).cpu().numpy()
#         y_pred.extend(preds)
#         y_true.extend(y.numpy())
		
# print("\nFinal Classification Report (Hybrid CNN-LSTM):")
# print(classification_report(y_true, y_pred, target_names=["Low", "Moderate", "High"]))

