Skip to content

Commit

Permalink
Finalize 2nd order schedulers (huggingface#1503)
Browse files Browse the repository at this point in the history
* up

* up

* finish

* finish

* up

* up

* finish
  • Loading branch information
patrickvonplaten authored and Thomas Capelle committed Dec 12, 2022
1 parent 291be99 commit 6cb3811
Show file tree
Hide file tree
Showing 26 changed files with 1,020 additions and 36 deletions.
28 changes: 27 additions & 1 deletion docs/source/api/schedulers.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,33 @@ Original paper can be found [here](https://arxiv.org/abs/2206.00927) and the [im

[[autodoc]] DPMSolverMultistepScheduler

#### Heun scheduler inspired by Karras et. al paper

Algorithm 1 of [Karras et. al](https://arxiv.org/abs/2206.00364).
Scheduler ported from @crowsonkb's https://github.com/crowsonkb/k-diffusion library:

All credit for making this scheduler work goes to [Katherine Crowson](https://github.com/crowsonkb/)

[[autodoc]] HeunDiscreteScheduler

#### DPM Discrete Scheduler inspired by Karras et. al paper

Inspired by [Karras et. al](https://arxiv.org/abs/2206.00364).
Scheduler ported from @crowsonkb's https://github.com/crowsonkb/k-diffusion library:

All credit for making this scheduler work goes to [Katherine Crowson](https://github.com/crowsonkb/)

[[autodoc]] KDPM2DiscreteScheduler

#### DPM Discrete Scheduler with ancestral sampling inspired by Karras et. al paper

Inspired by [Karras et. al](https://arxiv.org/abs/2206.00364).
Scheduler ported from @crowsonkb's https://github.com/crowsonkb/k-diffusion library:

All credit for making this scheduler work goes to [Katherine Crowson](https://github.com/crowsonkb/)

[[autodoc]] KDPM2AncestralDiscreteScheduler

#### Variance exploding, stochastic sampling from Karras et. al

Original paper can be found [here](https://arxiv.org/abs/2006.11239).
Expand All @@ -86,7 +113,6 @@ Original paper can be found [here](https://arxiv.org/abs/2006.11239).

Original implementation can be found [here](https://arxiv.org/abs/2206.00364).


[[autodoc]] LMSDiscreteScheduler

#### Pseudo numerical methods for diffusion models (PNDM)
Expand Down
18 changes: 14 additions & 4 deletions examples/community/sd_text2img_k_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.utils import is_accelerate_available, logging
from k_diffusion.external import CompVisDenoiser
from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand All @@ -33,7 +33,12 @@ def __init__(self, model, alphas_cumprod):
self.alphas_cumprod = alphas_cumprod

def apply_model(self, *args, **kwargs):
return self.model(*args, **kwargs).sample
if len(args) == 3:
encoder_hidden_states = args[-1]
args = args[:2]
if kwargs.get("cond", None) is not None:
encoder_hidden_states = kwargs.pop("cond")
return self.model(*args, encoder_hidden_states=encoder_hidden_states, **kwargs).sample


class StableDiffusionPipeline(DiffusionPipeline):
Expand Down Expand Up @@ -63,6 +68,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]

def __init__(
self,
Expand Down Expand Up @@ -99,7 +105,10 @@ def __init__(
)

model = ModelWrapper(unet, scheduler.alphas_cumprod)
self.k_diffusion_model = CompVisDenoiser(model)
if scheduler.prediction_type == "v_prediction":
self.k_diffusion_model = CompVisVDenoiser(model)
else:
self.k_diffusion_model = CompVisDenoiser(model)

def set_sampler(self, scheduler_type: str):
library = importlib.import_module("k_diffusion")
Expand Down Expand Up @@ -417,6 +426,7 @@ def __call__(
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=text_embeddings.device)
sigmas = self.scheduler.sigmas
sigmas = sigmas.to(text_embeddings.dtype)

# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
Expand All @@ -437,7 +447,7 @@ def __call__(
def model_fn(x, t):
latent_model_input = torch.cat([x] * 2)

noise_pred = self.k_diffusion_model(latent_model_input, t, encoder_hidden_states=text_embeddings)
noise_pred = self.k_diffusion_model(latent_model_input, t, cond=text_embeddings)

noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
HeunDiscreteScheduler,
IPNDMScheduler,
KarrasVeScheduler,
KDPM2AncestralDiscreteScheduler,
KDPM2DiscreteScheduler,
PNDMScheduler,
RePaintScheduler,
SchedulerMixin,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ def __call__(
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

# call the callback, if provided
if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ def __call__(
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

# call the callback, if provided
if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,7 @@ def __call__(
).prev_sample

# call the callback, if provided
if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ def __call__(
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

# call the callback, if provided
if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def __call__(
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

# call the callback, if provided
if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,7 @@ def __call__(
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

# call the callback, if provided
if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,7 @@ def __call__(
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

# call the callback, if provided
if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ def __call__(
latents = (init_latents_proper * mask) + (latents * (1 - mask))

# call the callback, if provided
if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ def __call__(
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

# call the callback, if provided
if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ def __call__(
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

# call the callback, if provided
if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
Expand Down
4 changes: 3 additions & 1 deletion src/diffusers/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
from .scheduling_euler_discrete import EulerDiscreteScheduler
from .scheduling_heun import HeunDiscreteScheduler
from .scheduling_heun_discrete import HeunDiscreteScheduler
from .scheduling_ipndm import IPNDMScheduler
from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler
from .scheduling_k_dpm_2_discrete import KDPM2DiscreteScheduler
from .scheduling_karras_ve import KarrasVeScheduler
from .scheduling_pndm import PNDMScheduler
from .scheduling_repaint import RePaintScheduler
Expand Down
8 changes: 4 additions & 4 deletions src/diffusers/schedulers/scheduling_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
prediction_type (`str`, default `epsilon`):
indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`.
`v-prediction` is not supported for this scheduler.
prediction_type (`str`, default `epsilon`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://imagen.research.google/video/paper.pdf)
"""

_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
Expand Down
7 changes: 4 additions & 3 deletions src/diffusers/schedulers/scheduling_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability.
prediction_type (`str`, default `epsilon`):
indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`.
`v-prediction` is not supported for this scheduler.
prediction_type (`str`, default `epsilon`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://imagen.research.google/video/paper.pdf)
"""

_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
Expand Down
7 changes: 4 additions & 3 deletions src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
solver_order (`int`, default `2`):
the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided
sampling, and `solver_order=3` for unconditional sampling.
prediction_type (`str`, default `epsilon`):
indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`,
or `v-prediction`.
prediction_type (`str`, default `epsilon`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://imagen.research.google/video/paper.pdf)
thresholding (`bool`, default `False`):
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
`linear` or `scaled_linear`.
trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
prediction_type (`str`, default `epsilon`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://imagen.research.google/video/paper.pdf)
"""

Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/schedulers/scheduling_euler_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
`linear` or `scaled_linear`.
trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
prediction_type (`str`, default `epsilon`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://imagen.research.google/video/paper.pdf)
"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,16 @@

class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
Args:
Implements Algorithm 2 (Heun steps) from Karras et al. (2022). for discrete beta schedules. Based on the original
k-diffusion implementation by Katherine Crowson:
https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L90
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the
starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`):
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
Expand All @@ -40,7 +42,10 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
prediction_type (`str`, default `epsilon`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://imagen.research.google/video/paper.pdf)
"""

_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
Expand Down Expand Up @@ -77,7 +82,7 @@ def __init__(
def index_for_timestep(self, timestep):
indices = (self.timesteps == timestep).nonzero()
if self.state_in_first_order:
pos = 0 if indices.shape[0] < 2 else 1
pos = -1
else:
pos = 0
return indices[pos].item()
Expand Down Expand Up @@ -132,7 +137,7 @@ def set_timesteps(
self.init_noise_sigma = self.sigmas.max()

timesteps = torch.from_numpy(timesteps)
timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2), timesteps[-1:]])
timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)])

if str(device).startswith("mps"):
# mps does not support float64
Expand Down Expand Up @@ -199,9 +204,9 @@ def step(
)

if self.state_in_first_order:
# 2. Convert to an ODE derivative
# 2. Convert to an ODE derivative for 1st order
derivative = (sample - pred_original_sample) / sigma_hat
# 3. 1st order derivative
# 3. delta timestep
dt = sigma_next - sigma_hat

# store for 2nd order step
Expand All @@ -213,7 +218,7 @@ def step(
derivative = (sample - pred_original_sample) / sigma_next
derivative = (self.prev_derivative + derivative) / 2

# 3. Retrieve 1st order derivative
# 3. take prev timestep & sample
dt = self.dt
sample = self.sample

Expand Down

0 comments on commit 6cb3811

Please sign in to comment.