Skip to content

Commit

Permalink
feat(api): add tile size and stride to image parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed May 2, 2023
1 parent 746e33b commit 95725ff
Show file tree
Hide file tree
Showing 15 changed files with 140 additions and 47 deletions.
1 change: 1 addition & 0 deletions api/onnx_web/chain/blend_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def blend_controlnet(

pipe = load_pipeline(
server,
params,
"controlnet",
params.model,
params.scheduler,
Expand Down
1 change: 1 addition & 0 deletions api/onnx_web/chain/blend_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def blend_img2img(
pipe_type = "lpw" if params.lpw() else "img2img"
pipe = load_pipeline(
server,
params,
pipe_type,
params.model,
params.scheduler,
Expand Down
1 change: 1 addition & 0 deletions api/onnx_web/chain/blend_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def blend_inpaint(
pipe_type = "lpw" if params.lpw() else "inpaint"
pipe = load_pipeline(
server,
params,
pipe_type,
params.model,
params.scheduler,
Expand Down
1 change: 1 addition & 0 deletions api/onnx_web/chain/blend_pix2pix.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def blend_pix2pix(

pipe = load_pipeline(
server,
params,
"pix2pix",
params.model,
params.scheduler,
Expand Down
1 change: 1 addition & 0 deletions api/onnx_web/chain/source_txt2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def source_txt2img(
pipe_type = "lpw" if params.lpw() else "txt2img"
pipe = load_pipeline(
server,
params,
pipe_type,
params.model,
params.scheduler,
Expand Down
5 changes: 2 additions & 3 deletions api/onnx_web/chain/upscale_outpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
pipe_type = params.get_valid_pipeline("inpaint", params.pipeline)
pipe = load_pipeline(
server,
params,
pipe_type,
params.model,
params.scheduler,
Expand Down Expand Up @@ -125,9 +126,7 @@ def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):

if params.pipeline == "panorama":
logger.debug("outpainting with one shot panorama, no tiling")
return outpaint(
source, (source.width, source.height, max(source.width, source.height))
)
return outpaint(source, (0, 0, max(source.width, source.height)))
if overlap == 0:
logger.debug("outpainting with 0 margin, using grid tiling")
output = process_tile_grid(source, SizeChart.auto, 1, [outpaint])
Expand Down
1 change: 1 addition & 0 deletions api/onnx_web/chain/upscale_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def upscale_stable_diffusion(

pipeline = load_pipeline(
server,
params,
"upscale",
path.join(server.model_path, upscale.upscale_model),
params.scheduler,
Expand Down
47 changes: 33 additions & 14 deletions api/onnx_web/diffusers/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from ..convert.diffusion.textual_inversion import blend_textual_inversions
from ..diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
from ..diffusers.utils import expand_prompt
from ..models.meta import NetworkModel
from ..params import DeviceParams
from ..params import DeviceParams, ImageParams
from ..server import ServerContext
from ..utils import run_gc
from .patches.unet import UNetWrapper
Expand Down Expand Up @@ -93,20 +92,20 @@ def get_scheduler_name(scheduler: Any) -> Optional[str]:

def load_pipeline(
server: ServerContext,
params: ImageParams,
pipeline: str,
model: str,
scheduler_name: str,
device: DeviceParams,
control: Optional[NetworkModel] = None,
inversions: Optional[List[Tuple[str, float]]] = None,
loras: Optional[List[Tuple[str, float]]] = None,
):
inversions = inversions or []
loras = loras or []
control_key = control.name if control is not None else None
model = params.model

torch_dtype = server.torch_dtype()
logger.debug("using Torch dtype %s for pipeline", torch_dtype)

control_key = params.control.name if params.control is not None else None
pipe_key = (
pipeline,
model,
Expand All @@ -116,8 +115,8 @@ def load_pipeline(
inversions,
loras,
)
scheduler_key = (scheduler_name, model)
scheduler_type = pipeline_schedulers[scheduler_name]
scheduler_key = (params.scheduler, model)
scheduler_type = pipeline_schedulers[params.scheduler]

cache_pipe = server.cache.get("diffusion", pipe_key)

Expand Down Expand Up @@ -164,8 +163,10 @@ def load_pipeline(
unet_type = "unet"

# ControlNet component
if pipeline == "controlnet" and control is not None:
cnet_path = path.join(server.model_path, "control", f"{control.name}.onnx")
if pipeline == "controlnet" and params.control is not None:
cnet_path = path.join(
server.model_path, "control", f"{params.control.name}.onnx"
)
logger.debug("loading ControlNet weights from %s", cnet_path)
components["controlnet"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
Expand Down Expand Up @@ -317,6 +318,11 @@ def load_pipeline(
)
)

# additional options for panorama pipeline
if pipeline == "panorama":
components["window"] = params.tiles
components["stride"] = params.stride()

pipeline_class = available_pipelines.get(pipeline, OnnxStableDiffusionPipeline)
logger.debug("loading pretrained SD pipeline for %s", pipeline_class.__name__)
pipe = pipeline_class.from_pretrained(
Expand All @@ -333,12 +339,12 @@ def load_pipeline(

optimize_pipeline(server, pipe)

# TODO: CPU VAE, etc
# TODO: remove this, not relevant with ONNX
if device is not None and hasattr(pipe, "to"):
pipe = pipe.to(device.torch_str())

# monkey-patch pipeline
patch_pipeline(server, pipe, pipeline)
patch_pipeline(server, pipe, pipeline_class, params)

server.cache.set("diffusion", pipe_key, pipe)
server.cache.set("scheduler", scheduler_key, components["scheduler"])
Expand Down Expand Up @@ -402,6 +408,7 @@ def patch_pipeline(
server: ServerContext,
pipe: StableDiffusionPipeline,
pipeline: Any,
params: ImageParams,
) -> None:
logger.debug("patching SD pipeline")
pipe._encode_prompt = expand_prompt.__get__(pipe, pipeline)
Expand All @@ -411,9 +418,21 @@ def patch_pipeline(

if hasattr(pipe, "vae_decoder"):
original_decoder = pipe.vae_decoder
pipe.vae_decoder = VAEWrapper(server, original_decoder, decoder=True)
pipe.vae_decoder = VAEWrapper(
server,
original_decoder,
decoder=True,
tiles=params.tiles,
stride=params.stride(),
)
original_encoder = pipe.vae_encoder
pipe.vae_encoder = VAEWrapper(server, original_encoder, decoder=False)
pipe.vae_encoder = VAEWrapper(
server,
original_encoder,
decoder=False,
tiles=params.tiles,
stride=params.stride(),
)
elif hasattr(pipe, "vae"):
pass # TODO: current wrapper does not work with upscaling VAE
else:
Expand Down
31 changes: 22 additions & 9 deletions api/onnx_web/diffusers/patches/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
logger = getLogger(__name__)

LATENT_CHANNELS = 4
LATENT_SIZE = 32
SAMPLE_SIZE = 256

# TODO: does this need to change for fp16 modes?
timestep_dtype = np.float32
Expand All @@ -25,14 +23,23 @@ def set_vae_dtype(dtype):


class VAEWrapper(object):
def __init__(self, server: ServerContext, wrapped: OnnxRuntimeModel, decoder: bool):
def __init__(
self,
server: ServerContext,
wrapped: OnnxRuntimeModel,
decoder: bool,
tiles: int,
stride: int,
):
self.server = server
self.wrapped = wrapped
self.decoder = decoder
self.tiles = tiles
self.stride = stride

self.tile_sample_min_size = SAMPLE_SIZE
self.tile_latent_min_size = LATENT_SIZE
self.tile_overlap_factor = 0.25
self.tile_latent_min_size = tiles
self.tile_sample_min_size = tiles * 8
self.tile_overlap_factor = stride / tiles

def __call__(self, latent_sample=None, sample=None, **kwargs):
global timestep_dtype
Expand All @@ -52,10 +59,16 @@ def __call__(self, latent_sample=None, sample=None, **kwargs):
logger.debug("converting VAE sample dtype")
sample = sample.astype(timestep_dtype)

if self.decoder:
return self.tiled_decode(latent_sample, **kwargs)
if self.tiles is not None and self.stride is not None:
if self.decoder:
return self.tiled_decode(latent_sample, **kwargs)
else:
return self.tiled_encode(sample, **kwargs)
else:
return self.tiled_encode(sample, **kwargs)
if self.decoder:
return self.wrapped(latent_sample=latent_sample)
else:
return self.wrapped(sample=sample)

def __getattr__(self, attr):
return getattr(self.wrapped, attr)
Expand Down
16 changes: 12 additions & 4 deletions api/onnx_web/diffusers/pipelines/panorama.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@
NUM_UNET_INPUT_CHANNELS = 9
NUM_LATENT_CHANNELS = 4

DEFAULT_WINDOW = 32
DEFAULT_STRIDE = 8


def preprocess(image):
if isinstance(image, torch.Tensor):
Expand Down Expand Up @@ -105,9 +108,14 @@ def __init__(
safety_checker: OnnxRuntimeModel,
feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True,
window: Optional[int] = None,
stride: Optional[int] = None,
):
super().__init__()

self.window = window or DEFAULT_WINDOW
self.stride = stride or DEFAULT_STRIDE

if (
hasattr(scheduler.config, "steps_offset")
and scheduler.config.steps_offset != 1
Expand Down Expand Up @@ -338,7 +346,7 @@ def check_inputs(
f" {negative_prompt_embeds.shape}."
)

def get_views(self, panorama_height, panorama_width, window_size=32, stride=8):
def get_views(self, panorama_height, panorama_width, window_size, stride):
# Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113)
panorama_height /= 8
panorama_width /= 8
Expand Down Expand Up @@ -514,7 +522,7 @@ def text2img(
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]

# panorama additions
views = self.get_views(height, width)
views = self.get_views(height, width, self.window, self.stride)
count = np.zeros_like(latents)
value = np.zeros_like(latents)

Expand Down Expand Up @@ -816,7 +824,7 @@ def img2img(
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]

# panorama additions
views = self.get_views(height, width)
views = self.get_views(height, width, self.window, self.stride)
count = np.zeros_like(latents)
value = np.zeros_like(latents)

Expand Down Expand Up @@ -1124,7 +1132,7 @@ def inpaint(
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]

# panorama additions
views = self.get_views(height, width)
views = self.get_views(height, width, self.window, self.stride)
count = np.zeros_like(latents)
value = np.zeros_like(latents)

Expand Down
13 changes: 4 additions & 9 deletions api/onnx_web/diffusers/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,8 @@ def run_loopback(

pipe = pipeline or load_pipeline(
server,
params,
pipe_type,
params.model,
params.scheduler,
job.get_device(),
inversions=inversions,
loras=loras,
Expand Down Expand Up @@ -140,9 +139,8 @@ def run_highres(

highres_pipe = pipeline or load_pipeline(
server,
params,
pipe_type,
params.model,
params.scheduler,
job.get_device(),
inversions=inversions,
loras=loras,
Expand Down Expand Up @@ -242,9 +240,8 @@ def run_txt2img_pipeline(

pipe = load_pipeline(
server,
params,
pipe_type,
params.model,
params.scheduler,
job.get_device(),
inversions=inversions,
loras=loras,
Expand Down Expand Up @@ -350,11 +347,9 @@ def run_img2img_pipeline(
pipe_type = params.get_valid_pipeline("img2img")
pipe = load_pipeline(
server,
params,
pipe_type,
params.model,
params.scheduler,
job.get_device(),
control=params.control,
inversions=inversions,
loras=loras,
)
Expand Down

0 comments on commit 95725ff

Please sign in to comment.