In [5]:
# [Capstone Project] ML for Healthcare
# Ramisha Mahiyat
# MODEL-CNN-BENCHMARK-MODEL
 

In [1]:
!pip install numpy pandas scikit-learn torch matplotlib



In [2]:
# model_cnn_benchmark_ecg.py
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix
from sklearn.preprocessing import StandardScaler

# Load data
train_data = pd.read_pickle('ECG5000_train.pickle')
val_data = pd.read_pickle('ECG5000_validation.pickle')

X_train = train_data[:, 1:]
y_train = train_data[:, 0].astype(int)
X_val = val_data[:, 1:]
y_val = val_data[:, 0].astype(int)

# Normalize
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_val = scaler.transform(X_val)

torch_X_train = torch.tensor(X_train, dtype=torch.float32).unsqueeze(1)
torch_y_train = torch.tensor(y_train, dtype=torch.long)
torch_X_val = torch.tensor(X_val, dtype=torch.float32).unsqueeze(1)
torch_y_val = torch.tensor(y_val, dtype=torch.long)

train_dataset = TensorDataset(torch_X_train, torch_y_train)
val_dataset = TensorDataset(torch_X_val, torch_y_val)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

# CNN model
class SimpleECGCNN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.conv1 = nn.Conv1d(1, 16, 5, padding=2)
        self.bn1 = nn.BatchNorm1d(16)
        self.conv2 = nn.Conv1d(16, 32, 5, padding=2)
        self.bn2 = nn.BatchNorm1d(32)
        self.conv3 = nn.Conv1d(32, 64, 5, padding=2)
        self.bn3 = nn.BatchNorm1d(64)
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(64, num_classes)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool1d(x, 2)
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool1d(x, 2)
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.pool(x).squeeze(-1)
        return self.fc(x)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleECGCNN(num_classes=len(np.unique(y_train))).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Train
for epoch in range(10):
    model.train()
    for X_batch, y_batch in train_loader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        optimizer.zero_grad()
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1} complete")

# Evaluate
model.eval()
y_true, y_pred, y_scores = [], [], []
with torch.no_grad():
    for X_batch, y_batch in val_loader:
        X_batch = X_batch.to(device)
        outputs = model(X_batch)
        probs = F.softmax(outputs, dim=1).cpu().numpy()
        preds = np.argmax(probs, axis=1)
        y_true.extend(y_batch.numpy())
        y_pred.extend(preds)
        y_scores.extend(probs)

print("\nEvaluation Metrics (CNN Benchmark):")
print("Accuracy:", accuracy_score(y_true, y_pred))
print("Precision:", precision_score(y_true, y_pred, average='weighted'))
print("Recall:", recall_score(y_true, y_pred, average='weighted'))
print("F1 Score:", f1_score(y_true, y_pred, average='weighted'))
print("AUROC:", roc_auc_score(y_true, y_scores, multi_class='ovr', average='weighted'))
print("Confusion Matrix:\n", confusion_matrix(y_true, y_pred))


Epoch 1 complete
Epoch 2 complete
Epoch 3 complete
Epoch 4 complete
Epoch 5 complete
Epoch 6 complete
Epoch 7 complete
Epoch 8 complete
Epoch 9 complete
Epoch 10 complete

Evaluation Metrics (CNN Benchmark):
Accuracy: 0.872
Precision: 0.7998039555590576
Recall: 0.872
F1 Score: 0.8337404388714734
AUROC: 0.9502204675032211
Confusion Matrix:
 [[752  29   0   0   0]
 [ 34 556   0   0   0]
 [ 12  31   0   0   0]
 [ 15  60   0   0   0]
 [  1  10   0   0   0]]


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
