This is the official implementation of our BMVC 2022 paper SP-ViT: Learning 2D Spatial Priors for Vision Transformers.
Comparison between our proposed SP-ViT and state-of-the-art vision transformers. Note that we exclude models pretrained on extra data or larger resolution than 224 × 224 for a fair comparison.
Our implementation is based on the TokenLabeling and pytorch-image-models.
SP-ViT Models(Using TokenLabeling as baseline)
Model | layer | dim | Image resolution | Param | Top 1 | Download |
---|---|---|---|---|---|---|
SP-ViT-S | 16 | 384 | 224 | 26M | 83.9 | link |
SP-ViT-S | 16 | 384 | 384 | 26M | 85.1 | link |
SP-ViT-M | 20 | 512 | 224 | 56M | 84.9 | link |
SP-ViT-M | 20 | 512 | 384 | 56M | 86.0 | link |
SP-ViT-L | 24 | 768 | 224 | 150M | 85.5 | link |
SP-ViT-L | 24 | 768 | 384 | 150M | 86.3 | link |
pyyaml scipy timm==0.4.5
data prepare: ImageNet with the following folder structure, you can extract imagenet by this script.
│imagenet/
├──train/
│ ├── n01440764
│ │ ├── n01440764_10026.JPEG
│ │ ├── n01440764_10027.JPEG
│ │ ├── ......
│ ├── ......
├──val/
│ ├── n01440764
│ │ ├── ILSVRC2012_val_00000293.JPEG
│ │ ├── ILSVRC2012_val_00002138.JPEG
│ │ ├── ......
│ ├── ......
Replace DATA_DIR with your imagenet validation set path and MODEL_DIR with the checkpoint path
CUDA_VISIBLE_DEVICES=0 bash eval.sh /path/to/imagenet/val /path/to/checkpoint
Please go to TokenLabeling for downloading the token label data.
Train SP-ViT:
bash distributed_train.sh
To Fine-tune the pre-trained SP-ViT on images with 384x384 resolution:
bash distributed_fine_tune.sh
Please refer to TokenLabeling for more details:
Backbone | Method | Crop size | Lr Schd | mIoU | mIoU(ms) | Pixel Acc. | Param |
---|---|---|---|---|---|---|---|
LV-ViT-S | UperNet | 512x512 | 160k | 47.9 | 48.6 | 83.1 | 44M |
SP-ViT-S | UperNet | 512x512 | 160k | 49.0 | 49.8 | 83.4 | 44M |
We apply the visualization method [Transformer-Explainability] (https://github.com/hila-chefer/Transformer-Explainability) to visualize the parts of the image that led to a certain classification for DeiT-S and our SP-ViT-S (w/o TokenLabeling). The parts of the image that used by the network to make the decision are highlighted in red.
To generate token label data for training:
python3 generate_label.py /path/to/imagenet/train /path/to/save/label_top5_train_nfnet --model dm_nfnet_f6 --pretrained --img-size 576 -b 32 --crop-pct 1.0
If you use this repo or find it useful, please consider citing:
@inproceedings{BMVC2022,
author={Zhou, Yuxuan and Xiang, Wangmeng and Li, Chao and Wang, Biao and Wei, Xihan and Zhang, Lei and Keuper, Margret and Hua, Xiansheng},
booktitle = {The 33rd British Machine Vision Conference},
title = {SP-ViT: Learning 2D Spatial Priors for Vision Transformers},
url = {https://bmvc2022.mpi-inf.mpg.de/0564.pdf},
year = {2022}
}