Skip to content

PyTorch implementation of "An Ensemble of Epoch-wise Empirical Bayes for Few-shot Learning" (ECCV 2020)

License

Notifications You must be signed in to change notification settings

yaoyao-liu/e3bm

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

68 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

An Ensemble of Epoch-wise Empirical Bayes for Few-shot Learning

LICENSE Python PyTorch Citations

[Paper] [Project Page] [GitLab@MPI]

This repository contains the PyTorch implementation for the ECCV 2020 Paper "An Ensemble of Epoch-wise Empirical Bayes for Few-shot Learning". If you have any questions on this repository or the related paper, feel free to create an issue or send me an email.

Summary

Introduction

Few-shot learning aims to train efficient predictive models with a few examples. The lack of training data leads to poor models that perform high-variance or low-confidence predictions. In this paper, we propose to meta-learn the ensemble of epoch-wise empirical Bayes models (E3BM) to achieve robust predictions. "Epoch-wise" means that each training epoch has a Bayes model whose parameters are specifically learned and deployed. "Empirical" means that the hyperparameters, e.g., used for learning and ensembling the epoch-wise models, are generated by hyperprior learners conditional on task-specific data. We introduce four kinds of hyperprior learners by considering inductive vs. transductive, and epoch-dependent vs. epoch-independent, in the paradigm of meta-learning. We conduct extensive experiments for five-class few-shot tasks on three challenging benchmarks: miniImageNet, tieredImageNet, and FC100, and achieve top performance using the epoch-dependent transductive hyperprior learner, which captures the richest information. Our ablation study shows that both "epoch-wise ensemble" and "empirical" encourage high efficiency and robustness in the model performance.

Figure: Conceptual illustrations of the model adaptation on the blue, red and yellow tasks. (a) MAML is the classical inductive method that meta-learns a network initialization θ that is used to learn a single base-learner on each task. (b) SIB is a transductive method that formulates a variational posterior as a function of both labeled training data T(tr) and unlabeled test data x(te). It also uses a single base-learner and optimizes the learner by running several synthetic gradient steps on x(te). (c) Our E3BM is a generic method that learns to combine the epoch-wise base-learners, and to generate task-specific learningcrates α and combination weights v that encourage robust adaptation.

Installation

In order to run this repository, we advise you to install python 3.6 and PyTorch 1.2.0 with Anaconda. You may download Anaconda and read the installation instruction on their official website: https://www.anaconda.com/download/

Create a new environment and install PyTorch and torchvision on it:

conda create --name e3bm-pytorch python=3.6
conda activate e3bm-pytorch
conda install pytorch=1.2.0 
conda install torchvision -c pytorch

Install other requirements:

pip install -r requirements.txt

Inductive Experiments

Performance (ResNet-12)

Experiment results (%) for 5-way few-shot classification on ResNet-12 (same as this repository).

Method Backbone 𝑚𝑖𝑛𝑖 1-shot 𝑚𝑖𝑛𝑖 5-shot 𝒕𝒊𝒆𝒓𝒆𝒅 1-shot 𝒕𝒊𝒆𝒓𝒆𝒅 5-shot
ProtoNet ResNet-12 60.37 ± 0.83 78.02 ± 0.57 65.65 ± 0.92 83.40 ± 0.65
MatchNet ResNet-12 63.08 ± 0.80 75.99 ± 0.60 68.50 ± 0.92 80.60 ± 0.71
MetaOptNet ResNet-12 62.64 ± 0.61 78.63 ± 0.46 65.99 ± 0.72 81.56 ± 0.53
Meta-Baseline ResNet-12 63.17 ± 0.23 79.26 ± 0.17 68.62 ± 0.27 83.29 ± 0.18
CAN ResNet-12 63.85 ± 0.48 79.44 ± 0.34 69.89 ± 0.51 84.93 ± 0.38
E3BM (Ours) ResNet-12 64.09 ± 0.37 80.29 ± 0.25 71.34 ± 0.41 85.82 ± 0.29

Running experiments

Run meta-training with default settings:

python main.py -backbone resnet12 -shot 1 -way 5 -mode meta_train -dataset miniimagenet
python main.py -backbone resnet12 -shot 5 -way 5 -mode meta_train -dataset miniimagenet
python main.py -backbone resnet12 -shot 1 -way 5 -mode meta_train -dataset tieredimagenet
python main.py -backbone resnet12 -shot 5 -way 5 -mode meta_train -dataset tieredimagenet

Run pre-training with default settings:

python main.py -backbone resnet12 -mode pre_train -dataset miniimagenet
python main.py -backbone resnet12 -mode pre_train -dataset tieredimagenet

Download resources

All the datasets and pre-trained models will be downloaded automatically.

You may also download the resources on Google Drive/百度网盘 using the following links:
Dataset 1 - miniImageNet: [Google Drive] [百度网盘] 提取码: p6w4
Dataset 2 - tieredImageNet: [Google Drive] [百度网盘] 提取码: 729f
Pre-trained models: [Google Drive] [百度网盘] 提取码: 2e7p
Meta-trained checkpoints: [Google Drive] [百度网盘] 提取码: wc7g

Transductive Experiments

See the transductive setting experiments in this branch: https://github.com/yaoyao-liu/E3BM/tree/transductive.

Citation

Please cite our paper if it is helpful to your work:

@inproceedings{Liu2020E3BM,
  author    = {Liu, Yaoyao and
               Schiele, Bernt and
               Sun, Qianru},
  title     = {An Ensemble of Epoch-wise Empirical Bayes for Few-shot Learning},
  booktitle = {European Conference on Computer Vision (ECCV)},
  year      = {2020}
}

Acknowledgements

Our implementations use the source code from the following repositories:

About

PyTorch implementation of "An Ensemble of Epoch-wise Empirical Bayes for Few-shot Learning" (ECCV 2020)

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published