Skip to content

tommy-xq/SA2VP

Repository files navigation

SA2VP: Spatially Aligned-and-Adapted Visual Prompt

paper link: https://arxiv.org/abs/2312.10376


This repository contains the official PyTorch implementation for SA2VP.

model_img

Environment settings

We use the framework from https://github.com/microsoft/unilm/tree/master/beit

we use following datasets for evaluation:

https://github.com/KMnP/vpt (FGVC)

https://github.com/dongzelian/SSF (VTAB-1k)

https://github.com/shikiw/DAM-VP (HTA)

This code is tested with Python-3.7.13, Pytorch = 1.12.1 and CUDA = 11.4, requiring the following dependencies:

  • timm = 0.6.7

we also provide the requirement.txt for reference.

Structure of this repo

  • ./backbone_ckpt: save the ViT and Swin Transformer pre-trained ckpt.

  • ./data: download and setup input datasets, containing fgvc and vtab-1k.

│SA2VP/
├──data/
│   ├──fgvc/
│   │   ├──CUB_200_2011/
│   │   ├──OxfordFlower/
│   │   ├──Stanford-cars/
│   │   ├──Stanford-dogs/
│   │   ├──nabirds/
│   ├──vtab-1k/
│   │   ├──caltech101/
│   │   ├──cifar/
│   │   ├──.......
├──backbone_ckpt/
│   ├──imagenet21k_ViT-B_16.npz
│   ├──swin_base_patch4_window7_224_22k.pth
  • ./model_save: save the final ckpt.

  • ./log_save: save the log.

  • ./vpt_main: we use the VPT code to initialize model.

    • 👉./vpt_main/src/models/vit_backbones/vit_tinypara.py: SA2VP based on ViT backbone.

    • 👉./vpt_main/src/models/vit_backbones/vit_tinypara_acc.py: We have accelerated the attention calculation of SA2VP.

    • 👉 ./vpt_main/src/models/vit_backbones/swin_transformer_tinypara.py: SA2VP based on Swin Transformer backbone.

    • ./vpt_main/src/models/build_swin_backbone.py: package SA2VP based on Swin. In this file, it will import model in swin_transformer_tinypara.py.

  • datasets.py: contain all datasets.

  • engine_for_train.py: engine for train and test.

  • 👉vit_train_sa2vp.py: call this to train SA2VP based on ViT. In line 37, you can use the accelerated version by adding '_acc' to the model name.

  • 👉vit_train_swin.py: call this to train SA2VP based on Swin Transformer.

  • 👉Train_nature.sh/Train_special.sh/Train_struct.sh: scripts used for automatic training.

Experiment steps

  • 1\ Download the pre-trained ckpt of ViT and Swin from VPT. Use ViT-B/16 Supervised and Swin-B Supervised.

  • 2\ Change the name and path in vit_train_sa2vp.py line 48 and in vit_train_swin.py line 47.

  • 3\ Set different branch training weights in engine_for_train.py line 26/177.

  • 4\ Set datasets path in datasets.py line 1160/1161 (prefix_fgvc/prefix_vtab). Note that you need to choose transform for fgvc or vtab in line 1157/1158 and Pay attention to the dataset name in the following.

  • 5\ Change model config. For SA2VP based on ViT, we set inter-dim in vit_tinypara.py line 280/281/334/428 and inter-weight in line 427. For SA2VP based on Swin, set inter-dim in vit_train_swin.py line 169/170/675 and inter-weight in line 596. Default lr 1e-3 and weight_decay 1e-4.

  • For ViT: (vtab: SVHN-16-0.5; Resisc45-16-0.5; ds/ori-16-0.1; sn/ele-32-0.5 need to Specially handle. || vtab special lr: Pets-5e-4; Clevr/Count-5e-4.)

CUB Nabirds Flower DOG CAR
inter-dim 16 32 8 32 64
inter-weight 0.1 0.1 0.1 0.1 1.5
batch size 64/128 64/128 64/128 64/128 64/128
vtab-Natural vtab-Special vtab-Structure HTA
inter-dim 8 16 32 64
inter-weight 0.1 1.5 1.5 0.1
batch size 40/64 40 40 64/128
  • For Swin:
vtab-Natural vtab-Special vtab-Structure
inter-dim 8 8 8
inter-weight 0.1/0.5 0.5/1.5 1.5
batch size 40/64 40 40
  • Training Scripts:

    • Single GPU
    CUDA_VISIBLE_DEVICES=1 python vit_train_sa2vp.py  --data_set CUB --output_dir ./model_save/CUB --update_freq 1  --warmup_epochs 10 --epochs 100 --drop_path 0.0  --lr 1e-3 --weight_decay 1e-4 --nb_classes 200 --log_dir ./log_save --batch_size 64 --my_mode train_val --min_lr 1e-7
    
    • Multiple GPUs
    CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 vit_train_sa2vp.py  --data_set CIFAR --output_dir ./model_save/CIFAR --update_freq 1  --warmup_epochs 10 --epochs 100 --drop_path 0.0  --lr 1e-3 --weight_decay 1e-4 --nb_classes 100 --log_dir ./log_save --batch_size 40 --my_mode train_val --min_lr 1e-7
    
  • Test Script:

    • For VTAB-1k
    CUDA_VISIBLE_DEVICES=1 python vit_train_sa2vp.py --data_set DS_LOC --eval --batch_size 64 --resume ./model_save/DS_LOC/checkpoint-99.pth --nb_classes 16 --my_mode trainval_test
    
    • For FGVC
    CUDA_VISIBLE_DEVICES=1 python vit_train_sa2vp.py --data_set CAR --eval --batch_size 64 --resume ./model_save/CAR/checkpoint-best.pth --nb_classes 196 --my_mode trainval_test
    
  • Note: --my_mode is to decide train/val/test sets. In train_val: to find the best model on val set when training. In trainval_test: use train/val sets to train and report acc on test set. We follow the strategy of VPT.

Citation

If you find our work helpful in your research, please cite it as:

@inproceedings{pei2024sa2vp,
  title={SA^2VP: Spatially Aligned-and-Adapted Visual Prompt},
  author={Pei, Wenjie and Xia, Tongqi and Chen, Fanglin and Li, Jinsong and Tian, Jiandong and Lu, Guangming},
  booktitle = {Proceedings of the AAAI Conference on Artificial Intelligence},
  year={2024}
}

License

The code is released under MIT License (see LICENSE file for details).

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published