Skip to content

Commit

Permalink
fix(api): update latent window size on VAE patches
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed May 3, 2023
1 parent 72c39b6 commit 98386cb
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 7 deletions.
8 changes: 6 additions & 2 deletions api/onnx_web/diffusers/load.py
Expand Up @@ -127,8 +127,12 @@ def load_pipeline(

# update panorama params
if pipeline == "panorama":
cache_pipe.window = params.tiles // 8
cache_pipe.stride = params.stride() // 8
latent_window = params.tiles // 8
latent_stride = params.stride() // 8

cache_pipe.set_window_size(latent_window, latent_stride)
cache_pipe.vae_encoder.set_window_size(latent_window, latent_stride)
cache_pipe.vae_decoder.set_window_size(latent_window, latent_stride)

# update scheduler
cache_scheduler = server.cache.get("scheduler", scheduler_key)
Expand Down
13 changes: 8 additions & 5 deletions api/onnx_web/diffusers/patches/vae.py
Expand Up @@ -34,12 +34,15 @@ def __init__(
self.server = server
self.wrapped = wrapped
self.decoder = decoder
self.tiles = tiles
self.set_window_size(tiles, stride)

def set_window_size(self, window: int, stride: int):
self.window = window
self.stride = stride

self.tile_latent_min_size = tiles
self.tile_sample_min_size = tiles * 8
self.tile_overlap_factor = stride / tiles
self.tile_latent_min_size = self.window
self.tile_sample_min_size = self.window * 8
self.tile_overlap_factor = self.stride / self.window

def __call__(self, latent_sample=None, sample=None, **kwargs):
global timestep_dtype
Expand All @@ -59,7 +62,7 @@ def __call__(self, latent_sample=None, sample=None, **kwargs):
logger.debug("converting VAE sample dtype")
sample = sample.astype(timestep_dtype)

if self.tiles is not None and self.stride is not None:
if self.window is not None and self.stride is not None:
if self.decoder:
return self.tiled_decode(latent_sample, **kwargs)
else:
Expand Down
4 changes: 4 additions & 0 deletions api/onnx_web/diffusers/pipelines/panorama.py
Expand Up @@ -1255,3 +1255,7 @@ def __call__(
else:
logger.debug("running txt2img panorama pipeline")
return self.text2img(*args, **kwargs)

def set_window_size(self, window: int, stride: int):
self.window = window
self.stride = stride

0 comments on commit 98386cb

Please sign in to comment.