This is a PyTorch implementation of "VirFace: Enhancing Face Recognition via Unlabeled Shallow Data" (CVPR 2021).
Training Stage | Datasets | args.method | pretrain model needed |
Pre-train backbone | Label | pretrain | None |
Pre-train generator | Label | generator | backbone, head |
Training VirClass | Label & Unlabel | virclass | backbone, head |
Training VirFace | Label & Unlabel | virface | backbone, head, generator |
Arch | ResNet | User Arch | ||||
18 | 34 | 50 | 101 | 152 | ||
args.arch | resnet18 | resnet34 | resnet50 | resnet101 | resnet152 | usr |
- edit config.py: import modified '.py' file, then modifiy "model_usr=None" to "model_usr=%filename.%modelname(%params)"
- Note: the user modified architecture should not contain any classification FC layer(e.g. the last FC layer in ArcFace). The output should be an embedding feature.
- Pretrain
python3 train.py --method pretrain
python3 train.py --method generator --pretrain_file %backbone_path
Note: 'pretrain' means training backbone(contraining head params) and save [backbone, head] model; 'generator' means training generator and save [backbone, head, generator] models.
-
Training from scrach
- VirClass
python3 train.py --method pretrain python3 train.py --method virclass --pretrain_file %backbone_path
- VirFace
python3 train.py --method pretrain python3 train.py --method generator --pretrain_file %backbone_path python3 train.py --method virface --pretrain_file %generator_path
-
Training from pretrained model
- VirClass
python3 train.py --method virclass --pretrain_file %backbone_path
- VirFace
python3 train.py --method virface --pretrain_file %generator_path
Note: this pretrained file should contain [backbone, head, generator]
-
Arguments setting
method: choose training stage. [pretrain, generator, virclass, virface]
label_batch_size: label data batch size.
unlabel_batch_size: unlabel data batch size.
arch: choose architecture of backbone. [resnet18, resnet34, resnet50, resnet101, resnet152, usr]
feat_len: length of embedding feature.
num_ids: number of identities in labeled dataset.
gen_num: the number of generated features via generator.
KL: weight of KL loss.
L2: weight of MSE loss.
resume: resume flag.
resume_file: resume checkpoint path.
pretrained_file: pretrained checkpoint path. This is necessary if method is not "pretrain".
tensorboard: use tensorboard or not.
snapshot_prefix: path to save checkpoint.
eval: whether evaluate for each epoch on LFW, CFP-FF, CFP-FP
If our paper helps your research, please cite it in your publications:
@inproceedings{li2021virface,
title={VirFace: Enhancing Face Recognition via Unlabeled Shallow Data},
author={Li, Wenyu and Guo, Tianchu and Li, Pengyu and Chen, Binghui and Wang, Biao and Zuo, Wangmeng and Zhang, Lei},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={14729--14738},
year={2021}
}