Skip to content

Commit

Permalink
fix(api): use VAE model dtype when converting sample
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jun 6, 2023
1 parent 6a09404 commit 395a632
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 17 deletions.
2 changes: 0 additions & 2 deletions api/onnx_web/diffusers/patches/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from diffusers import OnnxRuntimeModel

from ...server import ServerContext
from .vae import set_vae_dtype

logger = getLogger(__name__)

Expand Down Expand Up @@ -37,7 +36,6 @@ def __call__(
timestep.dtype,
encoder_hidden_states.dtype,
)
set_vae_dtype(timestep.dtype)

if self.prompt_embeds is not None:
step_index = self.prompt_index % len(self.prompt_embeds)
Expand Down
26 changes: 11 additions & 15 deletions api/onnx_web/diffusers/patches/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,14 @@
from diffusers import OnnxRuntimeModel
from diffusers.models.autoencoder_kl import AutoencoderKLOutput
from diffusers.models.vae import DecoderOutput
from onnx.helper import tensor_dtype_to_np_dtype

from ...server import ServerContext

logger = getLogger(__name__)

LATENT_CHANNELS = 4

# TODO: does this need to change for fp16 modes?
timestep_dtype = np.float32


def set_vae_dtype(dtype):
global timestep_dtype
timestep_dtype = dtype


class VAEWrapper(object):
def __init__(
Expand All @@ -46,7 +39,10 @@ def set_window_size(self, window: int, overlap: float):
self.tile_overlap_factor = overlap

def __call__(self, latent_sample=None, sample=None, **kwargs):
global timestep_dtype
# set timestep dtype to input type
inputs = self.wrapped.model.graph.input
sample_input = [i for i in inputs if i.name == "sample" or i.name == "latent_sample"][0]
sample_dtype = tensor_dtype_to_np_dtype(sample_input.type.tensor_type.elem_type)

logger.trace(
"VAE %s parameter types: %s, %s",
Expand All @@ -55,13 +51,13 @@ def __call__(self, latent_sample=None, sample=None, **kwargs):
(sample.dtype if sample is not None else "none"),
)

if latent_sample is not None and latent_sample.dtype != timestep_dtype:
logger.debug("converting VAE latent sample dtype")
latent_sample = latent_sample.astype(timestep_dtype)
if latent_sample is not None and latent_sample.dtype != sample_dtype:
logger.debug("converting VAE latent sample dtype to %s", sample_dtype)
latent_sample = latent_sample.astype(sample_dtype)

if sample is not None and sample.dtype != timestep_dtype:
logger.debug("converting VAE sample dtype")
sample = sample.astype(timestep_dtype)
if sample is not None and sample.dtype != sample_dtype:
logger.debug("converting VAE sample dtype to %s", sample_dtype)
sample = sample.astype(sample_dtype)

if self.tiled:
if self.decoder:
Expand Down

0 comments on commit 395a632

Please sign in to comment.