Skip to content

Commit

Permalink
fix(api): update SD upscaling pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Feb 17, 2023
1 parent 1ea1a57 commit 3d73b9e
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
2 changes: 1 addition & 1 deletion api/onnx_web/convert/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,4 +222,4 @@ def load_tensor(name: str, map_location=None):
checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint
)

return checkpoint
return checkpoint
24 changes: 15 additions & 9 deletions api/onnx_web/diffusion/pipeline_onnx_stable_diffusion_upscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,10 @@
NUM_LATENT_CHANNELS = 4
NUM_UNET_INPUT_CHANNELS = 7

# TODO: should this be a lookup? it needs to match the conversion script
class_labels_dtype = np.long

# TODO: should this be a lookup or converted? can it vary on ONNX?
text_embeddings_dtype = torch.float32
TORCH_DTYPES = {
"float16": torch.float16,
"float32": torch.float32,
}


def preprocess(image):
Expand Down Expand Up @@ -98,6 +97,8 @@ def __call__(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
)

latents_dtype = TORCH_DTYPES[str(text_embeddings.dtype)]

# 4. Preprocess image
image = preprocess(image)
image = image.cpu()
Expand All @@ -108,7 +109,7 @@ def __call__(

# 5. Add noise to image
noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
noise = torch.randn(image.shape, generator=generator, device=device, dtype=text_embeddings_dtype)
noise = torch.randn(image.shape, generator=generator, device=device, dtype=latents_dtype)
image = self.low_res_scheduler.add_noise(image, noise, noise_level)

batch_multiplier = 2 if do_classifier_free_guidance else 1
Expand All @@ -122,7 +123,7 @@ def __call__(
NUM_LATENT_CHANNELS,
height,
width,
text_embeddings_dtype,
latents_dtype,
device,
generator,
latents,
Expand All @@ -142,6 +143,11 @@ def __call__(
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

timestep_dtype = next(
(input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
)
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]

# 9. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
Expand All @@ -154,14 +160,14 @@ def __call__(
latent_model_input = np.concatenate([latent_model_input, image], axis=1)

# timestep to tensor
timestep = np.array([t], dtype=np.float32)
timestep = np.array([t], dtype=timestep_dtype)

# predict the noise residual
noise_pred = self.unet(
sample=latent_model_input,
timestep=timestep,
encoder_hidden_states=text_embeddings,
class_labels=noise_level.astype(class_labels_dtype),
class_labels=noise_level.astype(np.int64),
)[0]

# perform guidance
Expand Down

0 comments on commit 3d73b9e

Please sign in to comment.