Skip to content

Replication of the "Conditional Neural Processes" (2018) and "Neural Processes" (2018) papers by Garnelo et al.

Notifications You must be signed in to change notification settings

tonywu71/neural-processes

Repository files navigation

Neural Processes

Replication of the Conditional Neural Processes paper by Marta Garnelo et al., 2018 [arXiv] and the Neural Processes paper by Marta Garnelo et al, 2018 [arXiv]. The model and data pipeline were implemented using Tensorflow 2.10.

Code released in complement of the report and the poster.

Project for the Advanced Machine Learning module, MPhil in MLMI, University of Cambridge.

William Baker, Alexandra Shaw, Tony Wu

1. Introduction

While neural networks excel at function approximation, Gaussian Processes (GPs) address different challenges such as uncertainty prediction, continuous learning, and the ability to deal with data scarcity. Therefore, each model is only suited for a restricted spectrum of tasks that strongly depends on the nature of available data.

Neural Processes use neural networks to encode distributions over functions to approximate the dis- tributions over functions given by stochastic processes like GPs. This allows for efficient inference and scalability to large datasets. The performance of these models will be evaluated on 1D-regression and image completion to demonstrate visually how they learn distributions over complex functions.

np_poster_diagram

Figure 1: Comparison between Gaussian Process, Neural Network and Neural Process

2. Instructions

  1. Using an environment with python 3.10.8, install modules using:

    pip install -r requirements.txt
    
  2. To create, train, and evaluate instances of neural processes, run the train.py script. Use python train.py --help to display its arguments. In particular, specify the --model flag with CNP, HNP, LNP, or HNPC to choose the used model. Example:

    python train.py --task regression --model cnp --epochs 50 --batch 128
  3. The model will be saved in the checkpoints directory.

2. Data pipeline

data_pipeline

Figure 2: Data pipeline and examples of generated data for Neural Processes

Contrarily to neural networks which predict functions, NPs predict distributions of functions. For this reason, we have built a specific data loader class using the tf.data API to produce the examples for both training and validation. Note that the class definitions for data generators can be found in the dataloader module directory.

3. Models

architecture

Figure 3: Architecture diagram of CNP, LNP, and HNP

CNP, LNP and HNP all have a similar encoder-decoder architecture. They have been implemented using classes that inherit from tf.keras.Model. Thus, training with the tf.data API is straightforward and optimized.

4. Experiments

Training can either be conducted in a interactive session (iPython) with arguments set in the section beginning ln 40 (Training parameters). Or by commenting section ln40 and uncommenting section ln 25 (Parse Training parameters) the terminal and it's cmd arguments can be used.

4.1. Regression training

python train.py --task regression

Example of obtained result:

1d_regression-fixed_kernel

Figure 4: Comparison between GP, CNP and LNP on the 1D-regression task (fixed kernel parameter)

4.2. MNIST training

python train.py --task mnist

Example of obtained result:

mnist-image_completion

Figure 5: : CNP pixel mean and variance predictions on images from MNIST

4.3. CelebA training

Instructions: Download the aligned and cropped images from here and extract files in the ./data directory.

python train.py --task celeb

Example of obtained result:

celebA-image_completion

Figure 6: CNP pixel mean and variance predictions on images from CelebA

4.4. Extension: HNP and HNPC

Objective: Combine the deterministic link between the context representations (used by CNP) with the non-deterministic link from the latent space representation space (used by LNP) to produce a model with a richer embedding space.

extension-hnp_hnpc

Figure 7: Latent Variable Distribution - Mean and Standard Deviation Statistics during training.

5. Appendix

To go further, read the poster and the report that can be found in the report folder of this repository.

poster-thumbnail

Figure 8: Miniature of the CNP/LNP poster