In [2]:
from torch.utils.data import DataLoader, WeightedRandomSampler
from resnet import ResNet2
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, RocCurveDisplay
import numpy as np
from custom_dataset import FinalDataset, BeatsDataset
import torch
from torch import nn
import torch.optim as optim
import matplotlib.pyplot as plt
from FocalLoss import FocalLoss
import copy

In [None]:
file_path= "./data/average_beats.hdf5"
dataset2 = BeatsDataset(file_path ,downsample=True,majority_ratio=0.95)

In [None]:
# Focal loss parameters
gamma = 2.0
alpha = 0.95

# ResNet parameters
in_channels = 12
out_channels = 64

# Optimizer parameters
learning_rate = 5e-4
# grid search on learning rate
learning_rates = [1e-4, 5e-4, 1e-3]
momentum = 0.9
weight_decay = 1e-4

random_state = 42

# Training parameters
epochs = 200
batch_size = 256

patience, count = 10,0

In [None]:
train_val_idx_2, test_idx_2 = train_test_split(
    np.arange(len(dataset2)),
    test_size=0.15,
    random_state=random_state,
    stratify=dataset2.get_labels()
)
train_val_idx_2.sort()
train_idx_2, val_idx_2 = train_test_split(
    train_val_idx_2,
    test_size=0.17647,  # ~15% of total
    random_state=random_state,
    stratify=dataset2.labels[train_val_idx_2]
)
train_loader_2 = DataLoader(BeatsDataset("./data/average_beats.hdf5", indices = dataset2.indices[train_idx_2]), batch_size=batch_size, shuffle=True)
val_loader_2   = DataLoader(BeatsDataset("./data/average_beats.hdf5", indices = dataset2.indices[val_idx_2]),   batch_size=batch_size, shuffle=False)
test_loader_2  = DataLoader(BeatsDataset("./data/average_beats.hdf5", indices = dataset2.indices[test_idx_2]),  batch_size=batch_size, shuffle=False)


In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


In [None]:
def train_loop(dataloader, model, loss_fn, optimizer, device, l1_lambda=1e-5):
    size = len(dataloader.dataset.indices)
    model.train()
    total_loss = 0
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        pred = model(X)

        bce_loss = loss_fn(pred, y).mean()
        
        l1_norm = sum(p.abs().sum() for p in model.parameters())
        loss = bce_loss #+ l1_lambda * l1_norm
        
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        if batch % 100 == 0:
            print(f"loss: {loss.item():>7f} [{batch * len(X):>5d}/{size:>5d}]")
    return total_loss / len(dataloader)


def test_loop(dataloader, model, loss_fn, device,val = False):
    size = len(dataloader.dataset.indices)
    total_loss = 0
    model.eval()
    all_preds, all_targets = [], []
    
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            loss = loss_fn(pred, y).mean()
            total_loss += loss.item()
            all_preds.append(pred.cpu())
            all_targets.append(y.cpu())
    
    all_preds = torch.cat(all_preds)
    all_targets = torch.cat(all_targets)

    probabilities = torch.sigmoid(all_preds)

    targets = all_targets.int().numpy()

    # for threshold in np.arange(0.1, 1.0, 0.1):
    #     print(f"\n==== Threshold: {threshold:.1f} ====")
    #     predicted = (probabilities > threshold).int().numpy()

    #     cm = confusion_matrix(targets, predicted)
    #     print("Confusion Matrix:")
    #     print(cm)
        
    #     report = classification_report(targets, predicted, digits=4, zero_division=0)
    #     print("Classification Report:")
    #     print(report)
    if val:
        RocCurveDisplay.from_predictions(all_targets, probabilities, name="ROC Curve")
        plt.plot([0, 1], [0, 1], linestyle='--', color='gray')
        plt.title('ROC Curve')
        plt.savefig('roc_curve_RETE_BEATS_NO_L1.png')
        plt.show()
    return total_loss / len(dataloader)

In [None]:
# Load dataset labels
# labels = np.array(dataset.get_labels() ,dtype=np.int32)
# labels = labels.squeeze()

# Class counts and pos_weight for BCEWithLogitsLoss
# unique_classes, class_counts = np.unique(labels, return_counts=True)
# class_counts_dict = dict(zip(unique_classes, class_counts))
# print(f"Class counts: {class_counts_dict}")


# # with 95% majority ratio
# pos_weight = torch.tensor(155609/ 8190).to(device)
# loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

# with 50% majority ratio
loss_fn = nn.BCEWithLogitsLoss()

# with Focal Loss and 95% majority ratio
# loss_fn = FocalLoss(gamma, alpha)

# Model and optimizer
model = ResNet2(in_channels, out_channels).to(device)
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
# optimizer = optim.SGD(model.parameters(), lr=1e-2, momentum=0.9, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)

# Single train/val loop
train_losses = []
val_losses = []
best_loss = float('inf')
best_model_state = None
count = 0

for epoch in range(epochs):
    print(f"Epoch {epoch+1}\n-------------------------------")
    train_loss = train_loop(train_loader_2, model, loss_fn, optimizer, device, l1_lambda=1e-5)
    val_loss = test_loop(val_loader_2, model, loss_fn, device)
    scheduler.step(val_loss)
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    
    if val_loss < best_loss:
        best_loss = val_loss
        count = 0
        # Save the best model state
        best_model_state = copy.deepcopy(model.state_dict())
        print(f"New best model saved with validation loss: {val_loss:.4f}")
    # else:
    #     count += 1
    #     print(f"No improvement for {count} epochs. Best val loss: {best_loss:.4f}")

    # if count >= patience:
    #     print(f"Early stopping triggered after {epoch+1} epochs")
    #     break

    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

if best_model_state is not None:
    model.load_state_dict(best_model_state)
    print("Loaded best model for testing")

# Test after training
test_loss = test_loop(test_loader_2, model, loss_fn, device, val=True)
print(f"Final Test Loss: {test_loss:.4f}")

fig, ax = plt.subplots()

ax.plot(range(1, len(train_losses) + 1), train_losses, label="Train Loss", color='tab:blue')
ax.plot(range(1, len(val_losses) + 1), val_losses, label="Validation Loss", color='tab:red')

ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
ax.legend()

plt.title(f"Train and Test Loss")
plt.tight_layout()
plt.savefig(f"train_test_loss_RETE_BEATS_NO_L1.png")
plt.show()

torch.save(model.state_dict(), f"model_RETE_BEATS_NO_L1.pth")
print(f"Saved model")

print("Training Completed!")
