Skip to content

Commit

Permalink
feat(api): add reduce stages, noise source
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jan 29, 2023
1 parent 8d346cb commit c905fbb
Show file tree
Hide file tree
Showing 10 changed files with 229 additions and 80 deletions.
9 changes: 9 additions & 0 deletions api/onnx_web/chain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@
from .persist_s3 import (
persist_s3,
)
from .reduce_crop import (
reduce_crop,
)
from .reduce_thumbnail import (
reduce_thumbnail,
)
from .source_noise import (
source_noise,
)
from .source_txt2img import (
source_txt2img,
)
Expand Down
6 changes: 3 additions & 3 deletions api/onnx_web/chain/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ def __call__(self, ctx: ServerContext, params: ImageParams, source: Image.Image,
kwargs = stage_kwargs or {}
kwargs = {**pipeline_kwargs, **kwargs}

logger.info('running stage %s on result image with dimensions %sx%s, %s',
logger.info('running stage %s on image with dimensions %sx%s, %s',
name, image.width, image.height, kwargs.keys())

if image.width > stage_params.tile_size or image.height > stage_params.tile_size:
logger.info('source image larger than tile size of %s, tiling stage',
logger.info('image larger than tile size of %s, tiling stage',
stage_params.tile_size)

def stage_tile(tile: Image.Image, _dims) -> Image.Image:
Expand All @@ -89,7 +89,7 @@ def stage_tile(tile: Image.Image, _dims) -> Image.Image:
image = process_tile_grid(
image, stage_params.tile_size, stage_params.outscale, [stage_tile])
else:
logger.info('source image within tile size, running stage')
logger.info('image within tile size, running stage')
image = stage_pipe(ctx, stage_params, params, image,
**kwargs)

Expand Down
1 change: 0 additions & 1 deletion api/onnx_web/chain/persist_disk.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from logging import getLogger
from PIL import Image


from ..params import (
ImageParams,
StageParams,
Expand Down
30 changes: 30 additions & 0 deletions api/onnx_web/chain/reduce_crop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from logging import getLogger
from PIL import Image

from ..params import (
ImageParams,
Size,
StageParams,
)
from ..utils import (
ServerContext,
)

logger = getLogger(__name__)


def reduce_crop(
ctx: ServerContext,
_stage: StageParams,
_params: ImageParams,
source_image: Image.Image,
*,
origin: Size,
size: Size,
**kwargs,
) -> Image.Image:
image = source_image.crop(
(origin.width, origin.height, size.width, size.height))
logger.info('created thumbnail with dimensions: %sx%s',
image.width, image.height)
return image
28 changes: 28 additions & 0 deletions api/onnx_web/chain/reduce_thumbnail.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from logging import getLogger
from PIL import Image

from ..params import (
ImageParams,
Size,
StageParams,
)
from ..utils import (
ServerContext,
)

logger = getLogger(__name__)


def reduce_thumbnail(
ctx: ServerContext,
_stage: StageParams,
_params: ImageParams,
source_image: Image.Image,
*,
size: Size,
**kwargs,
) -> Image.Image:
image = source_image.thumbnail((size.width, size.height))
logger.info('created thumbnail with dimensions: %sx%s',
image.width, image.height)
return image
38 changes: 38 additions & 0 deletions api/onnx_web/chain/source_noise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from logging import getLogger
from PIL import Image
from typing import Callable

from ..params import (
ImageParams,
Size,
StageParams,
)
from ..utils import (
ServerContext,
)


logger = getLogger(__name__)


def source_noise(
ctx: ServerContext,
stage: StageParams,
params: ImageParams,
source_image: Image.Image,
*,
size: Size,
noise_source: Callable,
**kwargs,
) -> Image.Image:
prompt = prompt or params.prompt
logger.info('generating image from noise source')

if source_image is not None:
logger.warn(
'a source image was passed to a noise stage, but will be discarded')

output = noise_source(source_image, (size.width, size.height), (0, 0))

logger.info('final output image size: %sx%s', output.width, output.height)
return output
16 changes: 2 additions & 14 deletions api/onnx_web/chain/upscale_stable_diffusion.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from diffusers import (
AutoencoderKL,
DDPMScheduler,
StableDiffusionUpscalePipeline,
)
from logging import getLogger
Expand Down Expand Up @@ -40,19 +38,9 @@ def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams):
return last_pipeline_instance

if upscale.format == 'onnx':
# ValueError: Pipeline <class 'onnx_web.onnx.pipeline_onnx_stable_diffusion_upscale.OnnxStableDiffusionUpscalePipeline'>
# expected {'vae', 'unet', 'text_encoder', 'tokenizer', 'scheduler', 'low_res_scheduler'},
# but only {'scheduler', 'tokenizer', 'text_encoder', 'unet'} were passed.
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(
model_path,
vae=AutoencoderKL.from_pretrained(
model_path, subfolder='vae_encoder'),
low_res_scheduler=DDPMScheduler.from_pretrained(
model_path, subfolder='scheduler'),
)
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(model_path)
else:
pipeline = StableDiffusionUpscalePipeline.from_pretrained(
'stabilityai/stable-diffusion-x4-upscaler')
pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_path)

last_pipeline_instance = pipeline
last_pipeline_params = cache_params
Expand Down
171 changes: 110 additions & 61 deletions api/onnx_web/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
OnnxRuntimeModel,
OnnxStableDiffusionPipeline,
StableDiffusionPipeline,
StableDiffusionUpscalePipeline,
)
from logging import getLogger
from onnx import load, save_model
Expand Down Expand Up @@ -202,7 +203,7 @@ def onnx_export(


@torch.no_grad()
def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str):
def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, single_vae: bool = False):
'''
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
'''
Expand All @@ -212,6 +213,9 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str):
# diffusers go into a directory rather than .onnx file
logger.info('converting Diffusers model: %s -> %s/', name, dest_path)

if single_vae:
logger.info('converting model with single VAE')

if path.isdir(dest_path):
logger.info('ONNX model already exists, skipping.')
return
Expand Down Expand Up @@ -295,50 +299,75 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str):
)
del pipeline.unet

# VAE ENCODER
vae_encoder = pipeline.vae
vae_in_channels = vae_encoder.config.in_channels
vae_sample_size = vae_encoder.config.sample_size
# need to get the raw tensor output (sample) from the encoder
vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(
sample, return_dict)[0].sample()
onnx_export(
vae_encoder,
model_args=(
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(
device=training_device, dtype=dtype),
False,
),
output_path=output_path / "vae_encoder" / "model.onnx",
ordered_input_names=["sample", "return_dict"],
output_names=["latent_sample"],
dynamic_axes={
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
},
opset=opset,
)
if single_vae:
# SINGLE VAE
vae_only = pipeline.vae
vae_in_channels = vae_only.config.in_channels
vae_sample_size = vae_only.config.sample_size
# need to get the raw tensor output (sample) from the encoder
vae_only.forward = lambda sample, return_dict: vae_only.encode(
sample, return_dict)[0].sample()
onnx_export(
vae_only,
model_args=(
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(
device=training_device, dtype=dtype),
False,
),
output_path=output_path / "vae" / "model.onnx",
ordered_input_names=["sample", "return_dict"],
output_names=["latent_sample"],
dynamic_axes={
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
},
opset=opset,
)
else:
# VAE ENCODER
vae_encoder = pipeline.vae
vae_in_channels = vae_encoder.config.in_channels
vae_sample_size = vae_encoder.config.sample_size
# need to get the raw tensor output (sample) from the encoder
vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(
sample, return_dict)[0].sample()
onnx_export(
vae_encoder,
model_args=(
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(
device=training_device, dtype=dtype),
False,
),
output_path=output_path / "vae_encoder" / "model.onnx",
ordered_input_names=["sample", "return_dict"],
output_names=["latent_sample"],
dynamic_axes={
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
},
opset=opset,
)

# VAE DECODER
vae_decoder = pipeline.vae
vae_latent_channels = vae_decoder.config.latent_channels
vae_out_channels = vae_decoder.config.out_channels
# forward only through the decoder part
vae_decoder.forward = vae_encoder.decode
onnx_export(
vae_decoder,
model_args=(
torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to(
device=training_device, dtype=dtype),
False,
),
output_path=output_path / "vae_decoder" / "model.onnx",
ordered_input_names=["latent_sample", "return_dict"],
output_names=["sample"],
dynamic_axes={
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
},
opset=opset,
)

# VAE DECODER
vae_decoder = pipeline.vae
vae_latent_channels = vae_decoder.config.latent_channels
vae_out_channels = vae_decoder.config.out_channels
# forward only through the decoder part
vae_decoder.forward = vae_encoder.decode
onnx_export(
vae_decoder,
model_args=(
torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to(
device=training_device, dtype=dtype),
False,
),
output_path=output_path / "vae_decoder" / "model.onnx",
ordered_input_names=["latent_sample", "return_dict"],
output_names=["sample"],
dynamic_axes={
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
},
opset=opset,
)
del pipeline.vae

# SAFETY CHECKER
Expand Down Expand Up @@ -376,20 +405,32 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str):
safety_checker = None
feature_extractor = None

onnx_pipeline = OnnxStableDiffusionPipeline(
vae_encoder=OnnxRuntimeModel.from_pretrained(
output_path / "vae_encoder"),
vae_decoder=OnnxRuntimeModel.from_pretrained(
output_path / "vae_decoder"),
text_encoder=OnnxRuntimeModel.from_pretrained(
output_path / "text_encoder"),
tokenizer=pipeline.tokenizer,
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
scheduler=pipeline.scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
requires_safety_checker=safety_checker is not None,
)
if single_vae:
onnx_pipeline = StableDiffusionUpscalePipeline(
vae=OnnxRuntimeModel.from_pretrained(
output_path / "vae"),
text_encoder=OnnxRuntimeModel.from_pretrained(
output_path / "text_encoder"),
tokenizer=pipeline.tokenizer,
low_res_scheduler=pipeline.scheduler,
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
scheduler=pipeline.scheduler,
)
else:
onnx_pipeline = OnnxStableDiffusionPipeline(
vae_encoder=OnnxRuntimeModel.from_pretrained(
output_path / "vae_encoder"),
vae_decoder=OnnxRuntimeModel.from_pretrained(
output_path / "vae_decoder"),
text_encoder=OnnxRuntimeModel.from_pretrained(
output_path / "text_encoder"),
tokenizer=pipeline.tokenizer,
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
scheduler=pipeline.scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
requires_safety_checker=safety_checker is not None,
)

logger.info('exporting ONNX model')

Expand All @@ -398,8 +439,15 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str):

del pipeline
del onnx_pipeline
_ = OnnxStableDiffusionPipeline.from_pretrained(
output_path, provider="CPUExecutionProvider")

if single_vae:
_ = StableDiffusionUpscalePipeline.from_pretrained(
output_path, provider="CPUExecutionProvider"
)
else:
_ = OnnxStableDiffusionPipeline.from_pretrained(
output_path, provider="CPUExecutionProvider")

logger.info("ONNX pipeline is loadable")


Expand All @@ -409,7 +457,8 @@ def load_models(args, models: Models):
if source[0] in args.skip:
logger.info('Skipping model: %s', source[0])
else:
convert_diffuser(*source, args.opset, args.half, args.token)
single_vae = 'upscaling' in source[0]
convert_diffuser(*source, args.opset, args.half, args.token, single_vae=single_vae)

if args.upscaling:
for source in models.get('upscaling'):
Expand Down

0 comments on commit c905fbb

Please sign in to comment.