Skip to content

Phoenix1153/ViT_OOD_generalization

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

61 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Out-of-distribution Generalization Investigation on Vision Transformers

This repository contains PyTorch evaluation code for CVPR 2022 accepted paper Delving Deep into the Generalization of Vision Transformers under Distribution Shifts.

Taxonomy of Distribution Shifts

Illustration of our taxonomy of distribution shifts. We build the taxonomy upon what kinds of semantic concepts are modified from the original image and divide the distribution shifts into four cases: background shifts, corruption shifts, texture shifts, and style shifts. denotes the unmodified vision cues under certain type of distribution shifts. Please refer to the literature for details.

Dataset

We build OOD-Net, a collection constituted of data under four types of distribution shift and their in-distribution counterparts, for comprehensive investigation into model out-out-distribution generalization properties. The download link is available here.

Dataset Shift Type
ImageNet-9 Background Shift
ImageNet-C Corruption Shift
Cue Conflict Stimuli Texture Shift
Stylized-ImageNet Texture Shift
ImageNet-R Style Shift
DomainNet Style Shift

Generalization-Enhanced Vision Transformers

A framework overview of the three designed generalization-enhanced ViTs. All networks use a Vision Transformer as feature encoder and a label prediction head . Under this setting, the inputs to the models have labeled source examples and unlabeled target examples. top left: T-ADV promotes the network to learn domain-invariant representations by introducing a domain classifier for domain adversarial training. top right: T-MME leverage the minimax process on the conditional entropy of target data to reduce the distribution gap while learning discriminative features for the task. The network uses a cosine similarity-based classifier architecture to produce class prototypes. bottom: T-SSL is an end-to-end prototype-based self-supervised learning framework. The architecture uses two memory banks and to calculate cluster centroids. A cosine classifier is used for classification in this framework.

Run Our Code

Environment Installation

conda create -n vit python=3.6
conda activate vit
conda install pytorch==1.4.0 torchvision==0.5.0 cudatoolkit=10.0 -c pytorch

Before Running

conda activate vit
PYTHONPATH=$PYTHONPATH:.

Evaluation

CUDA_VISIBLE_DEVICES=0 python main.py \
--model deit_small_b16_384 \
--num-classes 345 \
--checkpoint data/checkpoints/deit_small_b16_384_baseline_real.pth.tar \
--meta-file data/metas/DomainNet/sketch_test.jsonl \
--root-dir data/images/DomainNet/sketch/test

Experimental Results

DomainNet

DeiT_small_b16_384

confusion matrix for the baseline model

clipart painting real sketch
clipart 80.25 33.75 55.26 43.43
painting 36.89 75.32 52.08 31.14
real 50.59 45.81 84.78 39.31
sketch 52.16 35.27 48.19 71.92

Above used models could be found here.

Remarks

  • These results may slightly differ from those in our paper due to differences of the environments.

  • We will continuously update this repo.

Citation

If you find these investigations useful in your research, please consider citing:

@article{zhang2021delving,
  title={Delving deep into the generalization of vision transformers under distribution shifts},
  author={Zhang, Chongzhi and Zhang, Mingyuan and Zhang, Shanghang and Jin, Daisheng and Zhou, Qiang and Cai, Zhongang and Zhao, Haiyu and Yi, Shuai and Liu, Xianglong and Liu, Ziwei},
  journal={arXiv preprint arXiv:2106.07617},
  year={2021}
}

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages