Skip to content

Commit

Permalink
fix(api): run garbage collection after each model change
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Feb 2, 2023
1 parent 2baf6ed commit 066b1a0
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 17 deletions.
2 changes: 2 additions & 0 deletions api/onnx_web/chain/correct_gfpgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
UpscaleParams,
)
from ..utils import (
run_gc,
ServerContext,
)
from .upscale_resrgan import (
Expand Down Expand Up @@ -51,6 +52,7 @@ def load_gfpgan(ctx: ServerContext, upscale: UpscaleParams, upsampler: Optional[

last_pipeline_instance = gfpgan
last_pipeline_params = face_path
run_gc()

return gfpgan

Expand Down
2 changes: 2 additions & 0 deletions api/onnx_web/chain/upscale_resrgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
UpscaleParams,
)
from ..utils import (
run_gc,
ServerContext,
)

Expand Down Expand Up @@ -66,6 +67,7 @@ def load_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):

last_pipeline_instance = upsampler
last_pipeline_params = cache_params
run_gc()

return upsampler

Expand Down
2 changes: 2 additions & 0 deletions api/onnx_web/chain/upscale_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
UpscaleParams,
)
from ..utils import (
run_gc,
ServerContext,
)

Expand Down Expand Up @@ -44,6 +45,7 @@ def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams):

last_pipeline_instance = pipeline
last_pipeline_params = cache_params
run_gc()

return pipeline

Expand Down
15 changes: 6 additions & 9 deletions api/onnx_web/diffusion/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
from ..params import (
Size,
)
from ..utils import (
run_gc,
)

import gc
import numpy as np
import torch

logger = getLogger(__name__)

Expand Down Expand Up @@ -39,7 +41,7 @@ def get_tile_latents(full_latents: np.ndarray, dims: Tuple[int, int, int]) -> np
xt = x + t
yt = y + t

return full_latents[:,:,y:yt,x:xt]
return full_latents[:, :, y:yt, x:xt]


def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler: Any, device: Optional[str] = None):
Expand All @@ -55,8 +57,7 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, schedu
logger.info('unloading previous diffusion pipeline')
last_pipeline_instance = None
last_pipeline_scheduler = None
gc.collect()
torch.cuda.empty_cache()
run_gc()

logger.info('loading new diffusion pipeline')
pipe = pipeline.from_pretrained(
Expand All @@ -83,10 +84,6 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, schedu

pipe.scheduler = scheduler
last_pipeline_scheduler = scheduler

logger.info('running garbage collection during pipeline change')
gc.collect()
run_gc()

return pipe


18 changes: 10 additions & 8 deletions api/onnx_web/diffusion/run.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from diffusers import (
OnnxStableDiffusionPipeline,
OnnxStableDiffusionImg2ImgPipeline,
OnnxStableDiffusionInpaintPipeline,
)
from logging import getLogger
from PIL import Image, ImageChops
Expand All @@ -10,9 +9,6 @@
from ..chain import (
upscale_outpaint,
)
from ..image import (
expand_image,
)
from ..params import (
ImageParams,
Border,
Expand All @@ -24,19 +20,20 @@
UpscaleParams,
)
from ..utils import (
is_debug,
base_join,
run_gc,
ServerContext,
)
from .load import (
get_latents_from_seed,
load_pipeline,
get_latents_from_seed,
load_pipeline,
)

import numpy as np

logger = getLogger(__name__)


def run_txt2img_pipeline(
ctx: ServerContext,
params: ImageParams,
Expand Down Expand Up @@ -69,6 +66,7 @@ def run_txt2img_pipeline(

del image
del result
run_gc()

logger.info('saved txt2img output: %s', dest)

Expand Down Expand Up @@ -104,6 +102,7 @@ def run_img2img_pipeline(

del image
del result
run_gc()

logger.info('saved img2img output: %s', dest)

Expand Down Expand Up @@ -139,7 +138,8 @@ def run_inpaint_pipeline(
if image.size == source_image.size:
image = ImageChops.blend(source_image, image, strength)
else:
logger.info('output image size does not match source, skipping post-blend')
logger.info(
'output image size does not match source, skipping post-blend')

image = run_upscale_correction(
ctx, stage, params, image, upscale=upscale)
Expand All @@ -148,6 +148,7 @@ def run_inpaint_pipeline(
image.save(dest)

del image
run_gc()

logger.info('saved inpaint output: %s', dest)

Expand All @@ -167,5 +168,6 @@ def run_upscale_pipeline(
image.save(dest)

del image
run_gc()

logger.info('saved img2img output: %s', dest)
9 changes: 9 additions & 0 deletions api/onnx_web/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from time import time
from typing import Any, Dict, List, Optional, Union, Tuple

import gc
import torch

from .params import (
ImageParams,
Param,
Expand Down Expand Up @@ -158,3 +161,9 @@ def make_output_name(
def base_join(base: str, tail: str) -> str:
tail_path = path.relpath(path.normpath(path.join('/', tail)), '/')
return path.join(base, tail_path)


def run_gc():
logger.debug('running garbage collection')
gc.collect()
torch.cuda.empty_cache()

0 comments on commit 066b1a0

Please sign in to comment.