Code for the "Neural Expectation Maximization" paper.
Switch branches/tags
Nothing to show
Clone or download
Permalink
Failed to load latest commit information.
illustrations added illustration Feb 15, 2018
.gitignore init commit Aug 10, 2017
README.md added illustration Feb 15, 2018
datasets.py init commit Aug 10, 2017
nem.py init commit Aug 10, 2017
nem_model.py removed gaussian_sigma option Sep 8, 2017
network.py init commit Aug 10, 2017
requirements.txt init commit Aug 10, 2017
utils.py init commit Aug 10, 2017

README.md

Neural Expectation-Maximization

rnn-em

This is the code repository complementing the paper "Neural Expectation Maximization". All experiments from the paper can be reproduced from this repository. The datasets can be found here.

Dependencies and Setup

  • tensorflow==1.2.1
  • numpy >= 1.13.1
  • sacred == 0.7.0
  • pymongo == 3.4.0
  • Pillow == 4.2.1
  • scipy >= 0.19.1
  • scikit-learn >= 0.18.2
  • scikit-image >= 0.13.0
  • matplotlib >= 2.0.2
  • h5py >= 2.7.0

Experiments

Use the following calls to recreate the experiments

Static Shapes

RNN-EM

python nem.py with dataset.shapes network.shapes nem.k=4 nem.nr_steps=15 noise.prob=0.1

N-EM

python nem.py with dataset.shapes network.NEM nem.k=4 nem.nr_steps=15 noise.prob=0.1

Flying Shapes

python nem.py with nem.sequential dataset.flying_shapes network.flying_shapes nem.k=3 nem.nr_steps=20

By varying K and the number of objects in the dataset (by using dataset.flying_shapes_4 , or dataset.flying_shapes_5) all results in Table 1 can be computed.

Flying MNIST

Training directly

python nem.py with nem.sequential dataset.flying_mnist_hard_2 network.flying_mnist nem.k=2 nem.nr_steps=20 nem.loss_inter_weight=0.2 training.params.learning_rate=0.0005

Training in stages:

20 variations:

python nem.py with nem.sequential dataset.flying_mnist_medium_20_2 network.flying_mnist nem.k=2 nem.nr_steps=20 nem.loss_inter_weight=0.2

500 variations:

python nem.py with nem.sequential dataset.flying_mnist_medium_500_2 network.flying_mnist nem.k=2 nem.nr_steps=20 nem.loss_inter_weight=0.2 training.params.learning_rate=0.0005 net_path=debug_out/best

full dataset:

python nem.py with nem.sequential dataset.flying_mnist_hard_2 network.flying_mnist nem.k=2 nem.nr_steps=20 nem.loss_inter_weight=0.2 training.params.learning_rate=0.0005 net_path=debug_out/best

Evaluation

During training an overview of the losses as well as ARI scores on the train and validation set (by default only on the first 1000 samples) are computed. At test-time one can compute the AMI scores (which are much more expensive to compute), or next-step prediction loss by using the run_from_file command.

For example when training RNN-EM on flying shapes using the following config:

python nem.py with nem.sequential dataset.flying_shapes network.flying_shapes nem.k=3 nem.nr_steps=20

one could evaluate it on the test set and compute the AMI scores by calling:

python nem.py run_from_file with <config, see above> run_config.AMI=True

or similarly obtain the BCE next-step prediction loss by calling:

python nem.py run_from_file with <config, see above> run_config.AMI=False