In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, random_split
from torchmetrics import Accuracy, Precision, Recall, F1Score

# -------------------
# 1. Hyperparameters
# -------------------
data_dir = r"D:\res_work\ECG_analysis_for_CVD\PCA\extrct"  # Your dataset path
num_classes = 4  # We have 4 classes
batch_size = 32
num_epochs = 10  # Adjust epochs as needed
learning_rate = 1e-4
train_ratio = 0.7
test_ratio = 0.2
val_ratio = 0.1

# ---------------------
# 2. Data Transforms
# ---------------------
# You can adjust the image size or augmentations if necessary.
# Common ECG images might not need heavy augmentations, but here's an example:
data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),   # ResNet expects at least 224x224
    transforms.ToTensor(),
    # Normalize using ImageNet means & std if using pretrained ResNet
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# ---------------------------------------
# 3. Load Dataset (ImageFolder)
# ---------------------------------------
full_dataset = datasets.ImageFolder(root=data_dir, transform=data_transforms)

# Calculate split sizes
dataset_size = len(full_dataset)  # total number of images
train_size = int(train_ratio * dataset_size)
test_size = int(test_ratio * dataset_size)
val_size = dataset_size - train_size - test_size

# Random split into train/test/val
train_dataset, test_dataset, val_dataset = random_split(
    full_dataset, [train_size, test_size, val_size],
    generator=torch.Generator().manual_seed(42)  # for reproducibility
)

# ----------------------------
# 4. Create DataLoaders
# ----------------------------
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False, num_workers=2)
val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False, num_workers=2)

# -------------------------
# 5. Initialize Model
# -------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Use a pretrained ResNet-50
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
# Replace the final fully connected layer to match num_classes
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(device)

# Define loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# ---------------------------------------
# 6. Training and Validation Loop
# ---------------------------------------
best_val_acc = 0.0

for epoch in range(num_epochs):
    print(f"Epoch [{epoch+1}/{num_epochs}]")

    # ----------- TRAIN -----------
    model.train()
    running_loss = 0.0
    running_corrects = 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Statistics
        _, preds = torch.max(outputs, 1)
        running_loss += loss.item() * images.size(0)
        running_corrects += torch.sum(preds == labels.data)

    epoch_loss = running_loss / len(train_dataset)
    epoch_acc = running_corrects.double() / len(train_dataset)

    print(f"Train Loss: {epoch_loss:.4f}  |  Train Acc: {epoch_acc:.4f}")

    # ----------- VALIDATE -----------
    model.eval()
    val_running_corrects = 0
    val_running_loss = 0.0

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            _, preds = torch.max(outputs, 1)
            val_running_loss += loss.item() * images.size(0)
            val_running_corrects += torch.sum(preds == labels.data)

    val_epoch_loss = val_running_loss / len(val_dataset)
    val_epoch_acc = val_running_corrects.double() / len(val_dataset)

    print(f"Val Loss:   {val_epoch_loss:.4f}  |  Val Acc:   {val_epoch_acc:.4f}")

    # Save best model
    if val_epoch_acc > best_val_acc:
        best_val_acc = val_epoch_acc
        torch.save(model.state_dict(), "best_resnet50_ecg.pth")
        print("Model saved!")

print("Training complete. Best val accuracy: {:.4f}".format(best_val_acc))

# --------------------------------
# 7. Testing and Metrics
# --------------------------------
# Load best model for testing
model.load_state_dict(torch.load("best_resnet50_ecg.pth"))
model.eval()

accuracy_metric = Accuracy(task="multiclass", num_classes=num_classes).to(device)
precision_metric = Precision(task="multiclass", num_classes=num_classes, average='macro').to(device)
recall_metric = Recall(task="multiclass", num_classes=num_classes, average='macro').to(device)
f1_metric = F1Score(task="multiclass", num_classes=num_classes, average='macro').to(device)

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, preds = torch.max(outputs, 1)
        # Update metrics
        accuracy_metric.update(preds, labels)
        precision_metric.update(preds, labels)
        recall_metric.update(preds, labels)
        f1_metric.update(preds, labels)

test_accuracy = accuracy_metric.compute().item()
test_precision = precision_metric.compute().item()
test_recall = recall_metric.compute().item()
test_f1 = f1_metric.compute().item()

print(f"Test Accuracy : {test_accuracy:.4f}")
print(f"Test Precision: {test_precision:.4f}")
print(f"Test Recall   : {test_recall:.4f}")
print(f"Test F1 Score : {test_f1:.4f}")


Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to C:\Users\FireFly/.cache\torch\hub\checkpoints\resnet50-11ad3fa6.pth
100%|█████████████████████████████████████████████████████████████████████████████| 97.8M/97.8M [00:21<00:00, 4.76MB/s]


Epoch [1/10]
Train Loss: 1.1058  |  Train Acc: 0.5877
Val Loss:   1.2793  |  Val Acc:   0.4681
Model saved!
Epoch [2/10]
Train Loss: 0.3234  |  Train Acc: 0.9400
Val Loss:   0.3120  |  Val Acc:   0.9149
Model saved!
Epoch [3/10]
Train Loss: 0.0723  |  Train Acc: 0.9862
Val Loss:   0.1003  |  Val Acc:   0.9787
Model saved!
Epoch [4/10]
Train Loss: 0.0422  |  Train Acc: 0.9954
Val Loss:   0.1145  |  Val Acc:   0.9787
Epoch [5/10]
Train Loss: 0.0144  |  Train Acc: 1.0000
Val Loss:   0.1143  |  Val Acc:   0.9787
Epoch [6/10]
Train Loss: 0.0424  |  Train Acc: 0.9892
Val Loss:   0.1844  |  Val Acc:   0.9468
Epoch [7/10]
Train Loss: 0.0347  |  Train Acc: 0.9954
Val Loss:   0.1530  |  Val Acc:   0.9787
Epoch [8/10]
Train Loss: 0.0287  |  Train Acc: 0.9938
Val Loss:   0.1482  |  Val Acc:   0.9787
Epoch [9/10]
Train Loss: 0.0134  |  Train Acc: 1.0000
Val Loss:   0.2050  |  Val Acc:   0.9787
Epoch [10/10]
Train Loss: 0.0148  |  Train Acc: 0.9923
Val Loss:   0.1931  |  Val Acc:   0.9787
Training c