This is the official PyTorch Implementation of "SoTTA: Robust Test-Time Adaptation on Noisy Data Streams (NeurIPS '23)" by Taesik Gong*, Yewon Kim*, Taeckyung Lee*, Sorn Chottananurak, and Sung-Ju Lee (* Equal contribution).
[ OpenReview ] [ arXiv ] [ Website ]
- Download or clone our repository.
- Set up a Python environment using conda (see below).
- Prepare datasets (see below).
- Run the code (see below).
We use Conda environment. You can get conda by installing Anaconda first.
We share our Python environment that contains all required Python packages. Please refer to the ./sotta.yml
file.
You can import our environment using conda:
conda env create -f sotta.yml -n sotta
To run our codes, you first need to download at least one of the datasets. Run the following commands:
$ cd . #project root
$ . download_cifar10c.sh #download CIFAR10/CIFAR10-C datasets
$ . download_cifar100c.sh #download CIFAR100/CIFAR100-C datasets
Also, you can download the following datasets and locate them in the ./dataset
folder (create the folder if not exists):
- ImageNet-C: https://zenodo.org/record/2235448
- MNIST-C: https://zenodo.org/record/3239543
"Source model" refers to a model that is trained with the source (clean) data only. Source models are required for all methods to perform test-time adaptation.
We provide the pretrained model for CIFAR10/CIFAR100 with three random seeds (0,1,2) at GDrive Link. After extracting log.zip
, put this folder to the project root directory, i.e., SoTTA/log
.
Alternatively, you can train source models via:
$ . train_src.sh #generate source models for CIFAR10 as default.
You can specify which dataset to use in the script file.
Given source models are available, you can run TTA via:
$ . tta.sh #run SoTTA for tta-target: CIFAR10-C, noisy-stream: MNIST as default.
You can specify which dataset and which method in the script file.
In addition to console outputs, the result will be saved as a log file with the following structure: ./log/{DATASET}/{METHOD}_noisy/{TGT}/{LOG_PREFIX}_{SEED}_{DIST}/online_eval.json
In order to print the classification accuracies(%) on the test set, run the following commands:
$ python print_acc.py --method SoTTA #prints the result of the specified condition.
We tested our codes in this environment.
- OS: Ubuntu 20.04.4 LTS
- GPU: NVIDIA GeForce RTX 3090
- GPU Driver Version: 470.74
- CUDA Version: 11.4
@inproceedings{ gong2023sotta,
title={{SoTTA}: Robust Test-Time Adaptation on Noisy Data Streams},
author={Gong, Taesik and Kim, Yewon and Lee, Taeckyung and Chottananurak, Sorn and Lee, Sung-Ju},
booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
year={2023}
}