In [None]:
import numpy as np
import torch
import torch.nn as nn


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
trainData = np.loadtxt("dataset/ECG5000/ECG5000_TRAIN.txt")
testData = np.loadtxt("dataset/ECG5000/ECG5000_TEST.txt")

data = np.vstack([trainData, testData])

labels  = data[:, 0].astype(int)
signals = data[:, 1:]

In [None]:
normal_signals   = signals[labels == 1]
abnormal_signals = signals[labels != 1]

In [None]:
!pip install scikit-learn


In [None]:
from sklearn.model_selection import train_test_split

# 80% train, 20% temp
normal_train, normal_temp = train_test_split(
    normal_signals,
    test_size=0.2,
    random_state=42,
    shuffle=True
)

# Split temp 10% val, 10% test
normal_val, normal_test = train_test_split(
    normal_temp,
    test_size=0.5,
    random_state=42,
    shuffle=True
)



In [None]:
x_train = torch.tensor(normal_train, dtype=torch.float32)
x_val = torch.tensor(normal_val, dtype=torch.float32)
x_test = torch.tensor(np.vstack([normal_test, abnormal_signals]), dtype=torch.float32)
y_test = torch.tensor(np.concatenate([
    np.zeros(len(normal_test)),      # normal = 0
    np.ones(len(abnormal_signals))    # abnormal = 1
]), dtype=torch.float32)
print(x_train.shape)
print(x_val.shape)
print(x_test.shape)

Defining the model

In [None]:
class ECGTransformerAutoencoder(nn.Module):
    def __init__(
        self,
        signal_length=140,
        patch_size=10,
        d_model=64,
        nhead=4,
        num_layers=2
    ):
        super().__init__()

        self.signal_length = signal_length
        self.patch_size = patch_size
        self.d_model = d_model

        self.num_patches = signal_length // patch_size

        # Patch embedding
        self.patch_embed = nn.Linear(patch_size, d_model)

        # Learned positional embedding
        self.pos_embed = nn.Parameter(
            torch.zeros(1, self.num_patches, d_model)
        )

        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=128,
            batch_first=True
        )
        self.encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=num_layers
        )

        # Decoder
        self.decoder = nn.Linear(d_model, patch_size)

    def forward(self, x):
        """
        x: [B, 140]
        """
        B, T = x.shape
        N = self.num_patches

        # 1) split into patches
        x = x.view(B, N, self.patch_size)          # [B, 14, 10]

        # 2) patch embedding
        x = self.patch_embed(x)                    # [B, 14, 64]

        # 3) add positional encoding
        x = x + self.pos_embed                     # [B, 14, 64]

        # 4) transformer encoder
        x = self.encoder(x)                        # [B, 14, 64]

        # 5) decode patches
        x = self.decoder(x)                        # [B, 14, 10]

        # 6) reconstruct signal
        x = x.contiguous().view(B, T)              # [B, 140]

        return x


In [None]:
model = ECGTransformerAutoencoder()

with torch.no_grad():
    recon = model(x_train[:8])

print("Input shape:", x_train[:8].shape)
print("Reconstruction shape:", recon.shape)


In [None]:
model = ECGTransformerAutoencoder()

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)


In [None]:
import torch
print(torch.cuda.is_available())
print(next(model.parameters()).device)
model = model.to(device)
x_train = x_train.to(device)
x_val   = x_val.to(device)
x_test  = x_test.to(device)
y_test  = y_test.to(device)

print(next(model.parameters()).device)

In [None]:
train_losses = []
val_losses = []

num_epochs = 30
batch_size = 64

for epoch in range(num_epochs):

    # ===== TRAINING =====
    model.train()
    train_loss = 0.0
    train_batches = 0

    for i in range(0, len(x_train), batch_size):
        batch = x_train[i:i + batch_size]

        optimizer.zero_grad()

        recon = model(batch)
        loss = criterion(recon, batch)

        loss.backward() #gradient comp.
        optimizer.step()

        train_loss += loss.item()
        train_batches += 1

    train_loss /= train_batches
    train_losses.append(train_loss)

    # ===== VALIDATION =====
    model.eval()
    val_loss = 0.0
    val_batches = 0

    with torch.no_grad():
        for i in range(0, len(x_val), batch_size):
            batch = x_val[i:i + batch_size]

            recon = model(batch)
            loss = criterion(recon, batch)

            val_loss += loss.item()
            val_batches += 1

    val_loss /= val_batches
    val_losses.append(val_loss)

    # ===== LOGGING =====
    print(
        f"Epoch {epoch+1:02d} | "
        f"Train Loss: {train_loss:.4f} | "
        f"Val Loss: {val_loss:.4f}"
    )



In [None]:
!pip install matplotlib


In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(8, 5))
plt.plot(train_losses, label="Train Loss")
plt.plot(val_losses, label="Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Reconstruction Loss (MSE)")
plt.title("Training vs Validation Loss")
plt.legend()
plt.grid(True)
plt.show()


In [None]:
model.eval()
errors = []

with torch.no_grad():
    for i in range(0, len(x_test), batch_size):
        batch = x_test[i:i + batch_size].to(device)

        recon = model(batch)
        batch_error = torch.mean((batch - recon) ** 2, dim=1)
        errors.append(batch_error.cpu())

errors = torch.cat(errors).numpy()

In [None]:
normal_errors = errors[y_test.cpu().numpy() == 0]
abnormal_errors = errors[y_test.cpu().numpy() == 1]


In [None]:
import matplotlib.pyplot as plt

plt.hist(normal_errors, bins=50, alpha=0.7, label="Normal")
plt.hist(abnormal_errors, bins=50, alpha=0.7, label="Abnormal")
plt.xlabel("Reconstruction Error")
plt.ylabel("Count")
plt.legend()
plt.title("Reconstruction Error Distribution")
plt.show()


In [None]:
from sklearn.metrics import roc_auc_score

y = y_test.detach().cpu().numpy()
auc = roc_auc_score(y, errors)
print("ROC-AUC:", auc)


In [None]:
from sklearn.metrics import roc_curve
import matplotlib.pyplot as plt

fpr, tpr, _ = roc_curve(y, errors)
plt.figure()
plt.plot(fpr, tpr)
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curve")
plt.grid(True)
plt.show()

In [None]:
model.eval()
val_errs = []

with torch.no_grad():
    for i in range(0, len(x_val), batch_size):
        batch = x_val[i:i+batch_size].to(device)
        recon = model(batch)
        e = torch.mean((batch - recon)**2, dim=1)
        val_errs.append(e.cpu())

val_errs = torch.cat(val_errs).numpy()
threshold = np.percentile(val_errs, 95)   # try 99, 97.5, 95, 90


print("Threshold:", threshold)

In [None]:
from sklearn.metrics import confusion_matrix, classification_report

y_pred = (errors > threshold).astype(int)
print(confusion_matrix(y, y_pred))
print(classification_report(y, y_pred, digits=4))


In [None]:
y_test_np = y_test.detach().cpu().numpy()
y_pred_np = y_pred.astype(int)   # already numpy in your case, but safe

# find indices of true positives
tp_indices = np.where((y_test_np == 1) & (y_pred_np == 1))[0]

idx = tp_indices[0]  # take first one
model.eval()

with torch.no_grad():
    x = x_test[idx:idx+1].to(device)      # shape [1, T]
    recon = model(x)                      # shape [1, T]

x = x.cpu().numpy().squeeze()
recon = recon.cpu().numpy().squeeze()
residual = (x - recon)**2
res_th = np.percentile(residual, 95)


In [None]:
import matplotlib.pyplot as plt

t = np.arange(len(x))

plt.figure(figsize=(14, 6))

# --- ECG signal ---
plt.subplot(2, 1, 1)
plt.plot(t, x, label="Original ECG", linewidth=2)
plt.plot(t, recon, label="Reconstruction", linestyle="--")
plt.title("ECG Reconstruction")
plt.legend()
plt.grid(True)

# --- Residual ---
plt.subplot(2, 1, 2)
plt.plot(t, residual, label="Residual Error")
plt.axhline(res_th, color="red", linestyle="--", label="Residual Threshold")

# Highlight anomalous regions
anomaly_mask = residual > res_th
plt.fill_between(t, residual, where=anomaly_mask, alpha=0.3, label="Anomalous Region")

plt.title("Residual-Based Anomaly Localization")
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()


In [None]:
plt.savefig("ecg_residual_example.png", dpi=300)
