Skip to content
/ HVT Public

[ICCV 2021] Official implementation of "Scalable Vision Transformers with Hierarchical Pooling"

License

Notifications You must be signed in to change notification settings

ziplab/HVT

Repository files navigation

Scalable Vision Transformers with Hierarchical Pooling

License

This is the official PyTorch implementation of ICCV 2021 paper: Scalable Vision Transformers with Hierarchical Pooling.

By Zizheng Pan, Bohan Zhuang, Jing Liu, Haoyu He, and Jianfei Cai.

DeiT

In our paper, we propose a Hierarchical Visual Transformer (HVT) which progressively pools visual tokens to shrink the sequence length and hence reduces the computational cost, analogous to the feature maps downsampling in Convolutional Neural Networks (CNNs). Moreover, we empirically find that the average pooled visual tokens contain more discriminative information than the single class token.

If you use this code for a paper please cite:

@InProceedings{Pan_2021_ICCV,
    author    = {Pan, Zizheng and Zhuang, Bohan and Liu, Jing and He, Haoyu and Cai, Jianfei},
    title     = {Scalable Vision Transformers With Hierarchical Pooling},
    booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
    month     = {October},
    year      = {2021},
    pages     = {377-386}
}

Updates

  • 30/12/2021: HVT has been integrated into PaddleViT, checkout here for the 3rd party implementation!

Usage

First, clone the repository locally:

git clone https://github.com/MonashAI/HVT

Then, install PyTorch 1.7.0+ and torchvision 0.8.1+ and pytorch-image-models 0.3.2:

conda install -c pytorch pytorch torchvision
pip install timm==0.3.2

Data preparation

ImageNet

Download the ImageNet 2012 dataset from here, and prepare the dataset based on this script. The file structure should look like:

imagenet
├── train
│   ├── class1
│   │   ├── img1.jpeg
│   │   ├── img2.jpeg
│   │   └── ...
│   ├── class2
│   │   ├── img3.jpeg
│   │   └── ...
│   └── ...
└── val
    ├── class1
    │   ├── img4.jpeg
    │   ├── img5.jpeg
    │   └── ...
    ├── class2
    │   ├── img6.jpeg
    │   └── ...
    └── ...

CIFAR100

Download the CIFAR100 dataset from here.

Training

To train HVT-Ti-1 on ImageNet with 8 gpus, run:

python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --config config/hvt-ti-1.json --data-set IMNET --data-path [path/to/imagenet]

We also provide configuration files for HVT-S-1 and Scale HVT-Ti-4 under the config folder.

To train HVT-Ti-1 on CIFAR100, run:

python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --config config/hvt-ti-1.json --data-set CIFAR --data-path [path/to/cifar100]

Evaluation

To evaluate a model on ImageNet, e.g. HVT-S-1, run:

python main.py --config config/hvt-s-1.json --data-set IMNET --data-path [path/to/imagenet] --eval --resume [path/to/hvt_s_1.pth]

Scaling HVT

You can scale a HVT model with various settings, which is supported in the configuration file:

  • input_size: The input image size.
  • patch_size: The patch size that used to split an image.
  • num_heads: The number of self-attention heads in a MSA layer.
  • num_blocks: The number of Transformer blocks in a model.
  • pool_kernel_size: The kernel size of the pooling layer.
  • pool_stride: The stride of the pooling layer.
  • pool_block_width: The number of blocks for each stage.

Results on ImageNet

Main Results

Name FLOPs (G) Params (M) Top-1 Acc. (%) Top-5 Acc. (%) Model Log
HVT-Ti-1 0.64 5.74 69.64 89.40 github log
Scale HVT-Ti-4 1.39 22.12 75.23 92.30 github log
HVT-S-1 2.40 22.09 78.00 93.83 github log

Note that model weights and logs for HVT-Ti-1 and HVT-S-1 have been retrained.

More Pooling Stages with HVT-S

Name FLOPs (G) Params (M) Top-1 Acc. (%) Top-5 Acc. (%) Model Log
HVT-S-0 4.57 22.05 80.39 95.13 github log
HVT-S-1 2.40 22.09 78.00 93.83 github log
HVT-S-2 1.94 22.11 77.36 93.55 github log
HVT-S-3 1.62 22.11 76.32 92.90 github log
HVT-S-4 1.39 22.12 75.23 92.30 github log

For CIFAR-100 results, please check out our paper for more details.

License

This repository is released under the Apache 2.0 license as found in the LICENSE file.

Acknowledgement

This repository has adopted codes from DeiT, we thank the authors for their open-sourced code.

About

[ICCV 2021] Official implementation of "Scalable Vision Transformers with Hierarchical Pooling"

Topics

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages