Skip to content

wustl-cig/parameter_interpolation

Repository files navigation

Deep Parameter Interpolation for Scalar Conditioning

cover-img

Abstract

We propose deep parameter interpolation (DPI), a general-purpose method for transforming an existing deep neural network architecture into one that accepts an additional scalar input. Recent deep generative models, including diffusion models and flow matching, employ a single neural network to learn a time- or noise level-dependent vector field. Designing a network architecture to accurately represent this vector field is challenging because the network must integrate information from two different sources: a high-dimensional vector (usually an image) and a scalar. Common approaches either encode the scalar as an additional image input or combine scalar and vector information in specific network components, which restricts architecture choices. Instead, we propose to maintain two learnable parameter sets within a single network and to introduce the scalar dependency by dynamically interpolating between the parameter sets based on the scalar value during training and sampling. DPI is a simple, architecture-agnostic method for adding scalar dependence to a neural network. We demonstrate that our method improves denoising performance and enhances sample quality for both diffusion and flow matching models, while achieving computational efficiency comparable to standard scalar conditioning techniques.

⚙️ Environment setup

Create and activate virtual environment

cd parameter_interpolation

conda create -n PI python=3.9.19

conda activate PI

conda install -c conda-forge mpi4py mpich

pip install -r requirements.txt

🚀 Training Model and Generating Images with Parameter Interpolation

Step 1: Choose a configuration file from configs directory:

🧠 Config for training

  • configs_ours/train/ADM/diffusion_adm_ours_ffhq_img64.yaml
  • configs_ours/train/ADM/flow_adm_ours_ffhq_img64.yaml
  • configs_ours/train/DRUNet/diffusion_adm_ours_ffhq_img64.yaml
  • configs_ours/train/DRUNet/flow_adm_ours_ffhq_img64.yaml

🎨 Config for image generation

  • configs_ours/generation/ADM/diffusion_adm_ours_ffhq_img64.yaml
  • configs_ours/generation/ADM/flow_adm_ours_ffhq_img64.yaml
  • configs_ours/generation/DRUNet/diffusion_adm_ours_ffhq_img64.yaml
  • configs_ours/generation/DRUNet/flow_adm_ours_ffhq_img64.yaml
! Baseline configurations used in the paper are available under `configs_baselines`

Step 2: Run training or image generation

🧠 Training

# For training
python first_train.py --task_config configs/{TASK_YAML_FILE_NAME}.yaml    # example code: python first_train.py --task_config configs_ours/train/ADM/diffusion_adm_ours_ffhq_img64.yaml

🎨 Image generation

# For generation
python first_sample.py --task_config configs/{TASK_YAML_FILE_NAME}.yaml    # example code: python first_sample.py --task_config configs_ours/generation/ADM/diffusion_adm_ours_ffhq_img64.yaml

Implementation overview

first_train.py                                 			# Firstly called Python file for training
first_generate.py                                		# Firstly called Python file for image generation
│   
└── guided_diffusion ─── train_util.py         			# Include all training related configurations with data loading
			│   
			├── gauissian_diffusion.py  # Include specific training loss function for training
			│   
			├── nn_ours.py              # Include parameter interpolation operation
			│   
			└── unet_ours.py            # Include model's forward operation

🔑 Core Logic for parameter interpolation

These three components are central to the parameter interpolation method:

# Parameter interpolation implementation (e.g., OurConv2d)
guided_diffusion.nn_ours

# Learnable monotonic lambda function
guided_diffusion.models.ADM.diffusion_adm_ours.NormalizedSoftmaxApprox

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published