Skip to content

Commit

Permalink
fix(api): set CUDA device in ORT session
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Feb 5, 2023
1 parent d636ce3 commit 04a2faf
Show file tree
Hide file tree
Showing 21 changed files with 116 additions and 52 deletions.
6 changes: 6 additions & 0 deletions api/dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
mypy

types-Flask-Cors
types-jsonschema
types-Pillow
types-PyYAML
5 changes: 3 additions & 2 deletions api/onnx_web/chain/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
class StageCallback(Protocol):
def __call__(
self,
job: JobContext,
ctx: ServerContext,
stage: StageParams,
params: ImageParams,
Expand Down Expand Up @@ -83,7 +84,7 @@ def __call__(self, job: JobContext, server: ServerContext, params: ImageParams,
stage_params.tile_size)

def stage_tile(tile: Image.Image, _dims) -> Image.Image:
tile = stage_pipe(server, stage_params, params, tile,
tile = stage_pipe(job, server, stage_params, params, tile,
**kwargs)

if is_debug():
Expand All @@ -95,7 +96,7 @@ def stage_tile(tile: Image.Image, _dims) -> Image.Image:
image, stage_params.tile_size, stage_params.outscale, [stage_tile])
else:
logger.info('image within tile size, running stage')
image = stage_pipe(server, stage_params, params, image,
image = stage_pipe(job, server, stage_params, params, image,
**kwargs)

logger.info('finished stage %s, result size: %sx%s',
Expand Down
8 changes: 6 additions & 2 deletions api/onnx_web/chain/blend_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from logging import getLogger
from PIL import Image

from ..device_pool import (
JobContext,
)
from ..diffusion.load import (
load_pipeline,
)
Expand All @@ -21,7 +24,8 @@


def blend_img2img(
_ctx: ServerContext,
job: JobContext,
_server: ServerContext,
_stage: StageParams,
params: ImageParams,
source_image: Image.Image,
Expand All @@ -34,7 +38,7 @@ def blend_img2img(
logger.info('generating image using img2img, %s steps: %s', params.steps, prompt)

pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline,
params.model, params.provider, params.scheduler)
params.model, params.scheduler, job.get_device())

rng = np.random.RandomState(params.seed)

Expand Down
18 changes: 11 additions & 7 deletions api/onnx_web/chain/blend_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from PIL import Image
from typing import Callable, Tuple

from ..device_pool import (
JobContext,
)
from ..diffusion.load import (
get_latents_from_seed,
load_pipeline,
Expand Down Expand Up @@ -38,7 +41,8 @@


def blend_inpaint(
ctx: ServerContext,
job: JobContext,
server: ServerContext,
stage: StageParams,
params: ImageParams,
source_image: Image.Image,
Expand All @@ -65,21 +69,21 @@ def blend_inpaint(
mask_filter=mask_filter)

if is_debug():
save_image(ctx, 'last-source.png', source_image)
save_image(ctx, 'last-mask.png', mask_image)
save_image(ctx, 'last-noise.png', noise_image)
save_image(server, 'last-source.png', source_image)
save_image(server, 'last-mask.png', mask_image)
save_image(server, 'last-noise.png', noise_image)

def outpaint(image: Image.Image, dims: Tuple[int, int, int]):
left, top, tile = dims
size = Size(*image.size)
mask = mask_image.crop((left, top, left + tile, top + tile))

if is_debug():
save_image(ctx, 'tile-source.png', image)
save_image(ctx, 'tile-mask.png', mask)
save_image(server, 'tile-source.png', image)
save_image(server, 'tile-mask.png', mask)

pipe = load_pipeline(OnnxStableDiffusionInpaintPipeline,
params.model, params.provider, params.scheduler)
params.model, params.scheduler, job.get_device())

latents = get_latents_from_seed(params.seed, size)
rng = np.random.RandomState(params.seed)
Expand Down
8 changes: 6 additions & 2 deletions api/onnx_web/chain/correct_gfpgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from realesrgan import RealESRGANer
from typing import Optional

from ..device_pool import (
JobContext,
)
from ..params import (
ImageParams,
StageParams,
Expand Down Expand Up @@ -60,7 +63,8 @@ def load_gfpgan(ctx: ServerContext, upscale: UpscaleParams, upsampler: Optional[


def correct_gfpgan(
ctx: ServerContext,
_job: JobContext,
server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source_image: Image.Image,
Expand All @@ -74,7 +78,7 @@ def correct_gfpgan(
return source_image

logger.info('correcting faces with GFPGAN model: %s', upscale.correction_model)
gfpgan = load_gfpgan(ctx, upscale, upsampler=upsampler)
gfpgan = load_gfpgan(server, upscale, upsampler=upsampler)

output = np.array(source_image)
_, _, output = gfpgan.enhance(
Expand Down
4 changes: 4 additions & 0 deletions api/onnx_web/chain/persist_disk.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from logging import getLogger
from PIL import Image

from ..device_pool import (
JobContext,
)
from ..params import (
ImageParams,
StageParams,
Expand All @@ -16,6 +19,7 @@


def persist_disk(
_job: JobContext,
ctx: ServerContext,
_stage: StageParams,
_params: ImageParams,
Expand Down
3 changes: 3 additions & 0 deletions api/onnx_web/chain/persist_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from logging import getLogger
from PIL import Image

from ..device_pool import (
JobContext,
)
from ..params import (
ImageParams,
StageParams,
Expand Down
3 changes: 3 additions & 0 deletions api/onnx_web/chain/reduce_crop.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from logging import getLogger
from PIL import Image

from ..device_pool import (
JobContext,
)
from ..params import (
ImageParams,
Size,
Expand Down
3 changes: 3 additions & 0 deletions api/onnx_web/chain/reduce_thumbnail.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from logging import getLogger
from PIL import Image

from ..device_pool import (
JobContext,
)
from ..params import (
ImageParams,
Size,
Expand Down
3 changes: 3 additions & 0 deletions api/onnx_web/chain/source_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from PIL import Image
from typing import Callable

from ..device_pool import (
JobContext,
)
from ..params import (
ImageParams,
Size,
Expand Down
8 changes: 6 additions & 2 deletions api/onnx_web/chain/source_txt2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from logging import getLogger
from PIL import Image

from ..device_pool import (
JobContext,
)
from ..diffusion.load import (
get_latents_from_seed,
load_pipeline,
Expand All @@ -23,7 +26,8 @@


def source_txt2img(
ctx: ServerContext,
job: JobContext,
server: ServerContext,
stage: StageParams,
params: ImageParams,
source_image: Image.Image,
Expand All @@ -39,7 +43,7 @@ def source_txt2img(
logger.warn('a source image was passed to a txt2img stage, but will be discarded')

pipe = load_pipeline(OnnxStableDiffusionPipeline,
params.model, params.provider, params.scheduler)
params.model, params.scheduler, job.get_device())

latents = get_latents_from_seed(params.seed, size)
rng = np.random.RandomState(params.seed)
Expand Down
3 changes: 3 additions & 0 deletions api/onnx_web/chain/upscale_outpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from PIL import Image, ImageDraw
from typing import Callable, Tuple

from ..device_pool import (
JobContext,
)
from ..diffusion.load import (
get_latents_from_seed,
get_tile_latents,
Expand Down
13 changes: 9 additions & 4 deletions api/onnx_web/chain/upscale_resrgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@
from PIL import Image
from realesrgan import RealESRGANer

from ..device_pool import (
JobContext,
)
from ..onnx import (
OnnxNet,
)
from ..params import (
DeviceParams,
ImageParams,
StageParams,
UpscaleParams,
Expand All @@ -25,7 +29,7 @@
last_pipeline_params = (None, None)


def load_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
def load_resrgan(ctx: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0):
global last_pipeline_instance
global last_pipeline_params

Expand All @@ -41,7 +45,7 @@ def load_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):

# use ONNX acceleration, if available
if params.format == 'onnx':
model = OnnxNet(ctx, model_file, provider=params.provider)
model = OnnxNet(ctx, model_file, provider=device.provider, sess_options=device.options)
elif params.format == 'pth':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
num_block=23, num_grow_ch=32, scale=params.scale)
Expand Down Expand Up @@ -76,7 +80,8 @@ def load_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):


def upscale_resrgan(
ctx: ServerContext,
job: JobContext,
server: ServerContext,
stage: StageParams,
_params: ImageParams,
source_image: Image.Image,
Expand All @@ -87,7 +92,7 @@ def upscale_resrgan(
logger.info('upscaling image with Real ESRGAN: x%s', upscale.scale)

output = np.array(source_image)
upsampler = load_resrgan(ctx, upscale, tile=stage.tile_size)
upsampler = load_resrgan(server, upscale, job.get_device(), tile=stage.tile_size)

output, _ = upsampler.enhance(output, outscale=upscale.outscale)

Expand Down
19 changes: 12 additions & 7 deletions api/onnx_web/chain/upscale_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@
from os import path
from PIL import Image

from ..device_pool import (
JobContext,
)
from ..diffusion.pipeline_onnx_stable_diffusion_upscale import (
OnnxStableDiffusionUpscalePipeline,
)
from ..params import (
DeviceParams,
ImageParams,
StageParams,
UpscaleParams,
Expand All @@ -27,7 +31,7 @@
last_pipeline_params = (None, None)


def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams):
def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams, device: DeviceParams):
global last_pipeline_instance
global last_pipeline_params

Expand All @@ -39,11 +43,11 @@ def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams):
return last_pipeline_instance

if upscale.format == 'onnx':
logger.debug('loading Stable Diffusion upscale ONNX model from %s, using provider %s', model_path, upscale.provider)
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(model_path, provider=upscale.provider)
logger.debug('loading Stable Diffusion upscale ONNX model from %s, using provider %s', model_path, device.provider)
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(model_path, provider=device.provider, sess_options=device.options)
else:
logger.debug('loading Stable Diffusion upscale model from %s, using provider %s', model_path, upscale.provider)
pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_path, provider=upscale.provider)
logger.debug('loading Stable Diffusion upscale model from %s, using provider %s', model_path, device.provider)
pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_path, provider=device.provider, sess_options=device.options)

last_pipeline_instance = pipeline
last_pipeline_params = cache_params
Expand All @@ -53,7 +57,8 @@ def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams):


def upscale_stable_diffusion(
ctx: ServerContext,
job: JobContext,
server: ServerContext,
_stage: StageParams,
params: ImageParams,
source: Image.Image,
Expand All @@ -65,7 +70,7 @@ def upscale_stable_diffusion(
prompt = prompt or params.prompt
logger.info('upscaling with Stable Diffusion, %s steps: %s', params.steps, prompt)

pipeline = load_stable_diffusion(ctx, upscale)
pipeline = load_stable_diffusion(server, upscale, job.get_device())
generator = torch.manual_seed(params.seed)

return pipeline(
Expand Down
3 changes: 3 additions & 0 deletions api/onnx_web/device_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ def cancel(self, key: str) -> bool:
return True
else:
job.set_cancel()
return True

return False

def done(self, key: str) -> Tuple[bool, int]:
for job in self.jobs:
Expand Down

0 comments on commit 04a2faf

Please sign in to comment.