Skip to content

Zhicaiwww/Diff-Mix

Repository files navigation

Enhance Image Classification Via Inter-Class Image Mixup With Diffusion Model

Introduction

Currently, a common method to enhance image classification involves expanding the training set with synthetic datasets generated by T2I models. Here, we propose an inter-class data augmentation method, Diff-Mix. Diff-Mix expands the dataset by conducting image translation in an inter-class manner, significantly improving the diversity of synthetic data. We observe an improved trade-off between faithfulness and diversity with Diff-Mix, resulting in a significant performance gain across various image classification settings, including few-shot classification, conventional classification, and long-tail classification, particularly for domain-specific datasets.

Datasets

For convenience, well-structured datasets in Hugging Face can be utilized. The fine-grained datasets CUB and Aircraft we experimented with can be downloaded from Multimodal-Fatima/CUB_train and Multimodal-Fatima/FGVC_Aircraft_train, respectively. In case of encountering network connection problem during training, please pre-download the data from the website, and the saved local path HUG_LOCAL_IMAGE_TRAIN_DIR should be specified in the semantic_aug/datasets/cub.py.

Code Description

Fine-tuning

We fine-tune both the textual tokens and U-Net (LoRA) (diffusers) of the pre-trained Stable Diffusion to expedite the fine-tuning process.

To simplify the usage, the concrete fine-tuning command is wrapped in the script scripts/finetune.sh. The distributed training is performed using the accelerate tool, and the GPU should be specified using the environmental variable CUDA_VISIBLE_DEVICES. The simplified command for fine-tuning on the full training set of CUB with a total of 35000 steps is:

source scripts/finetune.sh
bash finetune 'cub' 'ti_db' -1 35000

To fine-tune in a 5-shot setting, modify the shell command to

source scripts/finetune.sh
bash finetune 'cub' 'ti_db' 5 35000

The fine-tuned checkpoints will be saved under outputs/finetune_model/finetune_ti_db{_5shot}/cub/. After that, please manually add the meta information of checkpoints into config/finetuned_ckpts.yaml constructed with the following format:

cub: 
  ti_db_latest:
    model_path: "runwayml/stable-diffusion-v1-5"
    lora_path: "outputs/finetune_model/finetune_ti_db/sd-cub-model-lora-rank10/checkpoint-35000/pytorch_model.bin"
    embed_path: "outputs/finetune_model/finetune_ti_db/sd-cub-model-lora-rank10/learned_embeds-steps-35000.bin"

This structure allows you to locate the checkpoint paths simply by using the key set ('cub', 'ti_db_latest').

Contruct synthetic data

Similarly, we wrap the command details in the file scripts/sample.sh. To expedite the inference process, we utilize the multiprocessing tool to initiate multiple inference processes. The desired processes should be specified using the defined environmental variable GPU_IDS, where each item in the list denotes the process running on the indexed GPU.

The simplified command for sampling a $5\times$ synthetic subset in an inter-class translation manner (Diff-Mix) with strength $s=0.7$ is:

source scripts/sample.sh
export GPU_IDS=(0 0 0 1 1 1)
bash sample 'cub' 'ti_db_latest' 'diff-mix' 0.7

One can also attempt to construct the synthetic subset using other expansion strategies by replacing diff-mix with diff-aug (Diff-Aug, fine-tuned intra-class translation method), real-mix (Real-Mix, pre-trained inter-class translation method), real-guidance (Real-Aug, pre-trained intra-class translation method).

To sample a 5-shot setting, modify the shell command to:

source scripts/sample.sh
export GPU_IDS=(0 0 0 1 1 1)
bash sample_fewshot 5 'cub' '5shot_ti_db_latest' 'diff-mix' 0.7 

The sampled subset will be cached at outputs/aug_samples{_5shot}/cub. After that, please manually add the meta-information of the subset into synthetic_datasets.yaml constructed with the form:

cub: 
  diffmix_0.7: 'outputs/aug_samples/cub/diff-mix-Multi7-ti_db35000-Strength0.7'
  5shot_diffmix_0.7: 'outputs/aug_samples_5shot/cub/diff-mix-Multi7-ti_db35000-Strength0.7'

This allows you to locate the synthetic paths simply by using the key set ('cub', 'diffmix_fixed_0.7') in case there are multiple subsets.

Downstream classification

After completing the sampling process, you can integrate the synthetic data into downstream classification and initiate training using the following commands:

source scripts/classification.sh
# main_cls {dataset_name} {gpu} {seed} {model} {resolution} {nepoch} {syndata_key} {gamma} {synthetic_prob}
main_cls 'cub' '0' 2020 'resnet50' '224' 120 'diffmix_0.7' 0.5 0.1

Running scripts

Acknowledgements

This project is built upon the repository Da-fusion and diffusers. Special thanks to the contributors.

Requirements

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published