Skip to content

zhengli97/PromptKD

Repository files navigation

PromptKD: Unsupervised Prompt Distillation for Vision-Language Models

PromptKD: Unsupervised Prompt Distillation for Vision-Language Models
Zheng Li, Xiang Li#, Xinyi Fu, Xin Zhang, Weiqiang Wang, Shuo Chen, Jian Yang#.
Nankai University, Ant Group, RIKEN
CVPR 2024
[Paper] [Project Page] [中文解读]


PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC


Abstract

In this paper, we introduce an unsupervised domain prompt distillation framework, which aims to transfer the knowledge of a larger teacher model to a lightweight target model through prompt-driven imitation using unlabeled domain images.

To our best knowledge, we are the first to (1) perform unsupervised domain-specific prompt-driven knowledge distillation for CLIP, and (2) establish a practical pre-storing mechanism of text features as shared class vectors between teacher and student.

Framework

Figure 1. An overview of our PromptKD framework. (a) We first pre-train a large CLIP teacher model with labeled training images. (b) Reuse the existing higher-quality teacher text features for unsupervised prompt distillation. (c) The well-trained student and pre-stored teacher text features are utilized for final inference.

Highlights

(1). A novel two-stage unsupervised prompt distillation framework for Vision-Language Models.

(2). Reuse high-quality teacher text features instead of training the student's own text encoder.

(3). Distillation on large amounts of unlabeled domain images using soft labels provided by the teacher.

(4). PromptKD outperforms all existing prompt learning methods on 11 diverse recognition datasets.

Experimental Results

Results reported below show accuracy for base and novel classes for across 11 recognition datasets averaged over 3 seeds.

Base-to-Novel

fail

Table 1. Comparison with existing state-of-the-art methods on base-to-novel generalization. Our PromptKD demonstrates strong generalization ability and achieves significant improvements on 11 recognition datasets given the ViT-B/16 image encoder of the CLIP model. The symbol △ denotes the performance improvement compared to the previous SOTA method.

Cross Dataset

fail

Table 2. Comparison of PromptKD with existing advanced approaches on cross-dataset benchmark evaluation. Based on our pipeline, we perform unsupervised prompt distillation using the unlabeled domain data respectively (i.e., the transductive setting). The source model is trained on ImageNet. "ZSL" denotes the setting type for Zero-Shot Learning.

Running

Preliminary

  1. Create the environment and install Dassl.pytorch library. Please follow the instructions detailed in INSTALL.md.

  2. (1) Pre-train your own large teacher CLIP model (See below) or (2) use our publicly released pre-trained teacher ViT-L/14 CLIP models. (Highly Recommended)
    Our pre-trained teacher models are publicly available at [Baidu Yun] [TeraBox] [Google Cloud]
    (Note that due to cloud space limitations, we only provide a limited number of models in Google Cloud. Sorry.)
    After obtaining the teacher model, unzip these files and place the model in the ./teacher_model folder.
    The accuracy of each teacher model is shown in Tables 10 and 11 in the supplementary material of the paper.

  3. Download the original ViT-B/16 and ViT-L/14 CLIP model weights from the official OpenAI website. Then place these models in the ./clip folder.
    [ViT-B/16 CLIP] [ViT-L/14 CLIP]

  4. Prepare the dataset. Please follow the instructions detailed in DATASETS.md.

Train Your Teacher Model

In our paper, we default use PromptSRC to pre-train our ViT-L/14 CLIP teacher model. We have already provided the config file in configs/trainers/PromptSRC/vit_l14_c2_ep20_batch8_4+4ctx.yaml

If your want to train our own teacher model, first you should change scripts/promptsrc/base2new_train.sh line 11 CFG=vit_b16_c2_ep20_batch4_4+4ctx to vit_l14_c2_ep20_batch8_4+4ctx. Then follow the instructions listed in docs/PromptSRC.md and run the script.

Important Note:
The accuracy of your own teacher model may vary depending on your computing environment. To ensure that your teacher model is adequate for distillation, please refer to Appendix Table 10 to check whether your model achieves appropriate accuracy.

If your teacher model cannot achieve the corresponding accuracy or cannot be trained due to computational constraints, I highly recommend that you use our publicly available pre-trained models for distillation.

Running PromptKD

(1) Base-to-Novel Experiments.

  1. The base-to-novel experimental settings are provided in config file at configs/trainers/PromptKD/vit_b16_c2_ep20_batch8_4+4ctx.yaml. You can modify the hyper-parameteres in this config file according to your needs.

  2. Change the dataset path in scripts/promptkd/base2new_train.sh line 4 to your current path.

  3. Run the commands below to train PromptKD on specified dataset.

For example:

# dataset=imagenet, seed=1 
sh scripts/promptkd/base2new_train.sh imagenet 1

# seed=2
sh scripts/promptkd/base2new_train.sh imagenet 2

# seed=3
sh scripts/promptkd/base2new_train.sh imagenet 3

# dataset=caltech101, seed=1
sh scripts/promptkd/base2new_train.sh caltech101 1
  1. The output results will be automatically save at output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed_${SEED}.

(2) Cross-dataset Experiments.

  1. The cross-dataset experimental settings are provided in config file at configs/trainers/PromptKD/vit_b16_c2_ep20_batch8_4+4ctx_cross_datasets.yaml. You can modify the hyper-parameteres in this config file according to your needs.

  2. Change the dataset path in scripts/promptkd/base2new_train.sh line 4 to your current path.

  3. Run the commands below to train PromptKD on specified dataset.

For example:

# dataset=caltech101, seed=1 
sh scripts/promptkd/xd_train.sh caltech101 1

# seed=2
sh scripts/promptkd/xd_train.sh caltech101 2

# seed=3
sh scripts/promptkd/xd_train.sh caltech101 3

# dataset=oxford_pets, seed=1
sh scripts/promptkd/base2new_train.sh oxford_pets 1
  1. The output results will be automatically saved at output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/seed${SEED}.

Model Zoo

Here we provide the pretrained student models and complete training logs using 64-shots and 0-shots (i.e., full dataset) on ImageNet dataset for your references. Please refer to [Releases Part].

Contact

For any questions, please contact me via email (zhengli97[at]mail.nankai.edu.cn).

Citation

If you find our paper or repo is helpful for your research, please kindly cite our paper and give this repo a star⭐.

@inproceedings{li2024promptkd,
  title={PromptKD: Unsupervised Prompt Distillation for Vision-Language Models},
  author={Li, Zheng and Li, Xiang and Fu, Xinyi and Zhang, Xin and Wang, Weiqiang and Chen, Shuo and Yang, Jian},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  year={2024}
}

Acknowledgements

Our code is based on PromptSRC, MaPLe, Co-CoOp and CoOp repository. We thank the authors for releasing their code.