Skip to content

Rakshith-2905/SiSTA

Repository files navigation

SiSTA: Target-Aware Generative Augmentations for Single-Shot Adaptation

Abstract

While several test-time adaptation techniques have emerged, they typically rely on synthetic toolbox data augmentations in cases of limited target data availability. We consider the challenging setting of single-shot adaptation and explore the design of augmentation strategies. We argue that augmentations utilized by existing methods are insufficient to handle large distribution shifts, and hence propose a new approach SiSTA(Single-Shot Target Augmentations), which first fine-tunes a generative model from the source domain using a single-shot target, and then employs novel sampling strategies for curating synthetic target data. Using experiments on a variety of benchmarks, distribution shifts and image corruptions, we find that SiSTA produces significantly improved generalization over existing baselines in face attribute detection and multi-class object recognition. Furthermore, SiSTA performs competitively to models obtained by training on larger target datasets.

alt text

Requirements

The requirements for the project is given as conda yml file

conda env create -f SiSTA.yml
conda activate SiSTA

Datasets

Place the datasets following the below file structure

├── SISTA
│   create_reference.sh
│   finetune_GAN.sh
│   README.md
│   SiSTA.yml
│   source_adapt.sh
│   source_train.sh
│   synth_data.sh
│   
├───data
│   ├───AFHQ
│   │   ├───target
│   │   ├───test
│   │   └───train
│   ├───CelebA-HQ
│   │   ├───target
│   │   ├───test
│   │   └───train
│   └───CIFAR-10
│       ├───target
│       ├───test
│       └───train
│       
├───generative_augmentation
│   │   e4e_projection.py
│   │   GAN_AFHQ.py
│   │   GAN_CelebA.py
│   │   GAN_CIFAR-10.py
│   │   model.py
│   │   transformations.py
│   │   util.py
│   │
│   ├───data
│   ├───e4e
│   │
│   ├───models
│   └───op
│
└───SISTA_DA
        celebahq_dataloader.py
        celeba_dataloader.py
        data_list.py
        image_NRC_target.py
        image_source.py
        image_target.py
        image_target_memo.py
        loss.py
        network.py
        randconv.py
        README.md
        run.sh
        utils_memo.py

target images:
To create target images for different domains from the paper, create_reference.sh <data_path> <dst_path> <domain>

Algorithm

alt text Our method has 4 major steps

  1. Source model and GAN training
  2. Single-shot styleGAN finetuning
  3. Synthetic data generation
  4. Source Free UDA using the synthetic data

Source model Training

  • CelebA-HQ binary attribute classification:
    source_train.sh CelebA-HQ <attribute>
  • AFHQ multi class classification:
    source_train.sh AFHQ
  • CIFAR-10 multi class classification:
    source_train.sh CIFAR-10

We download pretrained source generators:

Single-shot styleGAN finetuning

  • CelebA-HQ:
    finetune_GAN.sh CelebA-HQ <domain>
  • AFHQ multi class classification:
    finetune_GAN.sh AFHQ <domain> <cls> (cls in {'cat', 'dog', 'wild'})
  • CIFAR-10 multi class classification:
    finetune_GAN.sh CIFAR-10 <domain> <num_cls> (num_cls integer from [1,10])

Synthetic data generation

  • Base: synth_data.sh <data_type> <domain> <cls> base
  • Prune-zero: synth_data.sh <data_type> <domain> <cls> prune-zero
  • Prune-rewind: synth_data.sh <data_type> <domain> <cls> prune-rewind

Source Free UDA

To Run source free UDA follow the instructions in 'SISTA_DA/README.md' folder

Acknowledgments

This code builds upon the following codebases: StyleGAN2 by rosalinity, e4e, StyleGAN-NADA, NRC, MEMO and RandConv. We thank the authors of the respective works for publicly sharing their code. Please cite them when appropriate.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Languages