Skip to content

A TensorFlow [2.0] implementation of ProSeNet: "Interpretable and Steerable Sequence Learning via Prototypes" (Ming et al., 2019)

Notifications You must be signed in to change notification settings

rgmyr/tf-ProSeNet

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

24 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

TensorFlow-ProSeNet

This is a tf.keras implementation of Interpretable and Steerable Sequence Learning via Prototypes (Ming et al., 2019). It's unlikely I'll implement any mechanisms for the "steering" part, but most of the interpretive stuff would be useful for my work.

Contributions are welcome!

Status

This implementation needs troubleshooting / debugging work.

MIT-BIH Dataset

I'm currently testing on the MIT-BIH Arrhythmia ECG Dataset, which is available the pre-processed format described in (Kachuee et al., 2018) on Kaggle Datasets: ECG Heartbeat Categorization Dataset.

See notebooks/test_arrythmia.ipynb. There are some outstanding issues...

No matter how I weight the different regularization terms, the classifier weights tend towards vectors of [+const, 0., 0., 0., 0.]. (The first class -- "Normal" beats -- accounts for ~83% of the dataset).

I've tried:

  • Using class_weights, even minimizing the weight of the first class to be negligibly small. Does not seem to help.
  • Training just the LSTM encoder first (achieves up to ~95% accuracy), then freezing those layers and training only the prototypes_layer and classifier. Does not seem to help.
  • Heavier regularization (of both prototype vector diversity and classifier weights). Does not seem to help.

Other ideas to try:

  • Heavily downsampling the first class. (Undesirable, since it wouldn't be a reproduction of results.)
  • Verifying custom regularization terms more rigorously.
  • Put cross entropy loss on logits (rather than softmax)

Synthetic Signals Dataset

I've written a SyntheticSignalsDataset class that generates saw/square/sine signals. This is a very simple dataset which a simple LSTM can easily master. I'm using this dataset to troubleshoot. See notebooks/test_synthetic.ipynb.

This dataset seems to work OK after I fixed a bug in the diversity regularization function -- still troubleshooting the ECG one.

Additional Features

Prototype Projection has been implemented as a Callback, but I have not yet implemented Prototype Simplification (via beam search). NOTE: the paper is somewhat ambiguous as to whether this is used during training (as the "projection" step) or whether this is just an interpretation tool. I think it's the latter, but may try to raise an issue on their template repo to make sure.

Additional functions to help with prototype interpretation will be helpful, but I want to make sure I can get the network to train properly first.

About

A TensorFlow [2.0] implementation of ProSeNet: "Interpretable and Steerable Sequence Learning via Prototypes" (Ming et al., 2019)

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages