From 6cb381181b00d1cc779ae2263a04ad9594a4de26 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 2 Dec 2022 16:38:35 +0100 Subject: [PATCH] Finalize 2nd order schedulers (#1503) * up * up * finish * finish * up * up * finish --- docs/source/api/schedulers.mdx | 28 +- examples/community/sd_text2img_k_diffusion.py | 18 +- src/diffusers/__init__.py | 2 + .../alt_diffusion/pipeline_alt_diffusion.py | 2 +- .../pipeline_alt_diffusion_img2img.py | 2 +- .../pipeline_cycle_diffusion.py | 2 +- .../pipeline_stable_diffusion.py | 2 +- ...peline_stable_diffusion_image_variation.py | 2 +- .../pipeline_stable_diffusion_img2img.py | 2 +- .../pipeline_stable_diffusion_inpaint.py | 2 +- ...ipeline_stable_diffusion_inpaint_legacy.py | 2 +- .../pipeline_stable_diffusion_upscale.py | 2 +- .../pipeline_stable_diffusion_safe.py | 2 +- src/diffusers/schedulers/__init__.py | 4 +- src/diffusers/schedulers/scheduling_ddim.py | 8 +- src/diffusers/schedulers/scheduling_ddpm.py | 7 +- .../scheduling_dpmsolver_multistep.py | 7 +- .../scheduling_euler_ancestral_discrete.py | 4 + .../schedulers/scheduling_euler_discrete.py | 4 + ...ng_heun.py => scheduling_heun_discrete.py} | 23 +- .../scheduling_k_dpm_2_ancestral_discrete.py | 324 ++++++++++++++++++ .../schedulers/scheduling_k_dpm_2_discrete.py | 299 ++++++++++++++++ .../schedulers/scheduling_lms_discrete.py | 5 +- src/diffusers/schedulers/scheduling_pndm.py | 4 + src/diffusers/utils/dummy_pt_objects.py | 30 ++ tests/test_scheduler.py | 269 +++++++++++++++ 26 files changed, 1020 insertions(+), 36 deletions(-) rename src/diffusers/schedulers/{scheduling_heun.py => scheduling_heun_discrete.py} (94%) create mode 100644 src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py create mode 100644 src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py diff --git a/docs/source/api/schedulers.mdx b/docs/source/api/schedulers.mdx index 7ed527bedf3f..82c4641cd791 100644 --- a/docs/source/api/schedulers.mdx +++ b/docs/source/api/schedulers.mdx @@ -76,6 +76,33 @@ Original paper can be found [here](https://arxiv.org/abs/2206.00927) and the [im [[autodoc]] DPMSolverMultistepScheduler +#### Heun scheduler inspired by Karras et. al paper + +Algorithm 1 of [Karras et. al](https://arxiv.org/abs/2206.00364). +Scheduler ported from @crowsonkb's https://github.com/crowsonkb/k-diffusion library: + +All credit for making this scheduler work goes to [Katherine Crowson](https://github.com/crowsonkb/) + +[[autodoc]] HeunDiscreteScheduler + +#### DPM Discrete Scheduler inspired by Karras et. al paper + +Inspired by [Karras et. al](https://arxiv.org/abs/2206.00364). +Scheduler ported from @crowsonkb's https://github.com/crowsonkb/k-diffusion library: + +All credit for making this scheduler work goes to [Katherine Crowson](https://github.com/crowsonkb/) + +[[autodoc]] KDPM2DiscreteScheduler + +#### DPM Discrete Scheduler with ancestral sampling inspired by Karras et. al paper + +Inspired by [Karras et. al](https://arxiv.org/abs/2206.00364). +Scheduler ported from @crowsonkb's https://github.com/crowsonkb/k-diffusion library: + +All credit for making this scheduler work goes to [Katherine Crowson](https://github.com/crowsonkb/) + +[[autodoc]] KDPM2AncestralDiscreteScheduler + #### Variance exploding, stochastic sampling from Karras et. al Original paper can be found [here](https://arxiv.org/abs/2006.11239). @@ -86,7 +113,6 @@ Original paper can be found [here](https://arxiv.org/abs/2006.11239). Original implementation can be found [here](https://arxiv.org/abs/2206.00364). - [[autodoc]] LMSDiscreteScheduler #### Pseudo numerical methods for diffusion models (PNDM) diff --git a/examples/community/sd_text2img_k_diffusion.py b/examples/community/sd_text2img_k_diffusion.py index beb7103c15ce..1c2ba36013f8 100755 --- a/examples/community/sd_text2img_k_diffusion.py +++ b/examples/community/sd_text2img_k_diffusion.py @@ -21,7 +21,7 @@ from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.utils import is_accelerate_available, logging -from k_diffusion.external import CompVisDenoiser +from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -33,7 +33,12 @@ def __init__(self, model, alphas_cumprod): self.alphas_cumprod = alphas_cumprod def apply_model(self, *args, **kwargs): - return self.model(*args, **kwargs).sample + if len(args) == 3: + encoder_hidden_states = args[-1] + args = args[:2] + if kwargs.get("cond", None) is not None: + encoder_hidden_states = kwargs.pop("cond") + return self.model(*args, encoder_hidden_states=encoder_hidden_states, **kwargs).sample class StableDiffusionPipeline(DiffusionPipeline): @@ -63,6 +68,7 @@ class StableDiffusionPipeline(DiffusionPipeline): feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ + _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, @@ -99,7 +105,10 @@ def __init__( ) model = ModelWrapper(unet, scheduler.alphas_cumprod) - self.k_diffusion_model = CompVisDenoiser(model) + if scheduler.prediction_type == "v_prediction": + self.k_diffusion_model = CompVisVDenoiser(model) + else: + self.k_diffusion_model = CompVisDenoiser(model) def set_sampler(self, scheduler_type: str): library = importlib.import_module("k_diffusion") @@ -417,6 +426,7 @@ def __call__( # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=text_embeddings.device) sigmas = self.scheduler.sigmas + sigmas = sigmas.to(text_embeddings.dtype) # 5. Prepare latent variables num_channels_latents = self.unet.in_channels @@ -437,7 +447,7 @@ def __call__( def model_fn(x, t): latent_model_input = torch.cat([x] * 2) - noise_pred = self.k_diffusion_model(latent_model_input, t, encoder_hidden_states=text_embeddings) + noise_pred = self.k_diffusion_model(latent_model_input, t, cond=text_embeddings) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 6d470cb0e65b..1bda39cf3d49 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -49,6 +49,8 @@ HeunDiscreteScheduler, IPNDMScheduler, KarrasVeScheduler, + KDPM2AncestralDiscreteScheduler, + KDPM2DiscreteScheduler, PNDMScheduler, RePaintScheduler, SchedulerMixin, diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 7e55c1c1140b..fb64a34a0bd8 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -558,7 +558,7 @@ def __call__( latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided - if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0: + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index 78d512bd4e11..346f5f727bb8 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -580,7 +580,7 @@ def __call__( latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided - if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0: + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py index 94f30041d569..a688a52a7a6e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py @@ -666,7 +666,7 @@ def __call__( ).prev_sample # call the callback, if provided - if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0: + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 1d616622aa42..a3a8703f3ea4 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -557,7 +557,7 @@ def __call__( latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided - if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0: + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py index 19dad28bfa9b..d77e71653078 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py @@ -440,7 +440,7 @@ def __call__( latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided - if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0: + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 829128c3025e..933f59c3bd32 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -587,7 +587,7 @@ def __call__( latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided - if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0: + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index a61d0e3619c4..b3269714cdf7 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -701,7 +701,7 @@ def __call__( latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided - if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0: + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py index ea6536b0a179..60d52eaa1ab4 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -602,7 +602,7 @@ def __call__( latents = (init_latents_proper * mask) + (latents * (1 - mask)) # call the callback, if provided - if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0: + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index d3c30597f01a..72981aebe184 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -515,7 +515,7 @@ def __call__( latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided - if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0: + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) diff --git a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py index d83bcf510f84..5cb0f2c03daf 100644 --- a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +++ b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py @@ -711,7 +711,7 @@ def __call__( latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided - if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0: + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index d708963839cd..e3069f1eed28 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -22,8 +22,10 @@ from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler from .scheduling_euler_discrete import EulerDiscreteScheduler - from .scheduling_heun import HeunDiscreteScheduler + from .scheduling_heun_discrete import HeunDiscreteScheduler from .scheduling_ipndm import IPNDMScheduler + from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler + from .scheduling_k_dpm_2_discrete import KDPM2DiscreteScheduler from .scheduling_karras_ve import KarrasVeScheduler from .scheduling_pndm import PNDMScheduler from .scheduling_repaint import RePaintScheduler diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 7a805dd5a4f2..dd38bd63c27d 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -106,10 +106,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. - prediction_type (`str`, default `epsilon`): - indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`. - `v-prediction` is not supported for this scheduler. - + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) """ _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index dcf899935deb..369db8b29e7d 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -99,9 +99,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. clip_sample (`bool`, default `True`): option to clip predicted sample between -1 and 1 for numerical stability. - prediction_type (`str`, default `epsilon`): - indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`. - `v-prediction` is not supported for this scheduler. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) """ _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 25d80ec5f560..b7d083802691 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -87,9 +87,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): solver_order (`int`, default `2`): the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. - prediction_type (`str`, default `epsilon`): - indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`, - or `v-prediction`. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) thresholding (`bool`, default `False`): whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index 7011e1f30a4e..f5905a3f8364 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -64,6 +64,10 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): `linear` or `scaled_linear`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) """ diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index a3208bfdbb21..9cb4a1eaa565 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -65,6 +65,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): `linear` or `scaled_linear`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) """ diff --git a/src/diffusers/schedulers/scheduling_heun.py b/src/diffusers/schedulers/scheduling_heun_discrete.py similarity index 94% rename from src/diffusers/schedulers/scheduling_heun.py rename to src/diffusers/schedulers/scheduling_heun_discrete.py index b7bf7ca3be66..4f40a24050b4 100644 --- a/src/diffusers/schedulers/scheduling_heun.py +++ b/src/diffusers/schedulers/scheduling_heun_discrete.py @@ -24,14 +24,16 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): """ - Args: Implements Algorithm 2 (Heun steps) from Karras et al. (2022). for discrete beta schedules. Based on the original k-diffusion implementation by Katherine Crowson: https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L90 + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from @@ -40,7 +42,10 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. - tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) """ _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() @@ -77,7 +82,7 @@ def __init__( def index_for_timestep(self, timestep): indices = (self.timesteps == timestep).nonzero() if self.state_in_first_order: - pos = 0 if indices.shape[0] < 2 else 1 + pos = -1 else: pos = 0 return indices[pos].item() @@ -132,7 +137,7 @@ def set_timesteps( self.init_noise_sigma = self.sigmas.max() timesteps = torch.from_numpy(timesteps) - timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2), timesteps[-1:]]) + timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)]) if str(device).startswith("mps"): # mps does not support float64 @@ -199,9 +204,9 @@ def step( ) if self.state_in_first_order: - # 2. Convert to an ODE derivative + # 2. Convert to an ODE derivative for 1st order derivative = (sample - pred_original_sample) / sigma_hat - # 3. 1st order derivative + # 3. delta timestep dt = sigma_next - sigma_hat # store for 2nd order step @@ -213,7 +218,7 @@ def step( derivative = (sample - pred_original_sample) / sigma_next derivative = (self.prev_derivative + derivative) / 2 - # 3. Retrieve 1st order derivative + # 3. take prev timestep & sample dt = self.dt sample = self.sample diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py new file mode 100644 index 000000000000..b7d2175f027a --- /dev/null +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py @@ -0,0 +1,324 @@ +# Copyright 2022 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS +from .scheduling_utils import SchedulerMixin, SchedulerOutput + + +class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Scheduler created by @crowsonkb in [k_diffusion](https://github.com/crowsonkb/k-diffusion), see: + https://github.com/crowsonkb/k-diffusion/blob/5b3af030dd83e0297272d861c19477735d0317ec/k_diffusion/sampling.py#L188 + + Scheduler inspired by DPM-Solver-2 and Algorthim 2 from Karras et al. (2022). + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the + starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear` or `scaled_linear`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, + `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) + """ + + _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + order = 2 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.00085, # sensible defaults + beta_end: float = 0.012, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + prediction_type: str = "epsilon", + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # set all values + self.set_timesteps(num_train_timesteps, None, num_train_timesteps) + + def index_for_timestep(self, timestep): + indices = (self.timesteps == timestep).nonzero() + if self.state_in_first_order: + pos = -1 + else: + pos = 0 + return indices[pos].item() + + def scale_model_input( + self, + sample: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + ) -> torch.FloatTensor: + """ + Args: + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep + Returns: + `torch.FloatTensor`: scaled input sample + """ + step_index = self.index_for_timestep(timestep) + + if self.state_in_first_order: + sigma = self.sigmas[step_index] + else: + sigma = self.sigmas_interpol[step_index - 1] + + sample = sample / ((sigma**2 + 1) ** 0.5) + return sample + + def set_timesteps( + self, + num_inference_steps: int, + device: Union[str, torch.device] = None, + num_train_timesteps: Optional[int] = None, + ): + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, optional): + the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + self.num_inference_steps = num_inference_steps + + num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps + + timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device) + + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + sigmas = torch.from_numpy(sigmas).to(device=device) + + # compute up and down sigmas + sigmas_next = sigmas.roll(-1) + sigmas_next[-1] = 0.0 + sigmas_up = (sigmas_next**2 * (sigmas**2 - sigmas_next**2) / sigmas**2) ** 0.5 + sigmas_down = (sigmas_next**2 - sigmas_up**2) ** 0.5 + sigmas_down[-1] = 0.0 + + # compute interpolated sigmas + sigmas_interpol = sigmas.log().lerp(sigmas_down.log(), 0.5).exp() + sigmas_interpol[-2:] = 0.0 + + # set sigmas + self.sigmas = torch.cat([sigmas[:1], sigmas[1:].repeat_interleave(2), sigmas[-1:]]) + self.sigmas_interpol = torch.cat( + [sigmas_interpol[:1], sigmas_interpol[1:].repeat_interleave(2), sigmas_interpol[-1:]] + ) + self.sigmas_up = torch.cat([sigmas_up[:1], sigmas_up[1:].repeat_interleave(2), sigmas_up[-1:]]) + self.sigmas_down = torch.cat([sigmas_down[:1], sigmas_down[1:].repeat_interleave(2), sigmas_down[-1:]]) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = self.sigmas.max() + + timesteps = torch.from_numpy(timesteps).to(device) + timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device) + interleaved_timesteps = torch.stack((timesteps_interpol[:-2, None], timesteps[1:, None]), dim=-1).flatten() + timesteps = torch.cat([timesteps[:1], interleaved_timesteps]) + + if str(device).startswith("mps"): + # mps does not support float64 + self.timesteps = timesteps.to(device, dtype=torch.float32) + else: + self.timesteps = timesteps + + self.sample = None + + def sigma_to_t(self, sigma): + # get log sigma + log_sigma = sigma.log() + + # get distribution + dists = log_sigma - self.log_sigmas[:, None] + + # get sigmas range + low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = self.log_sigmas[low_idx] + high = self.log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = w.clamp(0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.view(sigma.shape) + return t + + @property + def state_in_first_order(self): + return self.sample is None + + def step( + self, + model_output: Union[torch.FloatTensor, np.ndarray], + timestep: Union[float, torch.FloatTensor], + sample: Union[torch.FloatTensor, np.ndarray], + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Args: + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. timestep + (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor` or `np.ndarray`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + step_index = self.index_for_timestep(timestep) + + if self.state_in_first_order: + sigma = self.sigmas[step_index] + sigma_interpol = self.sigmas_interpol[step_index] + sigma_up = self.sigmas_up[step_index] + sigma_down = self.sigmas_down[step_index - 1] + else: + # 2nd order / KPDM2's method + sigma = self.sigmas[step_index - 1] + sigma_interpol = self.sigmas_interpol[step_index - 1] + sigma_up = self.sigmas_up[step_index - 1] + sigma_down = self.sigmas_down[step_index - 1] + + # currently only gamma=0 is supported. This usually works best anyways. + # We can support gamma in the future but then need to scale the timestep before + # passing it to the model which requires a change in API + gamma = 0 + sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now + + device = model_output.device + if device.type == "mps": + # randn does not work reproducibly on mps + noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to( + device + ) + else: + noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator).to( + device + ) + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + if self.config.prediction_type == "epsilon": + sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol + pred_original_sample = sample - sigma_input * model_output + elif self.config.prediction_type == "v_prediction": + sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol + pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + ( + sample / (sigma_input**2 + 1) + ) + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" + ) + + if self.state_in_first_order: + # 2. Convert to an ODE derivative for 1st order + derivative = (sample - pred_original_sample) / sigma_hat + # 3. delta timestep + dt = sigma_interpol - sigma_hat + + # store for 2nd order step + self.sample = sample + self.dt = dt + prev_sample = sample + derivative * dt + else: + # DPM-Solver-2 + # 2. Convert to an ODE derivative for 2nd order + derivative = (sample - pred_original_sample) / sigma_interpol + # 3. delta timestep + dt = sigma_down - sigma_hat + + sample = self.sample + self.sample = None + + prev_sample = sample + derivative * dt + prev_sample = prev_sample + noise * sigma_up + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.FloatTensor, + ) -> torch.FloatTensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + self.timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + step_indices = [self.index_for_timestep(t) for t in timesteps] + + sigma = self.sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py new file mode 100644 index 000000000000..8aee346c574c --- /dev/null +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py @@ -0,0 +1,299 @@ +# Copyright 2022 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS +from .scheduling_utils import SchedulerMixin, SchedulerOutput + + +class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Scheduler created by @crowsonkb in [k_diffusion](https://github.com/crowsonkb/k-diffusion), see: + https://github.com/crowsonkb/k-diffusion/blob/5b3af030dd83e0297272d861c19477735d0317ec/k_diffusion/sampling.py#L188 + + Scheduler inspired by DPM-Solver-2 and Algorthim 2 from Karras et al. (2022). + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the + starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear` or `scaled_linear`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, + `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) + """ + + _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + order = 2 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.00085, # sensible defaults + beta_end: float = 0.012, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + prediction_type: str = "epsilon", + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # set all values + self.set_timesteps(num_train_timesteps, None, num_train_timesteps) + + def index_for_timestep(self, timestep): + indices = (self.timesteps == timestep).nonzero() + if self.state_in_first_order: + pos = -1 + else: + pos = 0 + return indices[pos].item() + + def scale_model_input( + self, + sample: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + ) -> torch.FloatTensor: + """ + Args: + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep + Returns: + `torch.FloatTensor`: scaled input sample + """ + step_index = self.index_for_timestep(timestep) + + if self.state_in_first_order: + sigma = self.sigmas[step_index] + else: + sigma = self.sigmas_interpol[step_index] + + sample = sample / ((sigma**2 + 1) ** 0.5) + return sample + + def set_timesteps( + self, + num_inference_steps: int, + device: Union[str, torch.device] = None, + num_train_timesteps: Optional[int] = None, + ): + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, optional): + the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + self.num_inference_steps = num_inference_steps + + num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps + + timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device) + + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + sigmas = torch.from_numpy(sigmas).to(device=device) + + # interpolate sigmas + sigmas_interpol = sigmas.log().lerp(sigmas.roll(1).log(), 0.5).exp() + + self.sigmas = torch.cat([sigmas[:1], sigmas[1:].repeat_interleave(2), sigmas[-1:]]) + self.sigmas_interpol = torch.cat( + [sigmas_interpol[:1], sigmas_interpol[1:].repeat_interleave(2), sigmas_interpol[-1:]] + ) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = self.sigmas.max() + + timesteps = torch.from_numpy(timesteps).to(device) + + # interpolate timesteps + timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device) + interleaved_timesteps = torch.stack((timesteps_interpol[1:-1, None], timesteps[1:, None]), dim=-1).flatten() + timesteps = torch.cat([timesteps[:1], interleaved_timesteps]) + + if str(device).startswith("mps"): + # mps does not support float64 + self.timesteps = timesteps.to(torch.float32) + else: + self.timesteps = timesteps + + self.sample = None + + def sigma_to_t(self, sigma): + # get log sigma + log_sigma = sigma.log() + + # get distribution + dists = log_sigma - self.log_sigmas[:, None] + + # get sigmas range + low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = self.log_sigmas[low_idx] + high = self.log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = w.clamp(0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.view(sigma.shape) + return t + + @property + def state_in_first_order(self): + return self.sample is None + + def step( + self, + model_output: Union[torch.FloatTensor, np.ndarray], + timestep: Union[float, torch.FloatTensor], + sample: Union[torch.FloatTensor, np.ndarray], + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Args: + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. timestep + (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor` or `np.ndarray`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + step_index = self.index_for_timestep(timestep) + + if self.state_in_first_order: + sigma = self.sigmas[step_index] + sigma_interpol = self.sigmas_interpol[step_index + 1] + sigma_next = self.sigmas[step_index + 1] + else: + # 2nd order / KDPM2's method + sigma = self.sigmas[step_index - 1] + sigma_interpol = self.sigmas_interpol[step_index] + sigma_next = self.sigmas[step_index] + + # currently only gamma=0 is supported. This usually works best anyways. + # We can support gamma in the future but then need to scale the timestep before + # passing it to the model which requires a change in API + gamma = 0 + sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + if self.config.prediction_type == "epsilon": + sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol + pred_original_sample = sample - sigma_input * model_output + elif self.config.prediction_type == "v_prediction": + sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol + pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + ( + sample / (sigma_input**2 + 1) + ) + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" + ) + + if self.state_in_first_order: + # 2. Convert to an ODE derivative for 1st order + derivative = (sample - pred_original_sample) / sigma_hat + # 3. delta timestep + dt = sigma_interpol - sigma_hat + + # store for 2nd order step + self.sample = sample + else: + # DPM-Solver-2 + # 2. Convert to an ODE derivative for 2nd order + derivative = (sample - pred_original_sample) / sigma_interpol + + # 3. delta timestep + dt = sigma_next - sigma_hat + + sample = self.sample + self.sample = None + + prev_sample = sample + derivative * dt + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.FloatTensor, + ) -> torch.FloatTensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + self.timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + step_indices = [self.index_for_timestep(t) for t in timesteps] + + sigma = self.sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 044eb70c213c..28bc9bd0c608 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -64,7 +64,10 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): `linear` or `scaled_linear`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. - + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) """ _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index ddfd1338a4e5..a29f7d6d44cc 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -82,6 +82,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): each diffusion step uses the value of alphas product at that step and at the previous one. For the final step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, otherwise it uses the value of alpha at step 0. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) steps_offset (`int`, default `0`): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 9846927cb1ce..23afb51cf30c 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -407,6 +407,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class KDPM2AncestralDiscreteScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class KDPM2DiscreteScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class PNDMScheduler(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 5d885fef6b12..a78bd81c826b 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -32,6 +32,8 @@ EulerDiscreteScheduler, HeunDiscreteScheduler, IPNDMScheduler, + KDPM2AncestralDiscreteScheduler, + KDPM2DiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, ScoreSdeVeScheduler, @@ -2197,3 +2199,270 @@ def test_full_loop_device(self): # CUDA assert abs(result_sum.item() - 0.1233) < 1e-2 assert abs(result_mean.item() - 0.0002) < 1e-3 + + +class KDPM2DiscreteSchedulerTest(SchedulerCommonTest): + scheduler_classes = (KDPM2DiscreteScheduler,) + num_inference_steps = 10 + + def get_scheduler_config(self, **kwargs): + config = { + "num_train_timesteps": 1100, + "beta_start": 0.0001, + "beta_end": 0.02, + "beta_schedule": "linear", + } + + config.update(**kwargs) + return config + + def test_timesteps(self): + for timesteps in [10, 50, 100, 1000]: + self.check_over_configs(num_train_timesteps=timesteps) + + def test_betas(self): + for beta_start, beta_end in zip([0.00001, 0.0001, 0.001], [0.0002, 0.002, 0.02]): + self.check_over_configs(beta_start=beta_start, beta_end=beta_end) + + def test_schedules(self): + for schedule in ["linear", "scaled_linear"]: + self.check_over_configs(beta_schedule=schedule) + + def test_prediction_type(self): + for prediction_type in ["epsilon", "v_prediction"]: + self.check_over_configs(prediction_type=prediction_type) + + def test_full_loop_with_v_prediction(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(prediction_type="v_prediction") + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(self.num_inference_steps) + + model = self.dummy_model() + sample = self.dummy_sample_deter * scheduler.init_noise_sigma + sample = sample.to(torch_device) + + for i, t in enumerate(scheduler.timesteps): + sample = scheduler.scale_model_input(sample, t) + + model_output = model(sample, t) + + output = scheduler.step(model_output, t, sample) + sample = output.prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + if torch_device in ["cpu", "mps"]: + assert abs(result_sum.item() - 4.6934e-07) < 1e-2 + assert abs(result_mean.item() - 6.1112e-10) < 1e-3 + else: + # CUDA + assert abs(result_sum.item() - 4.693428650170972e-07) < 1e-2 + assert abs(result_mean.item() - 0.0002) < 1e-3 + + def test_full_loop_no_noise(self): + if torch_device == "mps": + return + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(self.num_inference_steps) + + model = self.dummy_model() + sample = self.dummy_sample_deter * scheduler.init_noise_sigma + sample = sample.to(torch_device) + + for i, t in enumerate(scheduler.timesteps): + sample = scheduler.scale_model_input(sample, t) + + model_output = model(sample, t) + + output = scheduler.step(model_output, t, sample) + sample = output.prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + if torch_device in ["cpu", "mps"]: + assert abs(result_sum.item() - 20.4125) < 1e-2 + assert abs(result_mean.item() - 0.0266) < 1e-3 + else: + # CUDA + assert abs(result_sum.item() - 20.4125) < 1e-2 + assert abs(result_mean.item() - 0.0266) < 1e-3 + + def test_full_loop_device(self): + if torch_device == "mps": + return + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(self.num_inference_steps, device=torch_device) + + model = self.dummy_model() + sample = self.dummy_sample_deter.to(torch_device) * scheduler.init_noise_sigma + + for t in scheduler.timesteps: + sample = scheduler.scale_model_input(sample, t) + + model_output = model(sample, t) + + output = scheduler.step(model_output, t, sample) + sample = output.prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + if str(torch_device).startswith("cpu"): + # The following sum varies between 148 and 156 on mps. Why? + assert abs(result_sum.item() - 20.4125) < 1e-2 + assert abs(result_mean.item() - 0.0266) < 1e-3 + else: + # CUDA + assert abs(result_sum.item() - 20.4125) < 1e-2 + assert abs(result_mean.item() - 0.0266) < 1e-3 + + +class KDPM2AncestralDiscreteSchedulerTest(SchedulerCommonTest): + scheduler_classes = (KDPM2AncestralDiscreteScheduler,) + num_inference_steps = 10 + + def get_scheduler_config(self, **kwargs): + config = { + "num_train_timesteps": 1100, + "beta_start": 0.0001, + "beta_end": 0.02, + "beta_schedule": "linear", + } + + config.update(**kwargs) + return config + + def test_timesteps(self): + for timesteps in [10, 50, 100, 1000]: + self.check_over_configs(num_train_timesteps=timesteps) + + def test_betas(self): + for beta_start, beta_end in zip([0.00001, 0.0001, 0.001], [0.0002, 0.002, 0.02]): + self.check_over_configs(beta_start=beta_start, beta_end=beta_end) + + def test_schedules(self): + for schedule in ["linear", "scaled_linear"]: + self.check_over_configs(beta_schedule=schedule) + + def test_full_loop_no_noise(self): + if torch_device == "mps": + return + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(self.num_inference_steps) + + generator = torch.Generator(device=torch_device).manual_seed(0) + + model = self.dummy_model() + sample = self.dummy_sample_deter * scheduler.init_noise_sigma + sample = sample.to(torch_device) + + for i, t in enumerate(scheduler.timesteps): + sample = scheduler.scale_model_input(sample, t) + + model_output = model(sample, t) + + output = scheduler.step(model_output, t, sample, generator=generator) + sample = output.prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + if torch_device in ["cpu", "mps"]: + assert abs(result_sum.item() - 13849.3945) < 1e-2 + assert abs(result_mean.item() - 18.0331) < 5e-3 + else: + # CUDA + assert abs(result_sum.item() - 13913.0449) < 1e-2 + assert abs(result_mean.item() - 18.1159) < 5e-3 + + def test_prediction_type(self): + for prediction_type in ["epsilon", "v_prediction"]: + self.check_over_configs(prediction_type=prediction_type) + + def test_full_loop_with_v_prediction(self): + if torch_device == "mps": + return + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(prediction_type="v_prediction") + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(self.num_inference_steps) + + model = self.dummy_model() + sample = self.dummy_sample_deter * scheduler.init_noise_sigma + sample = sample.to(torch_device) + + if torch_device == "mps": + # device type MPS is not supported for torch.Generator() api. + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=torch_device).manual_seed(0) + + for i, t in enumerate(scheduler.timesteps): + sample = scheduler.scale_model_input(sample, t) + + model_output = model(sample, t) + + output = scheduler.step(model_output, t, sample, generator=generator) + sample = output.prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + if torch_device in ["cpu", "mps"]: + assert abs(result_sum.item() - 328.9970) < 1e-2 + assert abs(result_mean.item() - 0.4284) < 1e-3 + else: + # CUDA + assert abs(result_sum.item() - 327.8027) < 1e-2 + assert abs(result_mean.item() - 0.4268) < 1e-3 + + def test_full_loop_device(self): + if torch_device == "mps": + return + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(self.num_inference_steps, device=torch_device) + + if torch_device == "mps": + # device type MPS is not supported for torch.Generator() api. + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=torch_device).manual_seed(0) + + model = self.dummy_model() + sample = self.dummy_sample_deter.to(torch_device) * scheduler.init_noise_sigma + + for t in scheduler.timesteps: + sample = scheduler.scale_model_input(sample, t) + + model_output = model(sample, t) + + output = scheduler.step(model_output, t, sample, generator=generator) + sample = output.prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + if str(torch_device).startswith("cpu"): + assert abs(result_sum.item() - 13849.3945) < 1e-2 + assert abs(result_mean.item() - 18.0331) < 5e-3 + else: + # CUDA + assert abs(result_sum.item() - 13913.0332) < 1e-1 + assert abs(result_mean.item() - 18.1159) < 1e-3