Skip to content

Official implementation for 'Class-Balancing Diffusion Models'

Notifications You must be signed in to change notification settings

qym7/CBDM-pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

9 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

CBDM: Class-Balancing Diffusion Models

This repo contains the PyTorch implementation for [Class-Balancing Diffusion Models], by Yiming Qin, Huangjie Zheng, Jiangchao Yao, Mingyuan Zhou, and Ya Zhang.

Diffusion-based models have shown the merits of generating high-quality visual data while preserving better diversity in recent studies. However, such observation is only justified with curated data distribution, where the data samples are nicely pre-processed to be uniformly distributed in terms of their labels. In practice, a long-tailed data distribution appears more common and how diffusion models perform on such class-imbalanced data remains unknown. In this work, we first investigate this problem and observe significant degradation in both diversity and fidelity when the diffusion model is trained on datasets with class-imbalanced distributions. Especially in tail classes, the generations largely lose diversity and we observe severe mode-collapse issues. To tackle this problem, we set from the hypothesis that the data distribution is not class-balanced, and propose Class-Balancing Diffusion Models (CBDM) that are trained with a distribution adjustment regularizer as a solution. Experiments show that images generated by CBDM exhibit higher diversity and quality in both quantitative and qualitative ways. Our method benchmarked the generation results on CIFAR100/CIFAR100LT dataset and shows outstanding performance on the downstream recognition task.

About this repository

The repo is implemented based on https://github.com/w86763777/pytorch-ddpm. Currently it supports the training for four datasets namely CIFAR10(LT) and CIFAR100(LT) under following three mechanisms:

  1. Regular (conditional or unconditional) diffusion model training
  2. Class-balancing model training
  3. Class-balancing model finetuning based on a regular diffusion model

Running the Experiments

We provide mainly the scripts for trianing and evaluating the CIFAR100LT dataset. To run the code, please change the argument 'root' to the path where the dataset is downloaded.

Files used in evaluation

Please find the features for cifar 100 and cifar 10 used in precision/recall/f_beta metrics. Put them in the stats folder and the codes will be ready to run. Note that those two metrics will only be evaluated if the number of samples is 50k otherwise it returns 0.

Train a model

  • Regular conditional diffusion model training, supporting the classifier-free guidance (cfg) sampling

    python main.py --train  \
            --flagfile ./config/cifar100.txt --parallel \
            --logdir ./logs/cifar100lt_ddpm --total_steps 300001 \
            --conditional \
            --data_type cifar100lt --imb_factor 0.01 --img_size 32 \
            --batch_size 64 --save_step 100000 --sample_step 50000 \
            --cfg
    
  • Class-balancing model training without ADA augmentation

    python main.py --train  \
            --flagfile ./config/cifar100.txt --parallel \
            --logdir ./logs/cifar100lt_cbdm --total_steps 300001 \
            --conditional \
            --data_type cifar100lt --imb_factor 0.01 --img_size 32 \
            --batch_size 48 --save_step 100000 --sample_step 50000 \
            --cb --tau 1.0
    
  • Class-balancing model training with ADA augmentation

    python main.py --train  \
            --flagfile ./config/cifar100.txt --parallel \
            --logdir ./logs/cifar100lt_cbdm_augm --total_steps 500001 \
            --conditional \
            --data_type cifar100lt --imb_factor 0.01 --img_size 32 \
            --batch_size 48 --save_step 100000 --sample_step 50000 \
            --cb --tau 1.0 --augm
    
  • Class-balancing model finetuning: finetune a DDPM model(ckpt of 200000 steps which the classifier-free guidance (cfg) sampling) based on CBDM approach

    python main.py --train  \
            --flagfile ./config/cifar100.txt --parallel \
            --logdir ./logs/cifar100lt_ddpm --total_steps 100001 \
            --conditional \
            --data_type cifar100lt --imb_factor 0.01 --img_size 32 \
            --batch_size 48 --save_step 50000 --sample_step 50000 \
            --cb --tau 1.0 \
            --finetune --finetuned_logdir cifar100lt_cbdm_finetune --ckpt_step 200000
    

Evaluate a model

  • Sample images and evaluate for the 4 models above.

    python main.py \
        --flagfile ./logs/cifar100lt_ddpm/flagfile.txt \
        --logdir ./logs/cifar100lt_ddpm \
        --fid_cache ./stats/cifar100.train.npz \
        --ckpt_step 200000 \
        --num_images 50000 --batch_size 64 \
        --notrain \
        --eval \
        --sample_method cfg  --omega 0.8
    
    python main.py \
        --flagfile ./logs/cifar100lt_cbdm/flagfile.txt \
        --logdir ./logs/cifar100lt_cbdm \
        --fid_cache ./stats/cifar100.train.npz \
        --ckpt_step 300000 \
        --num_images 50000 --batch_size 64 \
        --notrain \
        --eval \
        --sample_method cfg  --omega 1.6
    
    python main.py \
        --flagfile ./logs/cifar100lt_cbdm_augm/flagfile.txt \
        --logdir ./logs/cifar100lt_cbdm_augm \
        --fid_cache ./stats/cifar100.train.npz \
        --ckpt_step 500000 \
        --num_images 50000 --batch_size 192 \
        --notrain \
        --eval \
        --sample_method cfg  --omega 1.4
    
    python main.py \
        --flagfile ./logs/cifar100lt_cbdm_finetune/flagfile.txt \
        --logdir ./logs/cifar100lt_cbdm_finetune \
        --fid_cache ./stats/cifar100.train.npz \
        --ckpt_step 250000 \
        --num_images 50000 --batch_size 512 \
        --notrain \
        --eval \
        --sample_method cfg  --omega 2.0
    

References

If you find the code useful for your research, please consider citing

@inproceedings{qin2023class,
  title={Class-balancing diffusion models},
  author={Qin, Yiming and Zheng, Huangjie and Yao, Jiangchao and Zhou, Mingyuan and Zhang, Ya},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  year={2023}
}

Acknowledgements

This implementation is based on / inspired by:

About

Official implementation for 'Class-Balancing Diffusion Models'

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 4

  •  
  •  
  •  
  •  

Languages