Skip to content

[Arxiv 2024] The Official code of "Theoretical Analysis of Diffusion Models Under Class Imbalance."

License

Notifications You must be signed in to change notification settings

yanliang3612/DiffROP

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

The Code of NeurIPS 2024 Submitted paper: Theoretical Analysis of Diffusion Models Under Class Imbalance

This repo contains the PyTorch implementation for [Theoretical Analysis of Diffusion Models Under Class Imbalance].

Diffusion models have demonstrated remarkable success in generating high-quality data across various domains. However, their performance deteriorates when trained on class-imbalanced datasets, particularly for underrepresented classes. In this paper, we present a theoretical analysis of diffusion models under class imbalance, focusing on the geometric properties of data representations in the latent space. We introduce the Label Chaos Entropy (LCE) metric to quantify the uncertainty in the label distributions of generated samples and establish a connection between LCE and the geometric overlap of data representations on a hyper-sphere. Through rigorous theorems, we prove that the posterior distribution used in DDPM amplifies class bias under imbalanced conditions, exacerbating the generation of geometrically overlapping samples. Based on these theoretical insights, we propose a novel regularization term that minimizes the geometric overlap between generated samples of different classes, thereby mitigating the impact of class imbalance on the generation quality. Experimental results on benchmark datasets validate the effectiveness of our approach, particularly for underrepresented classes. Our work provides a solid theoretical foundation for understanding and addressing class imbalance in diffusion models, opening avenues for further research in this direction.

About this repository

The repo is implemented based on https://github.com/w86763777/pytorch-ddpm. Currently it supports the training for four datasets namely CIFAR10(LT) and CIFAR100(LT).

  1. Regular (conditional or unconditional) diffusion model training
  2. Class-balancing model training
  3. Class-balancing model finetuning based on a regular diffusion model

Running the Experiments

We provide mainly the scripts for trianing and evaluating the CIFAR100LT dataset. To run the code, please change the argument 'root' to the path where the dataset is downloaded.

Files used in evaluation

Please find the features for cifar 100 and cifar 10 used in precision/recall/f_beta metrics. Put them in the stats folder and the codes will be ready to run. Note that those two metrics will only be evaluated if the number of samples is 50k otherwise it returns 0.

Acknowledgements

This implementation is based on / inspired by:

About

[Arxiv 2024] The Official code of "Theoretical Analysis of Diffusion Models Under Class Imbalance."

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published