This repository contains PyTorch evaluation code for CVPR 2022 accepted paper Delving Deep into the Generalization of Vision Transformers under Distribution Shifts.
Illustration of our taxonomy of distribution shifts. We build the taxonomy upon what kinds of semantic concepts are modified from the original image and divide the distribution shifts into four cases: background shifts, corruption shifts, texture shifts, and style shifts. denotes the unmodified vision cues under certain type of distribution shifts. Please refer to the literature for details.
We build OOD-Net, a collection constituted of data under four types of distribution shift and their in-distribution counterparts, for comprehensive investigation into model out-out-distribution generalization properties. The download link is available here.
Dataset | Shift Type |
---|---|
ImageNet-9 | Background Shift |
ImageNet-C | Corruption Shift |
Cue Conflict Stimuli | Texture Shift |
Stylized-ImageNet | Texture Shift |
ImageNet-R | Style Shift |
DomainNet | Style Shift |
A framework overview of the three designed generalization-enhanced ViTs. All networks use a Vision Transformer as feature encoder and a label prediction head . Under this setting, the inputs to the models have labeled source examples and unlabeled target examples. top left: T-ADV promotes the network to learn domain-invariant representations by introducing a domain classifier for domain adversarial training. top right: T-MME leverage the minimax process on the conditional entropy of target data to reduce the distribution gap while learning discriminative features for the task. The network uses a cosine similarity-based classifier architecture to produce class prototypes. bottom: T-SSL is an end-to-end prototype-based self-supervised learning framework. The architecture uses two memory banks and to calculate cluster centroids. A cosine classifier is used for classification in this framework.
conda create -n vit python=3.6 conda activate vit conda install pytorch==1.4.0 torchvision==0.5.0 cudatoolkit=10.0 -c pytorch
conda activate vit PYTHONPATH=$PYTHONPATH:.
CUDA_VISIBLE_DEVICES=0 python main.py \ --model deit_small_b16_384 \ --num-classes 345 \ --checkpoint data/checkpoints/deit_small_b16_384_baseline_real.pth.tar \ --meta-file data/metas/DomainNet/sketch_test.jsonl \ --root-dir data/images/DomainNet/sketch/test
confusion matrix for the baseline model
clipart | painting | real | sketch | |
---|---|---|---|---|
clipart | 80.25 | 33.75 | 55.26 | 43.43 |
painting | 36.89 | 75.32 | 52.08 | 31.14 |
real | 50.59 | 45.81 | 84.78 | 39.31 |
sketch | 52.16 | 35.27 | 48.19 | 71.92 |
Above used models could be found here.
-
These results may slightly differ from those in our paper due to differences of the environments.
-
We will continuously update this repo.
If you find these investigations useful in your research, please consider citing:
@article{zhang2021delving,
title={Delving deep into the generalization of vision transformers under distribution shifts},
author={Zhang, Chongzhi and Zhang, Mingyuan and Zhang, Shanghang and Jin, Daisheng and Zhou, Qiang and Cai, Zhongang and Zhao, Haiyu and Yi, Shuai and Liu, Xianglong and Liu, Ziwei},
journal={arXiv preprint arXiv:2106.07617},
year={2021}
}