In [None]:
# Import necessary libraries
import pandas as pd
import numpy as np
import wfdb
import ast
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt

from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, classification_report
from sklearn.preprocessing import MultiLabelBinarizer

# Data loading and preprocessing functions
def load_raw_data(df, sampling_rate, path):
    if sampling_rate == 100:
        data = [wfdb.rdsamp(path+f) for f in df.filename_lr]
    else:
        data = [wfdb.rdsamp(path+f) for f in df.filename_hr]
    data = np.array([signal for signal, meta in data])
    return data

def aggregate_diagnostic(y_dic):
    tmp = []
    for key in y_dic.keys():
        if key in agg_df.index:
            tmp.append(agg_df.loc[key].diagnostic_class)
    return list(set(tmp))

def split(X, Y):
    test_fold = 10
    # Train
    X_train = X[np.where(Y.strat_fold != test_fold)]
    y_train = Y[Y.strat_fold != test_fold].diagnostic_superclass
    # Test
    X_test = X[np.where(Y.strat_fold == test_fold)]
    y_test = Y[Y.strat_fold == test_fold].diagnostic_superclass
    return X_train, y_train, X_test, y_test

def encode_filter(X_train, y_train, X_test, y_test, permute=True):
    unique_classes = list(set([item for sublist in y_train for item in sublist]))
    class_to_idx = {cls: idx for idx, cls in enumerate(unique_classes)}

    # Filter data
    y_train_filtered = [labels for labels in y_train if labels]
    X_train_filtered = X_train[np.array([bool(labels) for labels in y_train])]
    y_test_filtered = [labels for labels in y_test if labels]
    X_test_filtered = X_test[np.array([bool(labels) for labels in y_test])]

    y_train_encoded = [class_to_idx[labels[0]] for labels in y_train_filtered]
    y_test_encoded = [class_to_idx[labels[0]] for labels in y_test_filtered]

    if permute:
        X_train_tensor = torch.tensor(X_train_filtered, dtype=torch.float32).permute(0, 2, 1)
        X_test_tensor = torch.tensor(X_test_filtered, dtype=torch.float32).permute(0, 2, 1)
    else:
        X_train_tensor = torch.tensor(X_train_filtered, dtype=torch.float32)
        X_test_tensor = torch.tensor(X_test_filtered, dtype=torch.float32)

    y_train_tensor = torch.tensor(y_train_encoded, dtype=torch.long)
    y_test_tensor = torch.tensor(y_test_encoded, dtype=torch.long)

    train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
    test_dataset = TensorDataset(X_test_tensor, y_test_tensor)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    return train_loader, test_loader, len(unique_classes), class_to_idx

# Main code execution
if __name__ == '__main__':
    path = ''  # Set your dataset path here
    sampling_rate = 100

    # Load and convert annotation data
    Y = pd.read_csv(path + 'ptbxl_database.csv', index_col='ecg_id')
    Y.scp_codes = Y.scp_codes.apply(lambda x: ast.literal_eval(x))

    # Load raw signal data
    X = load_raw_data(Y, sampling_rate, path)

    # Load scp_statements.csv for diagnostic aggregation
    agg_df = pd.read_csv(path + 'scp_statements.csv', index_col=0)
    agg_df = agg_df[agg_df.diagnostic == 1]

    # Apply diagnostic superclass
    Y['diagnostic_superclass'] = Y.scp_codes.apply(aggregate_diagnostic)

    # Split the data
    X_train, y_train, X_test, y_test = split(X, Y)

    # Encode labels and prepare data loaders
    train_loader, test_loader, num_classes, class_to_idx = encode_filter(X_train, y_train, X_test, y_test)

    # Define the corrected InceptionTime model
    class InceptionModule1D(nn.Module):
        def __init__(self, in_channels, out_channels, kernel_sizes=[9, 19, 39], bottleneck_channels=32, use_residual=True):
            super(InceptionModule1D, self).__init__()
            self.use_residual = use_residual

            # Adjust bottleneck_channels if in_channels is less than bottleneck_channels
            self.bottleneck_channels = min(in_channels, bottleneck_channels)

            # Bottleneck layer
            self.bottleneck = nn.Conv1d(in_channels, self.bottleneck_channels, kernel_size=1, bias=False)

            # Convolutional layers with adjusted input channels
            self.conv1 = nn.Conv1d(self.bottleneck_channels, out_channels, kernel_size=kernel_sizes[0],
                                   padding=kernel_sizes[0] // 2, bias=False)
            self.conv2 = nn.Conv1d(self.bottleneck_channels, out_channels, kernel_size=kernel_sizes[1],
                                   padding=kernel_sizes[1] // 2, bias=False)
            self.conv3 = nn.Conv1d(self.bottleneck_channels, out_channels, kernel_size=kernel_sizes[2],
                                   padding=kernel_sizes[2] // 2, bias=False)
            self.maxpool = nn.MaxPool1d(kernel_size=3, stride=1, padding=1)
            self.conv4 = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=False)

            # Batch normalization and activation
            self.bn = nn.BatchNorm1d(out_channels * 4)
            self.relu = nn.ReLU()

            # Residual connection
            if self.use_residual and in_channels != out_channels * 4:
                self.residual = nn.Sequential(
                    nn.Conv1d(in_channels, out_channels * 4, kernel_size=1, bias=False),
                    nn.BatchNorm1d(out_channels * 4)
                )
            else:
                self.residual = nn.Identity()

        def forward(self, x):
            input_residual = x

            x = self.bottleneck(x)
            x1 = self.conv1(x)
            x2 = self.conv2(x)
            x3 = self.conv3(x)
            x4 = self.conv4(self.maxpool(input_residual))

            x = torch.cat([x1, x2, x3, x4], dim=1)
            x = self.bn(x)

            if self.use_residual:
                x += self.residual(input_residual)
            x = self.relu(x)
            return x

    class InceptionTime(nn.Module):
        def __init__(self, in_channels, num_classes, num_modules=6):
            super(InceptionTime, self).__init__()
            modules = []
            for i in range(num_modules):
                if i == 0:
                    modules.append(InceptionModule1D(in_channels, 32))
                else:
                    modules.append(InceptionModule1D(128, 32))
            self.inception_modules = nn.Sequential(*modules)
            self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
            self.fc = nn.Linear(128, num_classes)

        def forward(self, x):
            x = self.inception_modules(x)
            x = self.global_avg_pool(x).squeeze(-1)
            x = self.fc(x)
            return x

    # Instantiate the model, define loss and optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    input_channels = X_train.shape[2]
    model = InceptionTime(in_channels=input_channels, num_classes=num_classes).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    # Training function with metrics computation
    def train_model(model, train_loader, test_loader, criterion, optimizer, epochs=10):
        for epoch in range(epochs):
            model.train()
            train_loss = 0.0
            for X_batch, y_batch in train_loader:
                X_batch, y_batch = X_batch.to(device), y_batch.to(device)
                optimizer.zero_grad()
                outputs = model(X_batch)
                loss = criterion(outputs, y_batch)
                loss.backward()
                optimizer.step()
                train_loss += loss.item()
            avg_train_loss = train_loss / len(train_loader)
            print(f"Epoch {epoch + 1}/{epochs}, Loss: {avg_train_loss:.4f}")

            # Evaluation
            model.eval()
            y_pred, y_true = [], []
            with torch.no_grad():
                for X_batch, y_batch in test_loader:
                    X_batch, y_batch = X_batch.to(device), y_batch.to(device)
                    outputs = model(X_batch)
                    _, predicted = torch.max(outputs, 1)
                    y_pred.extend(predicted.cpu().numpy())
                    y_true.extend(y_batch.cpu().numpy())

            accuracy = accuracy_score(y_true, y_pred)
            macro_f1 = f1_score(y_true, y_pred, average='macro')
            micro_f1 = f1_score(y_true, y_pred, average='micro')
            macro_precision = precision_score(y_true, y_pred, average='macro')
            macro_recall = recall_score(y_true, y_pred, average='macro')
            print(f"Validation Accuracy: {accuracy * 100:.2f}%")
            print(f"Macro F1 Score: {macro_f1:.4f}")
            print(f"Micro F1 Score: {micro_f1:.4f}")
            print(f"Macro Precision: {macro_precision:.4f}")
            print(f"Macro Recall: {macro_recall:.4f}")

    # Train the model
    train_model(model, train_loader, test_loader, criterion, optimizer, epochs=10)

    # Compute detailed classification report after training
    model.eval()
    y_pred, y_true = [], []
    with torch.no_grad():
        for X_batch, y_batch in test_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            outputs = model(X_batch)
            _, predicted = torch.max(outputs, 1)
            y_pred.extend(predicted.cpu().numpy())
            y_true.extend(y_batch.cpu().numpy())

    print("\nClassification Report:")
    # Get class names in the correct order
    idx_to_class = {v: k for k, v in class_to_idx.items()}
    target_names = [idx_to_class[i] for i in range(len(idx_to_class))]
    print(classification_report(y_true, y_pred, target_names=target_names))

    # Compute confusion matrix
    from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
    import matplotlib.pyplot as plt

    cm = confusion_matrix(y_true, y_pred)

    # Plot confusion matrix
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=target_names)
    fig, ax = plt.subplots(figsize=(8, 8))
    disp.plot(cmap=plt.cm.Blues, ax=ax, values_format='d')
    plt.xticks(rotation=45)
    plt.title('Confusion Matrix')
    plt.show()

    # Compute ROC curves and AUC
    from sklearn.preprocessing import label_binarize
    from sklearn.metrics import roc_curve, auc
    from itertools import cycle
    import numpy as np

    # Binarize the labels for ROC computation
    n_classes = num_classes  # Number of classes
    all_labels_bin = label_binarize(y_true, classes=range(n_classes))

    # Collect probabilities and true labels
    all_probs = []
    all_labels = []

    model.eval()
    with torch.no_grad():
        for X_batch, y_batch in test_loader:
            X_batch = X_batch.to(device)
            outputs = model(X_batch)
            probabilities = F.softmax(outputs, dim=1)
            all_probs.append(probabilities.cpu().numpy())
            all_labels.append(y_batch.cpu().numpy())

    # Concatenate all batches
    all_probs = np.concatenate(all_probs, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)

    # Binarize the labels
    all_labels_bin = label_binarize(all_labels, classes=range(n_classes))

    # Compute ROC curve and ROC area for each class
    fpr = dict()
    tpr = dict()
    roc_auc = dict()

    for i in range(n_classes):
        fpr[i], tpr[i], _ = roc_curve(all_labels_bin[:, i], all_probs[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

    # Compute micro-average ROC curve and ROC area
    fpr["micro"], tpr["micro"], _ = roc_curve(all_labels_bin.ravel(), all_probs.ravel())
    roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

    # Plot ROC curves
    plt.figure(figsize=(10, 8))
    colors = cycle(['aqua', 'darkorange', 'cornflowerblue', 'green', 'red'])

    for i, color in zip(range(n_classes), colors):
        plt.plot(fpr[i], tpr[i], color=color, lw=2,
                label=f'ROC curve of class {target_names[i]} (area = {roc_auc[i]:0.2f})')

    # Plot micro-average ROC curve
    plt.plot(fpr["micro"], tpr["micro"],
            label=f'micro-average ROC curve (area = {roc_auc["micro"]:0.2f})',
            color='deeppink', linestyle=':', linewidth=4)

    plt.plot([0, 1], [0, 1], 'k--', lw=2)
    plt.xlim([-0.05, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate', fontsize=14)
    plt.ylabel('True Positive Rate', fontsize=14)
    plt.title('Receiver Operating Characteristic (ROC) Curves', fontsize=16)
    plt.legend(loc="lower right", fontsize=12)
    plt.show()


TypeError: InceptionModule1D.__init__() got an unexpected keyword argument 'num_classes'