This repo implements and trains a memory augmented neural networks, a black-box meta-learner that uses a recurrent neural network for few shot classification
This repository contains:
- The python code
- The config file
- CS330 HW1 file
- And the ReadMe file itself
The Omniglot data set is designed for developing more human-like learning algorithms. It contains 1623 different handwritten characters from 50 different alphabets. Each of the 1623 characters was drawn online via Amazon's Mechanical Turk by 20 different people. The Omniglot data set contains 50 alphabets. It is split into a background set of 30 alphabets and an evaluation set of 20 alphabets.
A stacked 2 layered-LSTM model is employed. The inputs from the support set are concatenated with their true lables one-hot encoded. Where as the inputs from the query set are concatenated with all zeroes. The model is expected to predict the true labels of the query set. Shown below is a stacked LSTM model.
More information on the training procedure could be found in HW1 of CS330. The hyper-parameters can be changed in the config file.
Update 1 (05-08-2021) : Included support for Bidirectional-LSTM. Change 'bi_dir' to "true" in the config file to enable BiLSTM.
Download the omniglot data here and save the downloaded folders in a folder titled 'omniglot'. Save the python code and config file in the same directory of 'omniglot'.
BlackBox.py
config.json
omniglot
│___ images_background
│___ images_evaluation
Install the following libraries to run the code
torch
numpy
glob
PIL
matplotlib
Run BlackBox.py
python3 BlackBox.py
- This work is inspired by Stanford's CS 330: Deep Multi-Task and Meta Learning
- A similar implementation could be found here
- An awesome blog on Meta-Learning
- Stanford's lecture series on CS 330: Deep Multi-Task and Meta Learning in YouTube