Skip to content

Commit

Permalink
feat(api): add feature flag for single-tile panorama highres
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Nov 11, 2023
1 parent 798fa5f commit 5fb2de8
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 12 deletions.
4 changes: 2 additions & 2 deletions api/onnx_web/chain/source_txt2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,11 @@ def run(
latents = get_tile_latents(latents, int(params.seed), latent_size, dims)

# reseed latents as needed
reseed_rng = np.random.default_rng(params.seed)
reseed_rng = np.random.RandomState(params.seed)
prompt, reseed = parse_reseed(prompt)
for top, left, bottom, right, region_seed in reseed:
if region_seed == -1:
region_seed = reseed_rng.integers(2**32)
region_seed = reseed_rng.random_integers(2**32 - 1)

logger.debug(
"reseed latent region: [:, :, %s:%s, %s:%s] with %s",
Expand Down
7 changes: 6 additions & 1 deletion api/onnx_web/diffusers/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,12 @@ def run_txt2img_pipeline(
)

# apply upscaling and correction, before highres
stage = StageParams(tile_size=params.unet_tile)
if params.is_panorama() and server.panorama_tiles:
highres_size = tile_size * highres.scale
else:
highres_size = params.unet_tile

stage = StageParams(tile_size=highres_size)
first_upscale, after_upscale = split_upscale(upscale)
if first_upscale:
stage_upscale_correction(
Expand Down
3 changes: 0 additions & 3 deletions api/onnx_web/diffusers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,6 @@ def get_tokens_from_prompt(
pattern: Pattern,
parser=parse_float_group,
) -> Tuple[str, List[Tuple[str, float]]]:
"""
TODO: replace with Arpeggio
"""
remaining_prompt = prompt

tokens = []
Expand Down
48 changes: 42 additions & 6 deletions api/onnx_web/server/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,34 +10,60 @@

logger = getLogger(__name__)

DEFAULT_ANY_PLATFORM = True
DEFAULT_CACHE_LIMIT = 5
DEFAULT_JOB_LIMIT = 10
DEFAULT_IMAGE_FORMAT = "png"
DEFAULT_SERVER_VERSION = "v0.10.0"
DEFAULT_SHOW_PROGRESS = True
DEFAULT_PANORAMA_TILES = False
DEFAULT_WORKER_RETRIES = 3


class ServerContext:
bundle_path: str
model_path: str
output_path: str
params_path: str
cors_origin: str
any_platform: bool
block_platforms: List[str]
default_platform: str
image_format: str
cache_limit: int
cache_path: str
show_progress: bool
optimizations: List[str]
extra_models: List[str]
job_limit: int
memory_limit: int
admin_token: str
server_version: str
worker_retries: int
panorama_tiles: bool

def __init__(
self,
bundle_path: str = ".",
model_path: str = ".",
output_path: str = ".",
params_path: str = ".",
cors_origin: str = "*",
any_platform: bool = True,
any_platform: bool = DEFAULT_ANY_PLATFORM,
block_platforms: Optional[List[str]] = None,
default_platform: Optional[str] = None,
image_format: str = DEFAULT_IMAGE_FORMAT,
cache_limit: int = DEFAULT_CACHE_LIMIT,
cache_path: Optional[str] = None,
show_progress: bool = True,
show_progress: bool = DEFAULT_SHOW_PROGRESS,
optimizations: Optional[List[str]] = None,
extra_models: Optional[List[str]] = None,
job_limit: int = DEFAULT_JOB_LIMIT,
memory_limit: Optional[int] = None,
admin_token: Optional[str] = None,
server_version: Optional[str] = DEFAULT_SERVER_VERSION,
worker_retries: Optional[int] = 3,
worker_retries: Optional[int] = DEFAULT_WORKER_RETRIES,
panorama_tiles: Optional[bool] = DEFAULT_PANORAMA_TILES,
) -> None:
self.bundle_path = bundle_path
self.model_path = model_path
Expand All @@ -58,6 +84,7 @@ def __init__(
self.admin_token = admin_token or token_urlsafe()
self.server_version = server_version
self.worker_retries = worker_retries
self.panorama_tiles = panorama_tiles

self.cache = ModelCache(self.cache_limit)

Expand All @@ -76,12 +103,16 @@ def from_environ(cls):
params_path=environ.get("ONNX_WEB_PARAMS_PATH", "."),
# others
cors_origin=environ.get("ONNX_WEB_CORS_ORIGIN", "*").split(","),
any_platform=get_boolean(environ, "ONNX_WEB_ANY_PLATFORM", True),
any_platform=get_boolean(
environ, "ONNX_WEB_ANY_PLATFORM", DEFAULT_ANY_PLATFORM
),
block_platforms=environ.get("ONNX_WEB_BLOCK_PLATFORMS", "").split(","),
default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None),
image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"),
cache_limit=int(environ.get("ONNX_WEB_CACHE_MODELS", DEFAULT_CACHE_LIMIT)),
show_progress=get_boolean(environ, "ONNX_WEB_SHOW_PROGRESS", True),
show_progress=get_boolean(
environ, "ONNX_WEB_SHOW_PROGRESS", DEFAULT_SHOW_PROGRESS
),
optimizations=environ.get("ONNX_WEB_OPTIMIZATIONS", "").split(","),
extra_models=environ.get("ONNX_WEB_EXTRA_MODELS", "").split(","),
job_limit=int(environ.get("ONNX_WEB_JOB_LIMIT", DEFAULT_JOB_LIMIT)),
Expand All @@ -90,7 +121,12 @@ def from_environ(cls):
server_version=environ.get(
"ONNX_WEB_SERVER_VERSION", DEFAULT_SERVER_VERSION
),
worker_retries=int(environ.get("ONNX_WEB_WORKER_RETRIES", 3)),
worker_retries=int(
environ.get("ONNX_WEB_WORKER_RETRIES", DEFAULT_WORKER_RETRIES)
),
panorama_tiles=get_boolean(
environ, "ONNX_WEB_PANORAMA_TILES", DEFAULT_PANORAMA_TILES
),
)

def torch_dtype(self):
Expand Down
1 change: 1 addition & 0 deletions onnx-web.code-workspace
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"ddpm",
"deis",
"denoise",
"denoised",
"denoising",
"directml",
"Dreambooth",
Expand Down

0 comments on commit 5fb2de8

Please sign in to comment.