Skip to content

Commit

Permalink
fix(api): make SD upscaling compatible with more schedulers
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Feb 15, 2023
1 parent 3e5edb1 commit 2b29b09
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 12 deletions.
4 changes: 3 additions & 1 deletion api/onnx_web/convert/diffusion_stable.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,9 @@ def convert_diffusion_stable(
# UNET
if single_vae:
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "class_labels"]
unet_scale = torch.tensor(4).to(device=ctx.training_device, dtype=torch.int)
unet_scale = torch.tensor(4).to(
device=ctx.training_device, dtype=torch.long
)
else:
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"]
unet_scale = torch.tensor(False).to(
Expand Down
27 changes: 16 additions & 11 deletions api/onnx_web/diffusion/pipeline_onnx_stable_diffusion_upscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,15 @@

logger = getLogger(__name__)

# TODO: make this dynamic, from self.vae.config.latent_channels
num_channels_latents = 4

# TODO: make this dynamic, from self.unet.config.in_channels
unet_in_channels = 7
NUM_LATENT_CHANNELS = 4
NUM_UNET_INPUT_CHANNELS = 7

# TODO: should this be a lookup? it needs to match the conversion script
class_labels_dtype = np.long

# TODO: should this be a lookup or converted? can it vary on ONNX?
text_embeddings_dtype = torch.float32


def preprocess(image):
Expand Down Expand Up @@ -93,7 +97,6 @@ def __call__(
text_embeddings = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
)
text_embeddings_dtype = torch.float32 # TODO: convert text_embeddings.dtype to torch dtype

# 4. Preprocess image
image = preprocess(image)
Expand All @@ -116,7 +119,7 @@ def __call__(
height, width = image.shape[2:]
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
NUM_LATENT_CHANNELS,
height,
width,
text_embeddings_dtype,
Expand All @@ -127,12 +130,12 @@ def __call__(

# 7. Check that sizes of image and latents match
num_channels_image = image.shape[1]
if num_channels_latents + num_channels_image != unet_in_channels:
if NUM_LATENT_CHANNELS + num_channels_image != NUM_UNET_INPUT_CHANNELS:
raise ValueError(
"Incorrect configuration settings! The config of `pipeline.unet` expects"
f" {unet_in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" {NUM_UNET_INPUT_CHANNELS} but received `num_channels_latents`: {NUM_LATENT_CHANNELS} +"
f" `num_channels_image`: {num_channels_image} "
f" = {num_channels_latents+num_channels_image}. Please verify the config of"
f" = {NUM_LATENT_CHANNELS+num_channels_image}. Please verify the config of"
" `pipeline.unet` or your `image` input."
)

Expand All @@ -158,7 +161,7 @@ def __call__(
sample=latent_model_input,
timestep=timestep,
encoder_hidden_states=text_embeddings,
class_labels=noise_level,
class_labels=noise_level.astype(class_labels_dtype),
)[0]

# perform guidance
Expand All @@ -167,7 +170,9 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(
torch.from_numpy(noise_pred), t, latents, **extra_step_kwargs
).prev_sample

# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
Expand Down

0 comments on commit 2b29b09

Please sign in to comment.