In [1]:
import numpy as np
from tensorflow.keras import optimizers

from prosenet import ProSeNet, datasets

In [2]:
# Set any non-default args
new_rnn_args = {
    'layers' : [32, 32, 32]
}

new_proto_args = {
    'dmin' : 2.0,
    'Ld' : 0.01,
    'Lc' : 0.0,
    'Le' : 1.0
}

pnet = ProSeNet(input_shape=(187, 1), nclasses=5, k=30, 
            rnn_args=new_rnn_args,
            prototypes_args=new_proto_args)

pnet.build( (None, 187, 1) )

pnet.summary()

Model: "pro_se_net"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
sequential (Sequential)      (None, 64)                58368     
_________________________________________________________________
prototypes (Prototypes)      multiple                  1920      
_________________________________________________________________
classifier (Dense)           multiple                  150       
Total params: 60,438
Trainable params: 60,438
Non-trainable params: 0
_________________________________________________________________


In [3]:
data = datasets.ArrhythmiaDataset('../data/')
print(data)

MIT-BIH Arrhythmia Dataset
Num classes: 5
Input shape: (187, 1)



In [4]:
train_gen = datasets.DataGenerator(data.X_train, data.y_train, batch_size=128)

test_gen = datasets.DataGenerator(data.X_test, data.y_test, batch_size=128)

In [5]:
sgd = optimizers.SGD(learning_rate=0.001)

pnet.compile(optimizer=sgd, loss='categorical_crossentropy', metrics=['accuracy'])

pnet.fit_generator(train_gen, validation_data=test_gen, epochs=12, shuffle=False)



<tensorflow.python.keras.callbacks.History at 0x7fe7885e15f8>

In [6]:
pnet.predict(data.X_train[np.newaxis,:3,:], batch_size=1)

array([[0.33492956, 0.10209967, 0.15463914, 0.13533103, 0.27300063]],
      dtype=float32)

In [7]:
data.y_train[:3]

array([1., 0., 0., 0., 0.], dtype=float32)

In [9]:
pnet.predict(data.X_test[np.newaxis,:3,:], batch_size=1)

array([[0.33492956, 0.10209967, 0.15463914, 0.13533103, 0.27300063]],
      dtype=float32)

In [10]:
data.y_test[0:3]

array([1., 0., 0., 0., 0.], dtype=float32)