Skip to content
Accelerated Training for Massive Classification via Dynamic Class Selection (AAAI 2018)
Python Shell
Branch: master
Clone or download
Latest commit 46f92e1 May 11, 2019
Permalink
Type Name Latest commit message Commit time
Failed to load latest commit information.
datasets minor fix May 2, 2019
evaluation specify eval metric Apr 25, 2019
models
paramserver
scripts add hnsw sampler Aug 25, 2018
tools simplify ckpt May 8, 2019
.gitignore mod ignore Jun 23, 2018
LICENSE Initial commit Feb 4, 2018
README.md update README Apr 15, 2019
eval.py specify eval metric Apr 25, 2019
extract_feat.py specify arch & strict Jul 20, 2018
logger.py add logger May 2, 2019
requirements.txt add hnsw sampler Aug 25, 2018
train.py
utils.py

README.md

Accelerated Training for Massive Classification via Dynamic Class Selection (HF-Softmax) pdf

Paper

Accelerated Training for Massive Classification via Dynamic Class Selection, AAAI 2018 (Oral)

Training

  1. Install PyTorch. (Better to install the latest master from source)
  2. Follow the instruction of InsightFace and download training data.
  3. Decode the data(.rec) to images and generate training/validation list.
python tools/rec2img.py --in-folder xxx --out-folder yyy
  1. Try normal training. It uses torch.nn.DataParallel(multi-thread) for parallel.
sh scripts/train.sh dataset_path
  1. Try sampled training. It uses one GPU for training and default sampling number is 1000.
python paramserver/paramserver.py
sh scripts/train_hf.sh dataset_path

Distributed Training

For distributed training, there is one process on each GPU.

Some backends are provided for PyTroch Distributed training. If you want to use nccl as backend for distributed training, please follow the instructions to install NCCL2.

You can test your distributed setting by executing

sh scripts/test_distributed.sh

When NCCL2 is installed, you should re-compile PyTorch from source.

python setup.py clean install

In our case, we use libnccl2=2.2.13-1+cuda9.0 libnccl-dev=2.2.13-1+cuda9.0 and the master of PyTorch 0.5.0a0+e31ab99

Hashing Forest

We use Annoy to approximate the hashing forest. You can adjust sample_num, ntrees and interval to balance performance and cost.

Parameter Sever

Parameter server is decoupled with PyTorch. A client is developed to communicate with the server. Other platforms can integrate the parameter server via the communication API. Currently, it only supports syncronized SGD updater.

Evaluation

./scripts/eval.sh arch model_path dataset_path outputs

It uses torch.nn.DataParallel to extract features and saves it as .npy. The features will subsequently be used to perform the verification test.

If you use distributed training, set strict=False during feature extraction.

Note that the bin file from InsightFace, lfw.bin for example, is pickled by Python2. It cannot be processed by Python 3.0+. You can either use Python2 for evaluation or re-pickle the bin file by Python3 first.

Citation

Please cite the following paper if you use this repository in your reseach.

@inproceedings{zhang2018accelerated,
  title     = {Accelerated Training for Massive Classification via Dynamic Class Selection},
  author    = {Xingcheng Zhang and Lei Yang and Junjie Yan and Dahua Lin},
  booktitle = {AAAI},
  year      = {2018},
}
You can’t perform that action at this time.