Skip to content

Commit

Permalink
[Torch 2.0 compile] Fix more torch compile breaks (huggingface#3313)
Browse files Browse the repository at this point in the history
* Fix more torch compile breaks

* add tests

* Fix all

* fix controlnet

* fix more

* Add Horace He as co-author.
>
>
Co-authored-by: Horace He <horacehe2007@yahoo.com>

* Add Horace He as co-author.

Co-authored-by: Horace He <horacehe2007@yahoo.com>

---------

Co-authored-by: Horace He <horacehe2007@yahoo.com>
  • Loading branch information
patrickvonplaten and Chillee committed May 2, 2023
1 parent 59d8041 commit cd42a2e
Show file tree
Hide file tree
Showing 18 changed files with 111 additions and 56 deletions.
15 changes: 8 additions & 7 deletions models/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def forward(
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
t_emb = t_emb.to(dtype=sample.dtype)

emb = self.time_embedding(t_emb, timestep_cond)

Expand All @@ -517,7 +517,7 @@ def forward(

controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)

sample += controlnet_cond
sample = sample + controlnet_cond

# 3. down
down_block_res_samples = (sample,)
Expand Down Expand Up @@ -551,21 +551,22 @@ def forward(

for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
down_block_res_sample = controlnet_block(down_block_res_sample)
controlnet_down_block_res_samples += (down_block_res_sample,)
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)

down_block_res_samples = controlnet_down_block_res_samples

mid_block_res_sample = self.controlnet_mid_block(sample)

# 6. scaling
if guess_mode:
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1) # 0.1 to 1.0
scales *= conditioning_scale
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0

scales = scales * conditioning_scale
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
mid_block_res_sample *= scales[-1] # last one
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
else:
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
mid_block_res_sample *= conditioning_scale
mid_block_res_sample = mid_block_res_sample * conditioning_scale

if self.config.global_pool_conditions:
down_block_res_samples = [
Expand Down
2 changes: 1 addition & 1 deletion models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,7 @@ def forward(
down_block_res_samples, down_block_additional_residuals
):
down_block_res_sample = down_block_res_sample + down_block_additional_residual
new_down_block_res_samples += (down_block_res_sample,)
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)

down_block_res_samples = new_down_block_res_samples

Expand Down
9 changes: 5 additions & 4 deletions pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def decode_latents(self, latents):
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
Expand Down Expand Up @@ -728,15 +728,16 @@ def __call__(
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
).sample
return_dict=False,
)[0]

# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
Expand All @@ -745,7 +746,7 @@ def __call__(
callback(i, t, latents)

if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor).sample
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
Expand Down
7 changes: 4 additions & 3 deletions pipelines/deepfloyd_if/pipeline_if_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,7 +918,8 @@ def __call__(
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
).sample
return_dict=False,
)[0]

# perform guidance
if do_classifier_free_guidance:
Expand All @@ -930,8 +931,8 @@ def __call__(

# compute the previous noisy sample x_t -> x_t-1
intermediate_images = self.scheduler.step(
noise_pred, t, intermediate_images, **extra_step_kwargs
).prev_sample
noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
)[0]

# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
Expand Down
7 changes: 4 additions & 3 deletions pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,7 +1036,8 @@ def __call__(
encoder_hidden_states=prompt_embeds,
class_labels=noise_level,
cross_attention_kwargs=cross_attention_kwargs,
).sample
return_dict=False,
)[0]

# perform guidance
if do_classifier_free_guidance:
Expand All @@ -1048,8 +1049,8 @@ def __call__(

# compute the previous noisy sample x_t -> x_t-1
intermediate_images = self.scheduler.step(
noise_pred, t, intermediate_images, **extra_step_kwargs
).prev_sample
noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
)[0]

# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
Expand Down
7 changes: 4 additions & 3 deletions pipelines/deepfloyd_if/pipeline_if_inpainting.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,7 +1033,8 @@ def __call__(
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
).sample
return_dict=False,
)[0]

# perform guidance
if do_classifier_free_guidance:
Expand All @@ -1047,8 +1048,8 @@ def __call__(
prev_intermediate_images = intermediate_images

intermediate_images = self.scheduler.step(
noise_pred, t, intermediate_images, **extra_step_kwargs
).prev_sample
noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
)[0]

intermediate_images = (1 - mask_image) * prev_intermediate_images + mask_image * intermediate_images

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1143,7 +1143,8 @@ def __call__(
encoder_hidden_states=prompt_embeds,
class_labels=noise_level,
cross_attention_kwargs=cross_attention_kwargs,
).sample
return_dict=False,
)[0]

# perform guidance
if do_classifier_free_guidance:
Expand All @@ -1157,8 +1158,8 @@ def __call__(
prev_intermediate_images = intermediate_images

intermediate_images = self.scheduler.step(
noise_pred, t, intermediate_images, **extra_step_kwargs
).prev_sample
noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
)[0]

intermediate_images = (1 - mask_image) * prev_intermediate_images + mask_image * intermediate_images

Expand Down
7 changes: 4 additions & 3 deletions pipelines/deepfloyd_if/pipeline_if_superresolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,7 +886,8 @@ def __call__(
encoder_hidden_states=prompt_embeds,
class_labels=noise_level,
cross_attention_kwargs=cross_attention_kwargs,
).sample
return_dict=False,
)[0]

# perform guidance
if do_classifier_free_guidance:
Expand All @@ -898,8 +899,8 @@ def __call__(

# compute the previous noisy sample x_t -> x_t-1
intermediate_images = self.scheduler.step(
noise_pred, t, intermediate_images, **extra_step_kwargs
).prev_sample
noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
)[0]

# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
Expand Down
48 changes: 40 additions & 8 deletions pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import numpy as np
import PIL.Image
import torch
import torch.nn.functional as F
from torch import nn
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer

Expand Down Expand Up @@ -579,9 +580,20 @@ def check_inputs(
)

# Check `image`
if isinstance(self.controlnet, ControlNetModel):
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
)
if (
isinstance(self.controlnet, ControlNetModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, ControlNetModel)
):
self.check_image(image, prompt, prompt_embeds)
elif isinstance(self.controlnet, MultiControlNetModel):
elif (
isinstance(self.controlnet, MultiControlNetModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
):
if not isinstance(image, list):
raise TypeError("For multiple controlnets: `image` must be type `list`")

Expand All @@ -600,10 +612,18 @@ def check_inputs(
assert False

# Check `controlnet_conditioning_scale`
if isinstance(self.controlnet, ControlNetModel):
if (
isinstance(self.controlnet, ControlNetModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, ControlNetModel)
):
if not isinstance(controlnet_conditioning_scale, float):
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
elif isinstance(self.controlnet, MultiControlNetModel):
elif (
isinstance(self.controlnet, MultiControlNetModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
):
if isinstance(controlnet_conditioning_scale, list):
if any(isinstance(i, list) for i in controlnet_conditioning_scale):
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
Expand Down Expand Up @@ -910,7 +930,14 @@ def __call__(
)

# 4. Prepare image
if isinstance(self.controlnet, ControlNetModel):
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
)
if (
isinstance(self.controlnet, ControlNetModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, ControlNetModel)
):
image = self.prepare_image(
image=image,
width=width,
Expand All @@ -922,7 +949,11 @@ def __call__(
do_classifier_free_guidance=do_classifier_free_guidance,
guess_mode=guess_mode,
)
elif isinstance(self.controlnet, MultiControlNetModel):
elif (
isinstance(self.controlnet, MultiControlNetModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
):
images = []

for image_ in image:
Expand Down Expand Up @@ -1006,15 +1037,16 @@ def __call__(
cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
).sample
return_dict=False,
)[0]

# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -677,15 +677,17 @@ def __call__(
latent_model_input = torch.cat([latent_model_input, depth_mask], dim=1)

# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds, return_dict=False)[
0
]

# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def decode_latents(self, latents):
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
Expand Down Expand Up @@ -734,15 +734,16 @@ def __call__(
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
).sample
return_dict=False,
)[0]

# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
Expand All @@ -751,7 +752,7 @@ def __call__(
callback(i, t, latents)

if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor).sample
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -878,15 +878,17 @@ def __call__(
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)

# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds, return_dict=False)[
0
]

# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -690,15 +690,17 @@ def __call__(
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds, return_dict=False)[
0
]

# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# masking
if add_predicted_noise:
init_latents_proper = self.scheduler.add_noise(
Expand Down
Loading

0 comments on commit cd42a2e

Please sign in to comment.