Switch branches/tags
20171212-distributed-validation doc/20181003-add-readme example/20181002-ENAS example/20181020-SAGAN example/20181204-MUNIT feature/20170921-vae-example feature/20171204-robust-mnist feature/20171209-capsule-net feature/20171215-imagenet-example feature/20180111-autoformat feature/20180216-backward-all-reduce feature/20180219-ptb feature/20180301-fp16-cuda feature/20180316-prototypical-net feature/20180329-half-communicator feature/20180330-add-notice-file feature/20180416-efficient-create-imagenet-files feature/20180509-factorization-layers feature/20180515-create-dcgan-folder feature/20180529-multi-gpu-imagenet feature/20180607-pix2pix feature/20180621-multi-gpu-imagenet-fixed-doc feature/20180622-file-format-converter feature/20180625-yolov2-training feature/20180713-improve-cache-creator-in-imagenet-example feature/20180728-shiftnet feature/20180728-shufflenet feature/20180810-yolov2-docs feature/20180821-fix-scipy-misc-deprecated feature/20180822-yolov2-refactor-debug-enhance feature/20181009_cp_als_regularization feature/20181020-improve-auto-format feature/20181025-pip-install-options feature/20181027-fix-imresize features/20180425-documentation-fix fix/20180606-global-docfix fix/20180629-obsolete-extension-api fix/20180703-di-sliced-rng-argument fix/20180703-notebook-collapse fix/20180730-yolov2-unlinked-workaround fix/20180912-delete-no-need-api-call fix/20180913-remove-unnecessary-lines fix/20180914-misspelling fix/20181010-imagenet-train-slice fix/20181119-image-utils fix/20181129-delete-clear-all-graph-links hotfix/20180305-test-mode master
Nothing to show
Find file History
Fetching latest commit…
Cannot retrieve the latest commit at this time.
Type Name Latest commit message Commit time
Failed to load latest commit information.


One-shot and few-shot learning examples


This example script implements one-shot and few-shot learning of hand-written characters on Omniglot dataset. It learns to classify a query image of character to its correct class from support classes, upon seeing only a few sample images from each support classes.

Start training

Prepare omniglot dataset by

python omniglot_data.py

This script downloads the dataset from ... into ~/nnabla_data. Also it generates compressed dataset files *.npy in ./omniglot/data. Extracting the dataset can take a while (around 1, 2 minutes).

Once the dataset is ready, start training, such as metric based meta learning, by

python metric_based_meta_learning.py

The output of the training will be saved in tmp.results directory. You can see here the "Training loss curve", "Validation error curve" and "Test error result". Also you can see a t-SNE 2d-visualized distribution of test samples.

By default, the script will be executed with GPU. If you prefer to run with CPU,

python metric_based_meta_learning.py -c cpu

Metric based meta learning

We classified some of the meta-learning method into "metric-based meta-learning". Siamese networks for oneshot:https://www.cs.cmu.edu/~rsalakhu/papers/oneshot1.pdf Matching networks: https://arxiv.org/abs/1606.04080 Prototypical networks: https://arxiv.org/abs/1703.05175

The script metric_based_meta_learning.py can demonstrate the matching networks and the prototypical networks. The default setting is a prototypical network with euclid distance of 20-way, 1-shot and 5-query setting. We have many options to change parameters including network structures. The following is an example of setting hyperparameters with corresponding options.

python metric_based_meta_learning.py -nw 20 -ns 1 -nq 5 -nwt 20 -nst 1 -nqt 5

Example of options are as follows. -nw : Number of ways in meta-test, typically 5 or 20 -ns : Number of shots per class in meta-test, typically 1 or 5 -nq : Number of queries per class in meta-test, typically 5 -nwt: Number of ways in meta-training, typically 60, or same as meta-test -nst: Number of shots per class in meta-training, typically same as meta-test -nqt: Number of queries per class in meta-test, typically same as meta-test -d : Similarity metric, you can select "cosine" or "euclid". -n ; Network type, you can select "matching" and "prototypical".

Prototypical networks

The default setting of this script is a prototypical network with euclid distance. The embedding architecture follows the typical network with 4 convolutions written in papers. To avoid all zero output from the embedding network, we omitted the last relu activation. You can refer the paper in the following site. https://arxiv.org/abs/1703.05175 Following the recommendation in this paper, we adopted 60-way episodes for training instead of 1 or 5-way.

Matching networks

You can also select matching networks by setting -n option to matching. However, since we are interested in the aspect of the metric learning, we implemented only the softmax attention part which works as soft nearest neighbor search. You can refer the paper in the following site. https://arxiv.org/abs/1606.04080 We omitted the full context embedding in this paper, which uses the context by using a LSTM module.