This is the official PyTorch implementation of ICCV 2021 paper: Scalable Vision Transformers with Hierarchical Pooling.
By Zizheng Pan, Bohan Zhuang, Jing Liu, Haoyu He, and Jianfei Cai.
In our paper, we propose a Hierarchical Visual Transformer (HVT) which progressively pools visual tokens to shrink the sequence length and hence reduces the computational cost, analogous to the feature maps downsampling in Convolutional Neural Networks (CNNs). Moreover, we empirically find that the average pooled visual tokens contain more discriminative information than the single class token.
If you use this code for a paper please cite:
@InProceedings{Pan_2021_ICCV,
author = {Pan, Zizheng and Zhuang, Bohan and Liu, Jing and He, Haoyu and Cai, Jianfei},
title = {Scalable Vision Transformers With Hierarchical Pooling},
booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
month = {October},
year = {2021},
pages = {377-386}
}
First, clone the repository locally:
git clone https://github.com/MonashAI/HVT
Then, install PyTorch 1.7.0+ and torchvision 0.8.1+ and pytorch-image-models 0.3.2:
conda install -c pytorch pytorch torchvision
pip install timm==0.3.2
Download the ImageNet 2012 dataset from here, and prepare the dataset based on this script. The file structure should look like:
imagenet
├── train
│ ├── class1
│ │ ├── img1.jpeg
│ │ ├── img2.jpeg
│ │ └── ...
│ ├── class2
│ │ ├── img3.jpeg
│ │ └── ...
│ └── ...
└── val
├── class1
│ ├── img4.jpeg
│ ├── img5.jpeg
│ └── ...
├── class2
│ ├── img6.jpeg
│ └── ...
└── ...
Download the CIFAR100 dataset from here.
To train HVT-Ti-1 on ImageNet with 8 gpus, run:
python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --config config/hvt-ti-1.json --data-set IMNET --data-path [path/to/imagenet]
We also provide configuration files for HVT-S-1 and Scale HVT-Ti-4 under the config
folder.
To train HVT-Ti-1 on CIFAR100, run:
python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --config config/hvt-ti-1.json --data-set CIFAR --data-path [path/to/cifar100]
To evaluate a model on ImageNet, e.g. HVT-S-1, run:
python main.py --config config/hvt-s-1.json --data-set IMNET --data-path [path/to/imagenet] --eval --resume [path/to/hvt_s_1.pth]
You can scale a HVT model with various settings, which is supported in the configuration file:
input_size
: The input image size.patch_size
: The patch size that used to split an image.num_heads
: The number of self-attention heads in a MSA layer.num_blocks
: The number of Transformer blocks in a model.pool_kernel_size
: The kernel size of the pooling layer.pool_stride
: The stride of the pooling layer.pool_block_width
: The number of blocks for each stage.
Name | FLOPs (G) | Params (M) | Top-1 Acc. (%) | Top-5 Acc. (%) | Model | Log |
---|---|---|---|---|---|---|
HVT-Ti-1 | 0.64 | 5.74 | 69.64 | 89.40 | github | log |
Scale HVT-Ti-4 | 1.39 | 22.12 | 75.23 | 92.30 | github | log |
HVT-S-1 | 2.40 | 22.09 | 78.00 | 93.83 | github | log |
Note that model weights and logs for HVT-Ti-1 and HVT-S-1 have been retrained.
Name | FLOPs (G) | Params (M) | Top-1 Acc. (%) | Top-5 Acc. (%) | Model | Log |
---|---|---|---|---|---|---|
HVT-S-0 | 4.57 | 22.05 | 80.39 | 95.13 | github | log |
HVT-S-1 | 2.40 | 22.09 | 78.00 | 93.83 | github | log |
HVT-S-2 | 1.94 | 22.11 | 77.36 | 93.55 | github | log |
HVT-S-3 | 1.62 | 22.11 | 76.32 | 92.90 | github | log |
HVT-S-4 | 1.39 | 22.12 | 75.23 | 92.30 | github | log |
For CIFAR-100 results, please check out our paper for more details.
This repository is released under the Apache 2.0 license as found in the LICENSE file.
This repository has adopted codes from DeiT, we thank the authors for their open-sourced code.