Skip to content

Commit

Permalink
feat(api): add option for CPU-only conversion on systems with CUDA
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jun 8, 2023
1 parent 8215a1b commit 3f00da9
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
5 changes: 4 additions & 1 deletion api/onnx_web/chain/utils.py
Expand Up @@ -71,6 +71,8 @@ def blend_tiles(
overlap: float,
):
adj_tile = int(float(tile) * (1.0 - overlap))
logger.trace("adjusting tile size from %s to %s based on %s overlap", tile, adj_tile, overlap)

scaled_size = (height * scale, width * scale, 3)
count = np.zeros(scaled_size)
value = np.zeros(scaled_size)
Expand All @@ -82,7 +84,7 @@ def blend_tiles(
# gradient blending
points = [0, adj_tile * scale, (tile - adj_tile) * scale, (tile * scale) - 1]
grad_x, grad_y = get_tile_grads(left, top, adj_tile, width, height)
logger.trace("tile gradients: %s, %s", grad_x, grad_y)
logger.trace("tile gradients: %s, %s, %s", points, grad_x, grad_y)

mult_x = [np.interp(i, points, grad_x) for i in range(tile * scale)]
mult_y = [np.interp(i, points, grad_y) for i in range(tile * scale)]
Expand All @@ -98,6 +100,7 @@ def blend_tiles(
# equalized size may be wrong/too much
scaled_bottom = min(scaled_top + equalized.shape[0], scaled_size[0])
scaled_right = min(scaled_left + equalized.shape[1], scaled_size[1])
logger.trace("tile broadcast shapes: %s, %s, %s, %s", scaled_top, scaled_bottom, scaled_left, scaled_right)

# accumulation
value[scaled_top:scaled_bottom, scaled_left:scaled_right, :] += equalized[
Expand Down
12 changes: 10 additions & 2 deletions api/onnx_web/convert/utils.py
Expand Up @@ -46,6 +46,7 @@ def __init__(
reload: bool = True,
share_unet: bool = True,
extract: bool = False,

**kwargs,
) -> None:
super().__init__(model_path=model_path, cache_path=cache_path, **kwargs)
Expand All @@ -64,8 +65,6 @@ def __init__(
else:
self.training_device = "cuda" if torch.cuda.is_available() else "cpu"

self.map_location = torch.device(self.training_device)

@classmethod
def from_environ(cls):
context = super().from_environ()
Expand All @@ -74,8 +73,17 @@ def from_environ(cls):
context.reload = get_boolean(environ, "ONNX_WEB_CONVERT_RELOAD", True)
context.share_unet = get_boolean(environ, "ONNX_WEB_CONVERT_SHARE_UNET", True)
context.opset = int(environ.get("ONNX_WEB_CONVERT_OPSET", DEFAULT_OPSET))

cpu_only = get_boolean(environ, "ONNX_WEB_CONVERT_CPU_ONLY", False)
if cpu_only:
context.training_device = "cpu"

return context

@property
def map_location(self):
return torch.device(self.training_device)


def download_progress(urls: List[Tuple[str, str]]):
for url, dest in urls:
Expand Down

0 comments on commit 3f00da9

Please sign in to comment.