Skip to content

Commit

Permalink
feat: make pipeline type a request parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Apr 13, 2023
1 parent b255680 commit 2af1530
Show file tree
Hide file tree
Showing 18 changed files with 136 additions and 104 deletions.
40 changes: 13 additions & 27 deletions api/onnx_web/chain/blend_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,37 +33,23 @@ def blend_controlnet(

pipe = load_pipeline(
server,
OnnxStableDiffusionControlNetPipeline,
"controlnet",
params.model,
params.scheduler,
job.get_device(),
params.lpw,
)
if params.lpw:
logger.debug("using LPW pipeline for img2img")
rng = torch.manual_seed(params.seed)
result = pipe.img2img(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
image=source,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
strength=params.strength,
callback=callback,
)
else:
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
image=source,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
strength=params.strength,
callback=callback,
)

rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
image=source,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
strength=params.strength, # TODO: ControlNet strength
callback=callback,
)

output = result.images[0]

Expand Down
5 changes: 2 additions & 3 deletions api/onnx_web/chain/blend_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,12 @@ def blend_img2img(

pipe = load_pipeline(
server,
OnnxStableDiffusionImg2ImgPipeline,
params.pipeline,
params.model,
params.scheduler,
job.get_device(),
params.lpw,
)
if params.lpw:
if params.lpw():
logger.debug("using LPW pipeline for img2img")
rng = torch.manual_seed(params.seed)
result = pipe.img2img(
Expand Down
3 changes: 1 addition & 2 deletions api/onnx_web/chain/blend_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,9 @@ def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
params.model,
params.scheduler,
job.get_device(),
params.lpw,
)

if params.lpw:
if params.lpw():
logger.debug("using LPW pipeline for inpaint")
rng = torch.manual_seed(params.seed)
result = pipe.inpaint(
Expand Down
5 changes: 2 additions & 3 deletions api/onnx_web/chain/blend_pix2pix.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,12 @@ def blend_pix2pix(

pipe = load_pipeline(
server,
OnnxStableDiffusionInstructPix2PixPipeline,
"pix2pix",
params.model,
params.scheduler,
job.get_device(),
params.lpw,
)
if params.lpw:
if params.lpw():
logger.debug("using LPW pipeline for img2img")
rng = torch.manual_seed(params.seed)
result = pipe.img2img(
Expand Down
5 changes: 2 additions & 3 deletions api/onnx_web/chain/source_txt2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,13 @@ def source_txt2img(
latents = get_latents_from_seed(params.seed, size)
pipe = load_pipeline(
server,
OnnxStableDiffusionPipeline,
"txt2img",
params.model,
params.scheduler,
job.get_device(),
params.lpw,
)

if params.lpw:
if params.lpw():
logger.debug("using LPW pipeline for txt2img")
rng = torch.manual_seed(params.seed)
result = pipe.text2img(
Expand Down
5 changes: 2 additions & 3 deletions api/onnx_web/chain/upscale_outpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,12 @@ def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
latents = get_tile_latents(full_latents, dims)
pipe = load_pipeline(
server,
OnnxStableDiffusionInpaintPipeline,
"inpaint",
params.model,
params.scheduler,
job.get_device(),
params.lpw,
)
if params.lpw:
if params.lpw():
logger.debug("using LPW pipeline for inpaint")
rng = torch.manual_seed(params.seed)
result = pipe.inpaint(
Expand Down
48 changes: 31 additions & 17 deletions api/onnx_web/diffusers/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@
KDPM2DiscreteScheduler,
LMSDiscreteScheduler,
OnnxRuntimeModel,
OnnxStableDiffusionImg2ImgPipeline,
OnnxStableDiffusionInpaintPipeline,
OnnxStableDiffusionPipeline,
PNDMScheduler,
StableDiffusionPipeline,
)
from onnx import load_model
from transformers import CLIPTokenizer

from ..constants import ONNX_MODEL
from ..diffusers.utils import expand_prompt

try:
from diffusers import DEISMultistepScheduler
except ImportError:
Expand All @@ -38,8 +38,13 @@
except ImportError:
from ..diffusers.stub_scheduler import StubScheduler as UniPCMultistepScheduler

from ..constants import ONNX_MODEL
from ..convert.diffusion.lora import blend_loras, buffer_external_data_tensors
from ..convert.diffusion.textual_inversion import blend_textual_inversions
from ..diffusers.pipelines.controlnet import OnnxStableDiffusionControlNetPipeline
from ..diffusers.pipelines.pix2pix import OnnxStableDiffusionInstructPix2PixPipeline
from ..diffusers.lpw_stable_diffusion_onnx import OnnxStableDiffusionLongPromptWeightingPipeline
from ..diffusers.utils import expand_prompt
from ..params import DeviceParams, Size
from ..server import ServerContext
from ..utils import run_gc
Expand All @@ -49,6 +54,15 @@
latent_channels = 4
latent_factor = 8

available_pipelines = {
"controlnet": OnnxStableDiffusionControlNetPipeline,
"img2img": OnnxStableDiffusionImg2ImgPipeline,
"inpaint": OnnxStableDiffusionInpaintPipeline,
"lpw": OnnxStableDiffusionLongPromptWeightingPipeline,
"pix2pix": OnnxStableDiffusionInstructPix2PixPipeline,
"txt2img": OnnxStableDiffusionPipeline,
}

pipeline_schedulers = {
"ddim": DDIMScheduler,
"ddpm": DDPMScheduler,
Expand All @@ -68,8 +82,12 @@
}


def get_pipeline_schedulers():
return pipeline_schedulers
def get_available_pipelines() -> List[str]:
return list(available_pipelines.keys())


def get_pipeline_schedulers() -> List[str]:
return list(pipeline_schedulers.keys())


def get_scheduler_name(scheduler: Any) -> Optional[str]:
Expand Down Expand Up @@ -111,11 +129,10 @@ def get_tile_latents(

def load_pipeline(
server: ServerContext,
pipeline: DiffusionPipeline,
pipeline: str,
model: str,
scheduler_name: str,
device: DeviceParams,
lpw: bool,
control: Optional[str] = None,
inversions: Optional[List[Tuple[str, float]]] = None,
loras: Optional[List[Tuple[str, float]]] = None,
Expand All @@ -129,7 +146,7 @@ def load_pipeline(
)
logger.debug("using Torch dtype %s for pipeline", torch_dtype)
pipe_key = (
pipeline.__name__,
pipeline,
model,
device.device,
device.provider,
Expand Down Expand Up @@ -170,11 +187,6 @@ def load_pipeline(
logger.debug("unloading previous diffusion pipeline")
run_gc([device])

if lpw:
custom_pipeline = "./onnx_web/diffusers/lpw_stable_diffusion_onnx.py"
else:
custom_pipeline = None

logger.debug("loading new diffusion pipeline from %s", model)
components = {
"scheduler": scheduler_type.from_pretrained(
Expand Down Expand Up @@ -281,6 +293,7 @@ def load_pipeline(
)
)

# ControlNet component
if control is not None:
components["controlnet"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
Expand All @@ -290,9 +303,10 @@ def load_pipeline(
)
)

pipe = pipeline.from_pretrained(
pipeline_class = available_pipelines.get(pipeline, OnnxStableDiffusionPipeline)
logger.debug("loading pretrained SD pipeline for %s", pipeline_class.__name__)
pipe = pipeline_class.from_pretrained(
model,
custom_pipeline=custom_pipeline,
provider=device.ort_provider(),
sess_options=device.sess_options(),
revision="onnx",
Expand All @@ -306,12 +320,12 @@ def load_pipeline(

optimize_pipeline(server, pipe)

# TODO: CPU VAE, etc
if device is not None and hasattr(pipe, "to"):
pipe = pipe.to(device.torch_str())

# monkey-patch pipeline
if not lpw:
patch_pipeline(server, pipe, pipeline)
patch_pipeline(server, pipe, pipeline)

server.cache.set("diffusion", pipe_key, pipe)
server.cache.set("scheduler", scheduler_key, components["scheduler"])
Expand Down
16 changes: 6 additions & 10 deletions api/onnx_web/diffusers/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,16 @@ def run_txt2img_pipeline(

pipe = load_pipeline(
server,
OnnxStableDiffusionPipeline,
"txt2img",
params.model,
params.scheduler,
job.get_device(),
params.lpw,
inversions,
loras,
)
progress = job.get_progress_callback()

if params.lpw:
if params.lpw():
logger.debug("using LPW pipeline for txt2img")
rng = torch.manual_seed(params.seed)
result = pipe.text2img(
Expand Down Expand Up @@ -117,11 +116,10 @@ def run_txt2img_pipeline(
# load img2img pipeline once
highres_pipe = load_pipeline(
server,
OnnxStableDiffusionImg2ImgPipeline,
"img2img",
params.model,
params.scheduler,
job.get_device(),
params.lpw,
inversions,
loras,
)
Expand Down Expand Up @@ -153,7 +151,7 @@ def highres_tile(tile: Image.Image, dims):
callback=highres_progress,
)

if params.lpw:
if params.lpw():
logger.debug("using LPW pipeline for highres")
rng = torch.manual_seed(params.seed)
result = highres_pipe.img2img(
Expand Down Expand Up @@ -233,18 +231,16 @@ def run_img2img_pipeline(

pipe = load_pipeline(
server,
# OnnxStableDiffusionImg2ImgPipeline,
OnnxStableDiffusionControlNetPipeline,
"img2img",
params.model,
params.scheduler,
job.get_device(),
params.lpw,
control=params.control,
inversions=inversions,
loras=loras,
)
progress = job.get_progress_callback()
if params.lpw:
if params.lpw():
logger.debug("using LPW pipeline for img2img")
rng = torch.manual_seed(params.seed)
result = pipe.img2img(
Expand Down
2 changes: 1 addition & 1 deletion api/onnx_web/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,13 @@ def make_output_name(

hash_value(sha, mode)
hash_value(sha, params.model)
hash_value(sha, params.pipeline)
hash_value(sha, params.scheduler)
hash_value(sha, params.prompt)
hash_value(sha, params.negative_prompt)
hash_value(sha, params.cfg)
hash_value(sha, params.seed)
hash_value(sha, params.steps)
hash_value(sha, params.lpw)
hash_value(sha, params.eta)
hash_value(sha, params.batch)
hash_value(sha, size.width)
Expand Down

0 comments on commit 2af1530

Please sign in to comment.