Skip to content

yuyongcan/Benchmark-TTA

Repository files navigation

Prerequisites

To use the repository, we provide a conda environment.

conda update conda
conda env create -f environment.yaml
conda activate Benchmark_TTA 

Structure of Project

This project contains several directories. Their roles are listed as follows:

  • ./best_cfgs: the best config files for each dataset and algorithm are saved here.
  • ./robustbench: a official library we used to load robust datasets and models.
  • ./src/
    • data: we load our datasets and dataloaders by code under this directories.
    • methods: the code for implements of various TTA methods.
    • models: the various models' loading process and definition rely on the code here.
    • utils: some useful tools for our projects.

Run

This repository allows to study a wide range of different datasets, models, settings, and methods. A quick overview is given below:

  • Datasets

    • cifar10_c CIFAR10-C

    • cifar100_c CIFAR100-C

    • imagenet_c ImageNet-C

    • domainnet126 DomainNet (cleaned)

    • officehome Office-Home

    • The dataset directory structure is as follows:

      |-- datasets

          |-- cifar-10
      
          |-- cifar-100
      
          |-- ImageNet
      
                  |-- train
      
                  |-- val
      
          |-- ImageNet-C
      
          |-- CIFAR-10-C
      
          |-- CIFAR-100-C
      
          |-- DomainNet
      
                  |-- clipart
      
                  |-- painting
      
                  |-- real
      
                  |-- sketch
      
                  | -- clipart126_test.txt
      
                  ......
      
          |-- office-home
      
                  |-- Art
      
                  |-- Clipart
      
                  |-- Product
      
                  |-- Real_World
      

    You can download the .txt file for DomainNet in ./dataset/DomainNet, generate .txt file for office-home following SHOT

  • Models

    • For adapting to ImageNet variations, ResNet-50 models available in Torchvision can be used and ViT available in timm · PyPI.
    • For the corruption benchmarks, pre-trained models from RobustBench can be used.
    • For the DomainNet-126 benchmark, there is a pre-trained model for each domain.
    • The checkpoint of pretrained models is in directory ckpt
  • Methods

  • Modular Design

    • Adding new methods should be rather simple, thanks to the modular design.

Get Started

To run one of the following benchmarks, the corresponding datasets need to be downloaded.

Next, specify the root folder for all datasets _C.DATA_DIR = "./data" in the file conf.py.

The best parameters for each method and dataset are save in ./best_cfgs

download the ckpt of pretrained models and data load sequences from here and put it in ./ckpt

How to reproduce

The entry file for SHOT, NRC, PLUE to run is SFDA-eva.sh

To evaluate this methods, modify the DATASET and METHOD in SFDA-eva.sh

and then

bash SFDA-eva.sh

The entry file for other algorithms is test-time-eva.sh

To evaluate this methods, modify the DATASET and METHOD in test-time-eva.sh

and then

bash test-time-eva.sh

Add your own algorithm, dataset and model

We decouple the loading of datasets, models, and methods. So you can add them to our benchmarks completely independently.

To add a algorithm

  1. You can add a python files Algorithm_XX.py for your algorithm in ./src/methods/

  2. Add the setup process function of your algorithm setup_XX(model, cfg) in function ./src/methods/setup.py.

  3. Add two line of your setup code in line 22 on ./test-time.py like

        elif cfg.MODEL.ADAPTATION == "XX":
            model, param_names = setup_XX(base_model, cfg)

To add a dataset

  1. Write a function load_dataset_name() to load your dataset Dataset_new in ./src/data/data.py

  2. Define the transforms used to load your dataset on function get_transform() in ./src/data/data.py

  3. Add two line to load your dataset in function load_dataset() in ./src/data/data.py like

        elif dataset == 'dataset_name':
            return load_dataset_name(root=root, batch_size=batch_size, workers=workers, split=split, transforms=transforms,
                                 ckpt=ckpt)

To add a model

  1. Just add the code for loading your model in load_model() function in ./src/model/load_model.py like

        elif model_name == 'model_new':
            model =# the code for loading your model

You can cite our work by

@article{yu2023benchmarking,
  title={Benchmarking test-time adaptation against distribution shifts in image classification},
  author={Yu, Yongcan and Sheng, Lijun and He, Ran and Liang, Jian},
  journal={arXiv preprint arXiv:2307.03133},
  year={2023}
}

Acknowledgements

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published