In [1]:
from jmlib.data import Data
from jmlib.data.loaders import PTBXL

from jmlib.data.processing.common import LambdaModule
from jmlib.data.writers.common import Writer
from jmlib.data.generators.fewshot import FewShotGenerator
from jmlib.data.splitters.common import ClassSplitter 

from typing import Unpack
from keras.layers import TimeDistributed, Lambda, Activation, Flatten, Dense
from keras.layers import BatchNormalization, MaxPooling1D, Conv1D
from keras.models import Sequential
from keras.optimizers.legacy import Adam
from keras.callbacks import ReduceLROnPlateau
from keras.losses import CategoricalCrossentropy

from jmlib.models.common import BaseModel, BaseModelParams
from jmlib.util.models import reduce_tensor, reshape_query, proto_dist
from jmlib.util.models import LinearFusion

from keras.metrics import AUC

import numpy as np
import matplotlib.pyplot as plt

ptbxl = Data(name="raw_PTBXL", verbose=True)
ptbxl.add(
    PTBXL(data_dir="/Users/jbthompson/Documents/final_folder/ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3"),
    ClassSplitter({"train": 0.6, "val": 0.2, "test": 0.2}),
    FewShotGenerator(way=5, shot=5, query= 5, batch_size=100)
)
ptbxl.run()

KERNEL_TUPLE = tuple[tuple, tuple]
EPOCHS = 100

class PMCNN2022(BaseModel):
    """Zicong Li et al., 2022 - https://doi.org/10.1109/BHI56158.2022.9926948.

    Implementation by Joe McMahon. Adapted from from code provided by
    Zicong Li. Summary from paper: "a parallel multi-scale CNN (PM-CNN) based
    prototypical network for arrhythmia classification".
    This implementation supports categorical classification.
    Designed for use on the CPSC-2018 dataset.

    Parameters
    ----------
    lr : float, default=0.001
        Learning rate.
    depth : int, default=4
        Number of convolutional layers in the prototypical network.
    fd : int, default=512
        The shape of the 1D feature vector output of the fusion layer.
    kernels : tuple of tuples of ints, default=((3,3),(7,7))
        Convolutional kernel shapes.
    filters : int, default=64
        Number of filters in convolutional layers.
    **kwargs
        keyword arguments to pass to super class. See jmlib.models.BaseClass.
    """

    LR: float
    DEPTH: int
    FD: int
    KERNELS: KERNEL_TUPLE
    FILTERS: int

    def __init__(self,
                 lr: float = 0.001,
                 depth: int = 4,
                 fd: int = 512,
                 kernels: KERNEL_TUPLE = (3, 7),
                 filters: int = 64,
                 **kwargs: Unpack[BaseModelParams]):
        self.LR = lr
        self.DEPTH = depth
        self.FD = fd
        self.KERNELS = kernels
        self.FILTERS = filters
        super().__init__(**kwargs)

    def _layers(self, X):
        Xs, Xq = X
        shot = Xs.shape[-4]
        query = Xq.shape[-4]

        proto_model3 = TimeDistributed(
            self._proto_model(self.KERNELS[0]), name="Prototype_CNN_3"
        )
        proto_model7 = TimeDistributed(
            self._proto_model(self.KERNELS[1]), name="Prototype_CNN_7"
        )

        Xs3 = proto_model3(Xs)
        Xq3 = proto_model3(Xq)

        Xs7 = proto_model7(Xs)
        Xq7 = proto_model7(Xq)

        Xs = LinearFusion(shot, self.FD)(Xs7, Xs3)
        Xq = LinearFusion(query, self.FD)(Xq7, Xq3)

        Xs = Lambda(reduce_tensor, name="Reduce_Support")(Xs)
        Xq = Lambda(reshape_query, name="Reshape_Query")(Xq)

        X = Lambda(proto_dist, name="Prototype_Distance")([Xs, Xq])

        return X

    def _proto_model(self, k) -> Sequential:
        cnn = Sequential()
        for _ in range(self.DEPTH):  # type: ignore
            cnn.add(Conv1D(self.FILTERS, k, padding='same'))
            cnn.add(BatchNormalization())
            cnn.add(Activation('relu'))
            cnn.add(MaxPooling1D())

        cnn.add(Flatten())
        cnn.add(Dense(self.FD))
        return cnn

    @property
    def optimizer(self):
        """Adam Optimizer."""
        return Adam(learning_rate=self.LR)  # type: ignore

    @property
    def loss(self):
        """Categorical Crossentropy Loss."""
        return CategoricalCrossentropy()

    @property
    def callbacks(self):
        """Callbacks."""
        reduce_lr = ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.4,
            patience=2,
            min_lr=1e-8,  # type: ignore
            cooldown=2
        )
        return [reduce_lr]
    
mymodel = PMCNN2022(input_tensors=ptbxl.tensor)
mymodel.model.compile(optimizer=mymodel.optimizer, loss=mymodel.loss, metrics=['categorical_accuracy'])
mymodel.model.summary()
results = mymodel.model.fit(ptbxl.generators['train'], epochs=EPOCHS, validation_data=ptbxl.generators['val'], verbose=2)
test_results = mymodel.model.evaluate(ptbxl.generators['test'], verbose=2, return_dict=True)
print(test_results)


x_data = np.arange(1, EPOCHS + 1)
plt.figure()
plt.plot(x_data, results.history['categorical_accuracy'], label='Training Accuracy')
plt.plot(x_data, results.history['val_categorical_accuracy'],label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.savefig('accuracy')
plt.show()

plt.figure()
plt.plot(x_data, results.history['loss'], label='Training Loss')
plt.plot(x_data, results.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.savefig('loss')
plt.show()

from sklearn.metrics import f1_score, roc_auc_score
repeated_array = np.tile(ptbxl.generators['test']['data'][1], (100, 1))
true_labels_categorical = np.argmax(repeated_array, axis=1)

# Assuming `test_results` is a dictionary containing predictions and true labels
predictions = mymodel.model.predict(ptbxl.generators['test'])
true_labels = true_labels_categorical

# Assuming your predictions are probabilities and you want to convert them to class labels
predicted_labels = predictions.argmax(axis=1)

# Calculate F1 score
f1 = f1_score(true_labels, predicted_labels, average='weighted')

# Calculate AUC
auc = roc_auc_score(true_labels, predictions, multi_class='ovr')

print("Overall F1 Score:", f1)
print('Overall AUC Score:', auc)

ModuleNotFoundError: No module named 'jmlib'