β‘ Energy-Based Modeling library for PyTorch, offering tools for π¬ sampling, π§ inference, and π learning in complex distributions.
Energy-Based Models (EBMs) offer a powerful and flexible framework for generative modeling by assigning an unnormalized probability (or "energy") to each data point. Lower energy corresponds to higher probability.
TorchEBM simplifies working with EBMs in PyTorch. It provides a suite of tools designed for researchers and practitioners, enabling efficient implementation and exploration of:
- Defining complex energy functions: Easily create custom energy landscapes using PyTorch modules.
- Training: Loss functions and procedures suitable for EBM parameter estimation including score matching and contrastive divergence variants.
- Sampling: Algorithms to draw samples from the learned distribution ( p(x) ).
For detailed documentation, including installation instructions, usage examples, and API references, please visit the π TorchEBM Website.
-
Core Components:
- Energy functions: Standard energy landscapes (Gaussian, Double Well, Rosenbrock, etc.)
- Datasets: Data generators for training and evaluation
- Loss functions: Contrastive Divergence, Score Matching, and more
- Sampling algorithms: Langevin Dynamics, Hamiltonian Monte Carlo (HMC), and more
- Evaluation metrics: Diagnostics for sampling and training
-
Performance Optimizations:
- CUDA-accelerated implementations
- Parallel sampling capabilities
- Extensive diagnostics
![]() |
![]() |
![]() |
![]() |
Gaussian Function | Double Well Function | Rastrigin Function | Rosenbrock Function |
pip install torchebm
- PyTorch (with CUDA support for optimal performance)
- Other dependencies are listed in requirements.txt
import torch
from torchebm.core import GaussianEnergy, DoubleWellEnergy
# Set device for computation
device = "cuda" if torch.cuda.is_available() else "cpu"
# Define dimensions
dim = 10
n_samples = 250
n_steps = 500
# Create a multivariate Gaussian energy function
gaussian_energy = GaussianEnergy(
mean=torch.zeros(dim, device=device), # Center at origin
cov=torch.eye(dim, device=device) # Identity covariance (standard normal)
)
# Create a double well potential
double_well_energy = DoubleWellEnergy(barrier_height=2.0)
import torch.optim as optim
from torch.utils.data import DataLoader
from torchebm.losses import ContrastiveDivergence
from torchebm.datasets import GaussianMixtureDataset
from torchebm.samplers import LangevinDynamics
# Define an NN energy model
class MLPEnergy(BaseEnergyFunction):
def __init__(self, input_dim, hidden_dim=64):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, 1),
)
def forward(self, x):
return self.net(x).squeeze(-1) # a scalar value
energy_fn = MLPEnergy(input_dim=2).to(device)
sampler = LangevinDynamics(energy_function=energy_fn, step_size=0.01, device=device)
cd_loss_fn = ContrastiveDivergence(
energy_function=energy_fn,
sampler=sampler,
k_steps=10 # MCMC steps for negative samples gen
)
optimizer = optim.Adam(energy_fn.parameters(), lr=0.001)
mixture_dataset = GaussianMixtureDataset(n_samples=500, n_components=4, std=0.1, seed=123).get_data()
dataloader = DataLoader(mixture_dataset, batch_size=32, shuffle=True)
# Training Loop
for epoch in range(10):
epoch_loss = 0.0
for i, batch_data in enumerate(dataloader):
batch_data = batch_data.to(device)
optimizer.zero_grad()
loss, neg_samples = cd_loss(batch_data)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
avg_loss = epoch_loss / len(dataloader)
print(f"Epoch {epoch + 1}/{EPOCHS}, Loss: {avg_loss:.6f}")
from torchebm.samplers import HamiltonianMonteCarlo
# Define a 10-D Gaussian energy function
energy_fn = GaussianEnergy(mean=torch.zeros(10), cov=torch.eye(10))
# Initialize HMC sampler
hmc_sampler = HamiltonianMonteCarlo(
energy_function=energy_fn, step_size=0.1, n_leapfrog_steps=10, device=device
)
# Sample 10,000 points in 10 dimensions
final_samples = hmc_sampler.sample(
dim=10, n_steps=500, n_samples=10000, return_trajectory=False
)
print(final_samples.shape) # Result batch_shape: (10000, 10) - (n_samples, dim)
# Sample with diagnostics and trajectory
final_samples, diagnostics = hmc_sampler.sample(
n_samples=n_samples,
n_steps=n_steps,
dim=dim,
return_trajectory=True,
return_diagnostics=True,
)
print(final_samples.shape) # Trajectory batch_shape: (250, 500, 10) - (n_samples, k_steps, dim)
print(diagnostics.shape) # Diagnostics batch_shape: (500, 4, 250, 10) - (k_steps, 4, n_samples, dim)
# The diagnostics contain: Mean (dim=0), Variance (dim=1), Energy (dim=2), Acceptance rates (dim=3)
# Sample from a custom initialization
x_init = torch.randn(n_samples, dim, dtype=torch.float32, device=device)
samples = hmc_sampler.sample(x=x_init, n_steps=100)
print(samples.shape) # Result batch_shape: (250, 10) -> (n_samples, dim)
torchebm/
βββ core/ # Core functionality
β βββ energy_function.py # Energy function definitions
β βββ basesampler.py # Base sampler class
β βββ ...
βββ samplers/ # Sampling algorithms
β βββ langevin_dynamics.py # Langevin dynamics implementation
β βββ mcmc.py # HMC implementation
β βββ ...
βββ models/ # Neural network models
βββ evaluation/ # Evaluation metrics and utilities
βββ datasets/
β βββ generators.py # Data generators for training
βββ losses/ # BaseLoss functions for training
βββ utils/ # Utility functions
βββ cuda/ # CUDA optimizations
![]() |
![]() |
![]() |
Langevin Dynamics Sampling | Single Langevin Dynamics Trajectory | Parallel Langevin Dynamics Sampling |
Check out the examples/
directory for sample scripts:
samplers/
: Demonstrates different sampling algorithmsdatasets/
: Depicts data generation using built-in datasetstraining_models/
: Shows how to train energy-based models using TorchEBMvisualization/
: Visualizes sampling results and trajectories- and more!
Contributions are welcome! Step-by-step instructions for contributing to the project can be found on the contributing.md page on the website.
Please check the issues page for current tasks or create a new issue to discuss proposed changes.
Please βοΈ this repository if β TorchEBM helped you and spread the word.
Thank you! π
If you use β TorchEBM in your research, please cite it using the following BibTeX entry:
@misc{torchebm_library_2025,
author = {Ghaderi, Soran and Contributors},
title = {{TorchEBM}: A PyTorch Library for Training Energy-Based Models},
year = {2025},
url = {https://github.com/soran-ghaderi/torchebm},
}
For a detailed list of changes between versions, please see our CHANGELOG.
This project is licensed under the MIT License - see the LICENSE file for details.
If you are interested in collaborating on research projects (diffusion-/flow-/energy-based models) or have any questions about the library, please feel free to reach out. I am open to discussions and collaborations that can enhance the capabilities of β TorchEBM π and contribute to the field of generative modeling.