In [1]:
import numpy as np
import torch
import os
import time
import json
import sys
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import roc_auc_score, confusion_matrix
import sys
sys.path.append('/home/peili')
from src.datasets import get_dataset_transformer, get_dataset_jepa
from src.dataloaders import get_dataloader
from src.encoders import get_encoder
from src.models import get_model

In [11]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
print(torch.cuda.is_available())  # Verify CUDA is available
print(torch.cuda.current_device())  # Should return 0 (mapped to device 1)

True
0


# Load Dataset

In [17]:
dataset_name = 'Ga'
model_name = 'ecg-jepa'
seg_method = 'fix'
train_encoder = 'train'
num_epochs = 10
ckpt = None

fixed= True if seg_method == 'fix' else False
train_encoder = True if train_encoder == 'train' else False
save_dir = f"{dataset_name}_{model_name}_{seg_method}" + "_encoder_trained" if train_encoder else ''
os.makedirs(save_dir, exist_ok=True) 

if model_name == 'transformer':
    train_dataset, test_dataset = get_dataset_transformer(dataset_name, fixed=fixed, window_size=10, overlap=5)
elif model_name == 'ecg-jepa':
    train_dataset, test_dataset = get_dataset_jepa(dataset_name, reload=True)



Processing ECG files:   0%|          | 0/10344 [00:00<?, ?it/s]

Processing ECG files: 100%|██████████| 10344/10344 [00:28<00:00, 364.21it/s]
Processing Records: 100%|██████████| 10292/10292 [00:01<00:00, 10114.63it/s]
Combining Results: 100%|██████████| 10292/10292 [00:04<00:00, 2263.71it/s]


invalid samples: 7


# Set Up Model

In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder, embed_dim = get_encoder(model_name, device, load=True)
train_loader, test_loader = get_dataloader(model_name, train_dataset, test_dataset, encoder, device, train_encoder)

criterion = nn.CrossEntropyLoss()
n_labels = 7
model = get_model(model_name, embed_dim, n_labels, device)

if train_encoder:
    assert encoder != None
    for param in encoder.parameters():
        param.requires_grad = True
    optimizer = optim.Adam(list(encoder.parameters()) + list(model.parameters()), lr=0.001)
else:
    optimizer = optim.Adam(model.parameters(), lr=0.001)

if ckpt:
    checkpoint = torch.load(ckpt)
    model.load_state_dict(checkpoint['model_state_dict'])

  ckpt = torch.load(ckpt_dir)


# Training

In [20]:
dataset_name = 'Ga'
model_name = 'ecg-jepa'
seg_method = 'fix'
train_encoder = 'train'
num_epochs = 10
ckpt = None

fixed= True if seg_method == 'fix' else False
train_encoder = True if train_encoder == 'train' else False
save_dir = f"{dataset_name}_{model_name}_{seg_method}" + "_encoder_trained" if train_encoder else ''

start_time = time.time()
training_metrics = {"loss": []}

for epoch in range(num_epochs):
    if train_encoder: encoder.train()
    model.train()
    running_loss = 0.0

    for batch in tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs}', unit='batch'):
        inputs, labels = batch
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()

        if train_encoder:
            features = encoder.representation(inputs)
            outputs = model(features)
        else:
            outputs = model(inputs)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()

    average_loss = running_loss / len(train_loader)
    training_metrics["loss"].append(average_loss)
    print(f'Average Loss: {average_loss:.4f}')

    checkpoint_path = f'{save_dir}/epoch_{epoch + 1}.pth'

    if train_encoder:
        torch.save({
            "epoch": epoch + 1,
            "encoder": encoder.state_dict(),
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "loss": average_loss
        }, checkpoint_path)
    else:
        torch.save({
            "epoch": epoch + 1,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "loss": average_loss
        }, checkpoint_path)

    print(f'Model checkpoint saved at {checkpoint_path}')
    torch.cuda.empty_cache()
end_time = time.time()
training_time = end_time - start_time
print(f"Time taken: {training_time} seconds")

Epoch 1/10: 100%|██████████| 258/258 [06:01<00:00,  1.40s/batch]


Average Loss: 1.6772
Model checkpoint saved at Ga_ecg-jepa_fix_encoder_trained/epoch_1.pth


Epoch 2/10: 100%|██████████| 258/258 [06:00<00:00,  1.40s/batch]


Average Loss: 1.6561
Model checkpoint saved at Ga_ecg-jepa_fix_encoder_trained/epoch_2.pth


Epoch 3/10: 100%|██████████| 258/258 [05:59<00:00,  1.39s/batch]


Average Loss: 1.6422
Model checkpoint saved at Ga_ecg-jepa_fix_encoder_trained/epoch_3.pth


Epoch 4/10: 100%|██████████| 258/258 [05:59<00:00,  1.39s/batch]


Average Loss: 1.6240
Model checkpoint saved at Ga_ecg-jepa_fix_encoder_trained/epoch_4.pth


Epoch 5/10: 100%|██████████| 258/258 [06:01<00:00,  1.40s/batch]


Average Loss: 1.6177
Model checkpoint saved at Ga_ecg-jepa_fix_encoder_trained/epoch_5.pth


Epoch 6/10: 100%|██████████| 258/258 [06:04<00:00,  1.41s/batch]


Average Loss: 1.6256
Model checkpoint saved at Ga_ecg-jepa_fix_encoder_trained/epoch_6.pth


Epoch 7/10: 100%|██████████| 258/258 [06:01<00:00,  1.40s/batch]


Average Loss: 1.6094
Model checkpoint saved at Ga_ecg-jepa_fix_encoder_trained/epoch_7.pth


Epoch 8/10: 100%|██████████| 258/258 [04:13<00:00,  1.02batch/s]


Average Loss: 1.6023
Model checkpoint saved at Ga_ecg-jepa_fix_encoder_trained/epoch_8.pth


Epoch 9/10: 100%|██████████| 258/258 [02:59<00:00,  1.44batch/s]


Average Loss: 1.6019
Model checkpoint saved at Ga_ecg-jepa_fix_encoder_trained/epoch_9.pth


Epoch 10/10: 100%|██████████| 258/258 [03:01<00:00,  1.43batch/s]


Average Loss: 1.5958
Model checkpoint saved at Ga_ecg-jepa_fix_encoder_trained/epoch_10.pth
Time taken: 3155.2359256744385 seconds


# Evaluation

In [21]:
start_time = time.time()
encoder.eval()
model.eval()

y_true = []
y_pred = []
y_scores = []

total = 0
correct = 0

with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        

        if train_encoder:
            features = encoder.representation(inputs)
            outputs = model(features)
        else:
            outputs = model(inputs)

        _, predicted = torch.max(outputs, 1)
        
        y_true.extend(labels.cpu().numpy())
        y_pred.extend(predicted.cpu().numpy())
        y_scores.extend(torch.softmax(outputs, dim=1).cpu().numpy())  # Probabilities for all classes
    
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    torch.cuda.empty_cache()

accuracy = 100 * correct / total
print(f'Accuracy on test data: {accuracy:.2f}%')

# Evaluation: Macro-averaged AUC, Sensitivity, and Specificity
y_true_one_hot = np.zeros((len(y_true), y_scores[0].shape[0]))  # Convert to one-hot encoding
y_true_one_hot[np.arange(len(y_true)), y_true] = 1

# AUC calculation for multilabel (macro-averaged)
auc = roc_auc_score(y_true_one_hot, np.array(y_scores), average='macro')
print(f'Macro-Averaged AUC: {auc:.4f}')

# Sensitivity and Specificity per class
num_classes = y_scores[0].shape[0]

sensitivities = []
specificities = []

for i in range(num_classes):
    y_true_binary = y_true_one_hot[:, i]  # Binary ground truth for class i

    # Correct binary predictions based on the highest probability class
    y_pred_binary = (np.argmax(np.array(y_scores), axis=1) == i).astype(int)

    tn, fp, fn, tp = confusion_matrix(y_true_binary, y_pred_binary).ravel()

    # Class-specific metrics
    sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0

    sensitivities.append(sensitivity)
    specificities.append(specificity)

    print(f'Class {i}: Sensitivity: {sensitivity:.4f}, Specificity: {specificity:.4f}')

# Macro-averaged metrics
macro_sensitivity = np.mean(sensitivities)
macro_specificity = np.mean(specificities)
print(f'Macro-Averaged Sensitivity: {macro_sensitivity:.4f}')
print(f'Macro-Averaged Specificity: {macro_specificity:.4f}')

test_metrics = {
    "accuracy": accuracy,
    "macro_auc": auc,
    "macro_sensitivity": macro_sensitivity,
    "macro_specificity": macro_specificity,
    "class_sensitivities": sensitivities,
    "class_specificities": specificities,
}

metrics_path = f"{save_dir}/test_metrics.json"
with open(metrics_path, "w") as f:
    json.dump(test_metrics, f)
print(f"Test metrics saved to {metrics_path}")

end_time = time.time()
testing_time = end_time - start_time
print(f"Time taken: {testing_time} seconds")

Accuracy on test data: 36.62%
Macro-Averaged AUC: 0.5462
Class 0: Sensitivity: 0.9164, Specificity: 0.2363
Class 1: Sensitivity: 0.0000, Specificity: 1.0000
Class 2: Sensitivity: 0.0000, Specificity: 1.0000
Class 3: Sensitivity: 0.0046, Specificity: 0.9963
Class 4: Sensitivity: 0.0267, Specificity: 0.9771
Class 5: Sensitivity: 0.0000, Specificity: 1.0000
Class 6: Sensitivity: 0.2675, Specificity: 0.8647
Macro-Averaged Sensitivity: 0.1736
Macro-Averaged Specificity: 0.8678
Test metrics saved to Ga_ecg-jepa_fix_encoder_trained/test_metrics.json
Time taken: 26.18837881088257 seconds
