The Code of NeurIPS 2024 Submitted paper: Theoretical Analysis of Diffusion Models Under Class Imbalance
This repo contains the PyTorch implementation for [Theoretical Analysis of Diffusion Models Under Class Imbalance].
Diffusion models have demonstrated remarkable success in generating high-quality data across various domains. However, their performance deteriorates when trained on class-imbalanced datasets, particularly for underrepresented classes. In this paper, we present a theoretical analysis of diffusion models under class imbalance, focusing on the geometric properties of data representations in the latent space. We introduce the Label Chaos Entropy (LCE) metric to quantify the uncertainty in the label distributions of generated samples and establish a connection between LCE and the geometric overlap of data representations on a hyper-sphere. Through rigorous theorems, we prove that the posterior distribution used in DDPM amplifies class bias under imbalanced conditions, exacerbating the generation of geometrically overlapping samples. Based on these theoretical insights, we propose a novel regularization term that minimizes the geometric overlap between generated samples of different classes, thereby mitigating the impact of class imbalance on the generation quality. Experimental results on benchmark datasets validate the effectiveness of our approach, particularly for underrepresented classes. Our work provides a solid theoretical foundation for understanding and addressing class imbalance in diffusion models, opening avenues for further research in this direction.
The repo is implemented based on https://github.com/w86763777/pytorch-ddpm. Currently it supports the training for four datasets namely CIFAR10(LT) and CIFAR100(LT).
- Regular (conditional or unconditional) diffusion model training
- Class-balancing model training
- Class-balancing model finetuning based on a regular diffusion model
We provide mainly the scripts for trianing and evaluating the CIFAR100LT dataset. To run the code, please change the argument 'root' to the path where the dataset is downloaded.
Please find the features for cifar 100 and cifar 10 used in precision/recall/f_beta metrics. Put them in the stats folder and the codes will be ready to run. Note that those two metrics will only be evaluated if the number of samples is 50k otherwise it returns 0.
This implementation is based on / inspired by:
- https://github.com/w86763777/pytorch-ddpm
- https://github.com/crowsonkb/k-diffusion/blob/master/train.py (we refer to the implementation of ADA augmentation in K-diffusion model).