Skip to content

Wanghley/stepdrop-tiny-diffusion

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

109 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

Contributors Forks Stargazers Issues MIT License LinkedIn Buy Me a Coffee


StepDrop Demo

StepDrop

Stochastic Step Skipping in Tiny Diffusion Models
Explore the Demo Notebook Β»

Table of Contents
  1. About The Project
  2. Built With
  3. Getting Started
  4. Quick Start
  5. Pipeline Script
  6. Training
  7. Sampling
  8. Evaluation & Benchmarking
  9. Interpreting Metrics
  10. StepDrop Skip Strategies
  11. Visualization Utilities
  12. Project Structure
  13. Roadmap
  14. Contributing
  15. License
  16. Contact
  17. Acknowledgments

About The Project

StepDrop is a novel sampling method designed to accelerate inference in diffusion models, particularly tiny ones. By introducing a stochastic step skipping technique, it significantly reduces the number of required sampling steps while maintaining high-quality image generation.

This repository contains the official implementation, experiments, and demo notebooks for the StepDrop project.

(back to top)

Built With

Python

PyTorch

Hugging Face

Jupyter

(back to top)

Getting Started

Prerequisites

  • Python 3.8+
  • pip or conda
  • CUDA-compatible GPU (recommended)

Installation

  1. Clone the repository

    git clone https://github.com/wanghley/stepdrop-tiny-diffusion.git
    cd stepdrop-tiny-diffusion
  2. Create a virtual environment (recommended)

    python -m venv venv
    source venv/bin/activate  # On Windows: venv\Scripts\activate
  3. Install dependencies

    pip install -r requirements.txt
  4. Verify installation

    python scripts/checklibs.py

(back to top)

Quick Start

Interactive Menu

The easiest way to get started is with the interactive quick start menu:

./scripts/quick_start.sh

This provides a menu-driven interface for common tasks like training, sampling, and benchmarking.

One-Command Pipeline

Run the full pipeline (train β†’ sample β†’ evaluate) with a single command:

chmod +x scripts/pipeline.sh
./scripts/pipeline.sh --all --dataset cifar10 --epochs 10 --eval-samples 1000

Quick Test

For a fast sanity check on MNIST:

./scripts/pipeline.sh --all --dataset mnist --epochs 5 --n-samples 16 --eval-samples 100

(back to top)

Pipeline Script

The main automation tool is scripts/pipeline.sh. It orchestrates training, sampling, and evaluation.

Usage

./scripts/pipeline.sh [OPTIONS]

Pipeline Stages

Flag Description
--train Run training stage
--sample Run sampling stage
--evaluate Run evaluation/benchmarking
--all Run all stages (train β†’ sample β†’ evaluate)
--clean Clean generated files

Common Options

Option Default Description
--dataset cifar10 Dataset: mnist, cifar10, custom
--epochs 50 Training epochs
--batch-size 128 Training batch size
--base-channels 64 U-Net base channels
--checkpoint auto Path to model checkpoint
--n-samples 64 Number of samples to generate
--method ddim Sampling method: ddpm, ddim, stepdrop
--eval-samples 1000 Samples for FID/IS evaluation
--device cuda Device: cuda or cpu

Examples

# Full CIFAR-10 training with evaluation
./scripts/pipeline.sh --all --dataset cifar10 --epochs 100 --base-channels 128 --eval-samples 5000

# Train only on MNIST
./scripts/pipeline.sh --train --dataset mnist --epochs 20

# Sample with DDIM from existing checkpoint
./scripts/pipeline.sh --sample --checkpoint checkpoints/model.pt --method ddim --ddim-steps 50 --n-samples 64

# Sample with StepDrop
./scripts/pipeline.sh --sample --checkpoint checkpoints/model.pt --method stepdrop --skip-prob 0.3 --skip-strategy linear

# Evaluate with full metrics
./scripts/pipeline.sh --evaluate --checkpoint checkpoints/model.pt --eval-samples 5000 --full-metrics

# Compare StepDrop strategies against DDIM baselines
./scripts/pipeline.sh --evaluate --checkpoint checkpoints/model.pt --compare-stepdrop --eval-samples 1000

# Dry run (show commands without executing)
./scripts/pipeline.sh --all --dataset mnist --epochs 5 --dry-run

(back to top)

Training

Direct Training Script

python src/train.py --dataset cifar10 --epochs 50 --batch_size 128

Training Options

Argument Default Description
--dataset mnist Dataset: mnist, cifar10, custom
--custom_data_dir None Path to custom images folder
--img_size 28 Image size
--channels 1 Number of image channels
--batch_size 128 Training batch size
--epochs 20 Number of epochs
--lr 2e-4 Learning rate
--n_timesteps 1000 Diffusion timesteps
--schedule_type cosine Noise schedule: linear, cosine
--base_channels 64 U-Net base channels
--save_path checkpoints/model.pt Model save path
--resume None Resume from checkpoint

Resume Training

python src/train.py --resume checkpoints/checkpoint_epoch_50.pt --epochs 100

(back to top)

Sampling

Direct Sampling Script

python src/sample.py --checkpoint checkpoints/model.pt --method ddim --ddim_steps 50 --n_samples 16

Sampling Methods

Method Command Description
DDPM --method ddpm Full 1000 steps, highest quality
DDIM --method ddim --ddim_steps 50 Accelerated, deterministic
StepDrop --method stepdrop --skip_prob 0.3 Stochastic step skipping
Adaptive StepDrop --method adaptive_stepdrop Error-based dynamic skipping

Sampling Options

Argument Default Description
--checkpoint required Path to trained model
--method ddpm Sampling method
--n_samples 16 Number of samples
--ddim_steps 50 DDIM inference steps
--ddim_eta 0.0 DDIM stochasticity (0 = deterministic)
--skip_prob 0.3 StepDrop skip probability
--skip_strategy linear StepDrop strategy
--output_dir samples Output directory
--save_grid True Save as image grid
--save_individual False Save individual images

Examples

# DDPM (best quality, slow)
python src/sample.py --checkpoint checkpoints/model.pt --method ddpm --n_samples 16

# DDIM (fast)
python src/sample.py --checkpoint checkpoints/model.pt --method ddim --ddim_steps 25 --n_samples 64

# StepDrop with linear strategy
python src/sample.py --checkpoint checkpoints/model.pt --method stepdrop --skip_prob 0.3 --skip_strategy linear

# StepDrop with quadratic strategy (more aggressive)
python src/sample.py --checkpoint checkpoints/model.pt --method stepdrop --skip_prob 0.5 --skip_strategy quadratic

# Adaptive StepDrop
python src/sample.py --checkpoint checkpoints/model.pt --method adaptive_stepdrop

(back to top)

Evaluation & Benchmarking

Benchmark Script

Run comprehensive benchmarks comparing different sampling strategies:

# Quick test with dummy model
python scripts/benchmark_strategies.py --dummy --samples 10

# Full benchmark with trained model
python scripts/benchmark_strategies.py --checkpoint checkpoints/model.pt --samples 5000

# With full metrics (FID, KID, IS, Precision, Recall, LPIPS, SSIM, PSNR, Vendi)
python scripts/benchmark_strategies.py --checkpoint checkpoints/model.pt --samples 5000 --full-metrics

Via Pipeline

# Basic evaluation
./scripts/pipeline.sh --evaluate --checkpoint checkpoints/model.pt --eval-samples 1000

# Full metrics
./scripts/pipeline.sh --evaluate --checkpoint checkpoints/model.pt --eval-samples 5000 --full-metrics

# Compare all StepDrop strategies vs DDIM
./scripts/pipeline.sh --evaluate --checkpoint checkpoints/model.pt --compare-stepdrop

# Evaluate only StepDrop variants
./scripts/pipeline.sh --evaluate --checkpoint checkpoints/model.pt --stepdrop-only

# Specific strategies only
./scripts/pipeline.sh --evaluate --checkpoint checkpoints/model.pt \
  --strategies "DDIM_50,StepDrop_Linear_0.3,StepDrop_Quadratic_0.3"

Output

Results are saved to results/<timestamp>/:

  • report.json - Full metrics data
  • report.csv - Summary for Excel/Sheets
  • *.png - Auto-generated plots (Pareto frontier, radar charts, etc.)
  • samples/ - Generated sample images per strategy

(back to top)

Interpreting Metrics

Metric Full Name Goal Description
FID FrΓ©chet Inception Distance πŸ“‰ Lower is better Similarity to real dataset. <10: excellent, 10-30: good, >50: poor
IS Inception Score πŸ“ˆ Higher is better Clarity and diversity. CIFAR-10 real data β‰ˆ 11.0
KID Kernel Inception Distance πŸ“‰ Lower is better Similar to FID, less biased for small samples
Precision - πŸ“ˆ Higher is better Quality: are generated images realistic?
Recall - πŸ“ˆ Higher is better Diversity: does the model cover the data distribution?
LPIPS Perceptual Similarity πŸ“‰ Lower is better Perceptual distance (diversity among samples)
Throughput Images/Second πŸ“ˆ Higher is better Generation speed
NFE Number of Function Evaluations πŸ“‰ Lower is better U-Net forward passes per image

(back to top)

StepDrop Skip Strategies

Probability-Based Strategies (StepDropSampler)

Strategy Formula Description
constant $P(t) = p$ Fixed skip probability
linear $P(t) = p \cdot 4t(1-t)$ Parabolic peak at middle
cosine_sq $P(t) = p \cdot \sin^2(\pi t)$ Smooth cosine curve
quadratic $P(t) = p \cdot 16t^2(1-t)^2$ Sharper middle peak
early_skip $P(t) = p \cdot t$ Skip more at high noise
late_skip $P(t) = p \cdot (1-t)$ Skip more at low noise
critical_preserve Variable Protect [0.3, 0.7] interval

Adaptive Strategy (AdaptiveStepDropSampler)

Dynamically adjusts skipping based on reconstruction error:

  • Low error β†’ skip more aggressively
  • High error β†’ force denoising steps

Target NFE Strategy (TargetNFEStepDropSampler)

Targets a specific step budget:

  • uniform - Evenly spaced (like DDIM)
  • importance - More steps at start/end
  • stochastic - Random with boundary protection

(back to top)

Visualization Utilities

Generate Comparison Grid

python scripts/generate_grid.py

Output: results/comparison_grid.png - Side-by-side DDPM vs DDIM vs StepDrop

Visualize Schedules

python scripts/plot_schedules.py --save_path results/schedules.png

Output: Probability curves and step sizes for different strategies

Benchmark Plots

python scripts/plot_results.py --results results/2025-12-07_12-00-00/

Output: Pareto frontiers, radar charts, metric comparisons

Denoising Evolution

python scripts/plot_denoising_evolution.py

Output: results/plot_denoising_evolution.png - Film strip showing denoising progression

Efficiency Plots

python scripts/plot_efficiency.py --results results/

Output: FLOPs/Memory analysis

(back to top)

Project Structure

stepdrop-tiny-diffusion/
β”œβ”€β”€ src/
β”‚   β”œβ”€β”€ config.py          # Configuration management
β”‚   β”œβ”€β”€ dataset.py         # Data loading (MNIST, CIFAR-10, custom)
β”‚   β”œβ”€β”€ modules.py         # U-Net architecture
β”‚   β”œβ”€β”€ scheduler.py       # Noise schedules
β”‚   β”œβ”€β”€ train.py           # Training script
β”‚   β”œβ”€β”€ sample.py          # Sampling script
β”‚   β”œβ”€β”€ sampler/           # Sampler implementations
β”‚   β”‚   β”œβ”€β”€ DDPM.py
β”‚   β”‚   β”œβ”€β”€ DDIM.py
β”‚   β”‚   β”œβ”€β”€ StepDrop.py
β”‚   β”‚   └── AdaptiveStepDrop.py
β”‚   └── eval/              # Evaluation metrics
β”‚       └── metrics_utils.py
β”œβ”€β”€ scripts/
β”‚   β”œβ”€β”€ pipeline.sh        # Main automation script
β”‚   β”œβ”€β”€ quick_start.sh     # Interactive menu
β”‚   β”œβ”€β”€ benchmark_strategies.py
β”‚   β”œβ”€β”€ plot_results.py
β”‚   β”œβ”€β”€ plot_schedules.py
β”‚   β”œβ”€β”€ plot_denoising_evolution.py
β”‚   └── generate_grid.py
β”œβ”€β”€ notebooks/             # Jupyter notebooks
β”œβ”€β”€ checkpoints/           # Saved models
β”œβ”€β”€ samples/               # Generated samples
β”œβ”€β”€ results/               # Benchmark results
└── docs/                  # Documentation

(back to top)

HPC / SLURM Support

For cluster environments:

# Submit job to SLURM
sbatch scripts/run_pipeline.slurm

# With custom arguments
sbatch scripts/run_pipeline.slurm --train --dataset cifar10 --epochs 100

(back to top)

Roadmap

  • Core StepDrop sampler implementation
  • Pipeline automation script
  • Comprehensive benchmarking suite
  • Multiple skip strategies
  • Example notebook for Tiny Diffusion
  • Example notebook for Stable Diffusion 1.5
  • Package as pip-installable library
  • Integration with HuggingFace Diffusers
  • Support for more diffusion schedulers

See open issues for proposed features and known issues.

(back to top)

Contributing

Contributions are welcome!

  1. Fork the Project
  2. Create your Feature Branch (git checkout -b feature/AmazingFeature)
  3. Commit your Changes (git commit -m 'Add some AmazingFeature')
  4. Push to the Branch (git push origin feature/AmazingFeature)
  5. Open a Pull Request

(back to top)

License

Distributed under the MIT License. See LICENSE for more information.

(back to top)

Contact

Wanghley Soares Martins - @wanghley - me@wanghley.com

Nicolas Vasilescu - @NicolasVasilescu

Project Link: https://github.com/wanghley/stepdrop-tiny-diffusion

(back to top)

Acknowledgments

(back to top)

About

We introduce a stochastic step skipping method designed to accelerate inference in tiny diffusion models, enabling faster generation on resource-constrained devices.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors