Skip to content

Commit

Permalink
fix(api): sanitize filenames in user input
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Mar 1, 2023
1 parent c99aa67 commit 12fb7f5
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 8 deletions.
3 changes: 2 additions & 1 deletion api/onnx_web/convert/diffusion/original.py
Expand Up @@ -53,7 +53,8 @@
CLIPVisionConfig,
)

from ..utils import ConversionContext, ModelDict, load_tensor, load_yaml, sanitize_name
from ...utils import sanitize_name
from ..utils import ConversionContext, ModelDict, load_tensor, load_yaml
from .diffusers import convert_diffusion_diffusers

logger = getLogger(__name__)
Expand Down
7 changes: 0 additions & 7 deletions api/onnx_web/convert/utils.py
Expand Up @@ -189,13 +189,6 @@ def load_yaml(file: str) -> Config:
return Config(data)


safe_chars = "._-"


def sanitize_name(name):
return "".join(x for x in name if (x.isalnum() or x in safe_chars))


def remove_prefix(name, prefix):
if name.startswith(prefix):
return name[len(prefix) :]
Expand Down
3 changes: 3 additions & 0 deletions api/onnx_web/server/api.py
Expand Up @@ -28,6 +28,7 @@
get_from_map,
get_not_empty,
get_size,
sanitize_name,
)
from ..worker.pool import DevicePoolExecutor
from .config import (
Expand Down Expand Up @@ -428,6 +429,7 @@ def cancel(context: ServerContext, pool: DevicePoolExecutor):
if output_file is None:
return error_reply("output name is required")

output_file = sanitize_name(output_file)
cancel = pool.cancel(output_file)

return ready_reply(cancel)
Expand All @@ -438,6 +440,7 @@ def ready(context: ServerContext, pool: DevicePoolExecutor):
if output_file is None:
return error_reply("output name is required")

output_file = sanitize_name(output_file)
done, progress = pool.done(output_file)

if done is None:
Expand Down
6 changes: 6 additions & 0 deletions api/onnx_web/utils.py
Expand Up @@ -10,6 +10,8 @@

logger = getLogger(__name__)

SAFE_CHARS = "._-"


def base_join(base: str, tail: str) -> str:
tail_path = path.relpath(path.normpath(path.join("/", tail)), "/")
Expand Down Expand Up @@ -100,3 +102,7 @@ def run_gc(devices: Optional[List[DeviceParams]] = None):
(mem_total - mem_free),
mem_total,
)


def sanitize_name(name):
return "".join(x for x in name if (x.isalnum() or x in SAFE_CHARS))

0 comments on commit 12fb7f5

Please sign in to comment.