Skip to content

Commit

Permalink
Improve Artifact Store isolation (#2490)
Browse files Browse the repository at this point in the history
* dir traversal issue

* Auto-update of Starter template

* Auto-update of NLP template

* reroute artifacts and logs via AS

* reroute materializers via AS

* simplify to one deco

* fix materializer tests

* allow local download

* Auto-update of E2E template

* fix test issues

* rework based on comments

* fix bugs

* lint

* Candidate (#2493)

Co-authored-by: Stefan Nica <stefan@zenml.io>

* darglint

---------

Co-authored-by: GitHub Actions <actions@github.com>
Co-authored-by: Stefan Nica <stefan@zenml.io>
  • Loading branch information
3 people committed Mar 5, 2024
1 parent 683e943 commit 00e934f
Show file tree
Hide file tree
Showing 17 changed files with 279 additions and 113 deletions.
167 changes: 113 additions & 54 deletions src/zenml/artifact_stores/base_artifact_store.py
Expand Up @@ -13,8 +13,10 @@
# permissions and limitations under the License.
"""The base interface to extend the ZenML artifact store."""

import inspect
import textwrap
from abc import abstractmethod
from pathlib import Path
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -44,50 +46,93 @@
PathType = Union[bytes, str]


def _sanitize_potential_path(potential_path: Any) -> Any:
"""Sanitizes the input if it is a path.
class _sanitize_paths:
"""Sanitizes path inputs before calling the original function.
If the input is a **remote** path, this function replaces backslash path
separators by forward slashes.
Extra decoration layer is needed to pass in fixed artifact store root
path for static methods that are called on filesystems directly.
Args:
potential_path: Value that potentially refers to a (remote) path.
func: The function to decorate.
fixed_root_path: The fixed artifact store root path.
is_static: Whether the function is static or not.
Returns:
The original input or a sanitized version of it in case of a remote
path.
Function that calls the input function with sanitized path inputs.
"""
if isinstance(potential_path, bytes):
path = fileio.convert_to_str(potential_path)
elif isinstance(potential_path, str):
path = potential_path
else:
# Neither string nor bytes, this is not a path
return potential_path

if io_utils.is_remote(path):
# If we have a remote path, replace windows path separators with
# slashes
import ntpath
import posixpath
def __init__(self, func: Callable[..., Any], fixed_root_path: str) -> None:
"""Initializes the decorator.
path = path.replace(ntpath.sep, posixpath.sep)
Args:
func: The function to decorate.
fixed_root_path: The fixed artifact store root path.
"""
self.func = func
self.fixed_root_path = fixed_root_path

return path
self.path_args: List[int] = []
self.path_kwargs: List[str] = []
for i, param in enumerate(
inspect.signature(self.func).parameters.values()
):
if param.annotation == PathType:
self.path_kwargs.append(param.name)
if param.default == inspect.Parameter.empty:
self.path_args.append(i)

def _validate_path(self, path: str) -> None:
"""Validates a path.
def _sanitize_paths(_func: Callable[..., Any]) -> Callable[..., Any]:
"""Sanitizes path inputs before calling the original function.
Args:
path: The path to validate.
Args:
_func: The function for which to sanitize the inputs.
Raises:
FileNotFoundError: If the path is outside of the artifact store
bounds.
"""
if not path.startswith(self.fixed_root_path):
raise FileNotFoundError(
f"File `{path}` is outside of "
f"artifact store bounds `{self.fixed_root_path}`"
)

Returns:
Function that calls the input function with sanitized path inputs.
"""
def _sanitize_potential_path(self, potential_path: Any) -> Any:
"""Sanitizes the input if it is a path.
If the input is a **remote** path, this function replaces backslash path
separators by forward slashes.
def inner_function(*args: Any, **kwargs: Any) -> Any:
"""Inner function.
Args:
potential_path: Value that potentially refers to a (remote) path.
Returns:
The original input or a sanitized version of it in case of a remote
path.
"""
if isinstance(potential_path, bytes):
path = fileio.convert_to_str(potential_path)
elif isinstance(potential_path, str):
path = potential_path
else:
# Neither string nor bytes, this is not a path
return potential_path

if io_utils.is_remote(path):
# If we have a remote path, replace windows path separators with
# slashes
import ntpath
import posixpath

path = path.replace(ntpath.sep, posixpath.sep)
self._validate_path(path)
else:
self._validate_path(str(Path(path).absolute().resolve()))

return path

def __call__(self, *args: Any, **kwargs: Any) -> Any:
"""Decorator function that sanitizes paths before calling the original function.
Args:
*args: Positional args.
Expand All @@ -96,15 +141,28 @@ def inner_function(*args: Any, **kwargs: Any) -> Any:
Returns:
Output of the input function called with sanitized paths.
"""
args = tuple(_sanitize_potential_path(arg) for arg in args)
# verify if `self` is part of the args
has_self = bool(args and isinstance(args[0], BaseArtifactStore))

# sanitize inputs for relevant args and kwargs, keep rest unchanged
args = tuple(
self._sanitize_potential_path(
arg,
)
if i + has_self in self.path_args
else arg
for i, arg in enumerate(args)
)
kwargs = {
key: _sanitize_potential_path(value)
key: self._sanitize_potential_path(
value,
)
if key in self.path_kwargs
else value
for key, value in kwargs.items()
}

return _func(*args, **kwargs)

return inner_function
return self.func(*args, **kwargs)


class BaseArtifactStoreConfig(StackComponentConfig):
Expand Down Expand Up @@ -323,6 +381,7 @@ def stat(self, path: PathType) -> Any:
The stat descriptor.
"""

@abstractmethod
def size(self, path: PathType) -> Optional[int]:
"""Get the size of a file in bytes.
Expand Down Expand Up @@ -376,30 +435,30 @@ def _register(self) -> None:
from zenml.io.filesystem_registry import default_filesystem_registry
from zenml.io.local_filesystem import LocalFilesystem

overloads: Dict[str, Any] = {
"SUPPORTED_SCHEMES": self.config.SUPPORTED_SCHEMES,
}
for abc_method in inspect.getmembers(BaseArtifactStore):
if getattr(abc_method[1], "__isabstractmethod__", False):
sanitized_method = _sanitize_paths(
getattr(self, abc_method[0]), self.path
)
# prepare overloads for filesystem methods
overloads[abc_method[0]] = staticmethod(sanitized_method)

# decorate artifact store methods
setattr(
self,
abc_method[0],
sanitized_method,
)

# Local filesystem is always registered, no point in doing it again.
if isinstance(self, LocalFilesystem):
return

filesystem_class = type(
self.__class__.__name__,
(BaseFilesystem,),
{
"SUPPORTED_SCHEMES": self.config.SUPPORTED_SCHEMES,
"open": staticmethod(_sanitize_paths(self.open)),
"copyfile": staticmethod(_sanitize_paths(self.copyfile)),
"exists": staticmethod(_sanitize_paths(self.exists)),
"glob": staticmethod(_sanitize_paths(self.glob)),
"isdir": staticmethod(_sanitize_paths(self.isdir)),
"listdir": staticmethod(_sanitize_paths(self.listdir)),
"makedirs": staticmethod(_sanitize_paths(self.makedirs)),
"mkdir": staticmethod(_sanitize_paths(self.mkdir)),
"remove": staticmethod(_sanitize_paths(self.remove)),
"rename": staticmethod(_sanitize_paths(self.rename)),
"rmtree": staticmethod(_sanitize_paths(self.rmtree)),
"size": staticmethod(_sanitize_paths(self.size)),
"stat": staticmethod(_sanitize_paths(self.stat)),
"walk": staticmethod(_sanitize_paths(self.walk)),
},
self.__class__.__name__, (BaseFilesystem,), overloads
)

default_filesystem_registry.register(filesystem_class)
Expand Down
10 changes: 7 additions & 3 deletions src/zenml/artifacts/utils.py
Expand Up @@ -152,7 +152,7 @@ def save_artifact(
if not uri.startswith(artifact_store.path):
uri = os.path.join(artifact_store.path, uri)

if manual_save and fileio.exists(uri):
if manual_save and artifact_store.exists(uri):
# This check is only necessary for manual saves as we already check
# it when creating the directory for step output artifacts
other_artifacts = client.list_artifact_versions(uri=uri, size=1)
Expand All @@ -162,7 +162,7 @@ def save_artifact(
f"{uri} because the URI is already used by artifact "
f"{other_artifact.name} (version {other_artifact.version})."
)
fileio.makedirs(uri)
artifact_store.makedirs(uri)

# Find and initialize the right materializer class
if isinstance(materializer, type):
Expand Down Expand Up @@ -752,6 +752,7 @@ def _load_file_from_artifact_store(
Raises:
DoesNotExistException: If the file does not exist in the artifact store.
NotImplementedError: If the artifact store cannot open the file.
IOError: If the artifact store rejects the request.
"""
try:
with artifact_store.open(uri, mode) as text_file:
Expand All @@ -761,6 +762,8 @@ def _load_file_from_artifact_store(
f"File '{uri}' does not exist in artifact store "
f"'{artifact_store.name}'."
)
except IOError as e:
raise e
except Exception as e:
logger.exception(e)
link = "https://docs.zenml.io/stacks-and-components/component-guide/artifact-stores/custom#enabling-artifact-visualizations-with-custom-artifact-stores"
Expand Down Expand Up @@ -819,7 +822,8 @@ def load_model_from_metadata(model_uri: str) -> Any:
The ML model object loaded into memory.
"""
# Load the model from its metadata
with fileio.open(
artifact_store = Client().active_stack.artifact_store
with artifact_store.open(
os.path.join(model_uri, MODEL_METADATA_YAML_FILE_NAME), "r"
) as f:
metadata = read_yaml(f.name)
Expand Down
14 changes: 8 additions & 6 deletions src/zenml/logging/step_logging.py
Expand Up @@ -23,7 +23,7 @@
from uuid import uuid4

from zenml.artifact_stores import BaseArtifactStore
from zenml.io import fileio
from zenml.client import Client
from zenml.logger import get_logger
from zenml.logging import (
STEP_LOGS_STORAGE_INTERVAL_SECONDS,
Expand Down Expand Up @@ -64,6 +64,7 @@ def prepare_logs_uri(
Returns:
The URI of the logs file.
"""
artifact_store = Client().active_stack.artifact_store
if log_key is None:
log_key = str(uuid4())

Expand All @@ -74,16 +75,16 @@ def prepare_logs_uri(
)

# Create the dir
if not fileio.exists(logs_base_uri):
fileio.makedirs(logs_base_uri)
if not artifact_store.exists(logs_base_uri):
artifact_store.makedirs(logs_base_uri)

# Delete the file if it already exists
logs_uri = os.path.join(logs_base_uri, f"{log_key}.log")
if fileio.exists(logs_uri):
if artifact_store.exists(logs_uri):
logger.warning(
f"Logs file {logs_uri} already exists! Removing old log file..."
)
fileio.remove(logs_uri)
artifact_store.remove(logs_uri)
return logs_uri


Expand Down Expand Up @@ -135,12 +136,13 @@ def write(self, text: str) -> None:

def save_to_file(self) -> None:
"""Method to save the buffer to the given URI."""
artifact_store = Client().active_stack.artifact_store
if not self.disabled:
try:
self.disabled = True

if self.buffer:
with fileio.open(self.logs_uri, "a") as file:
with artifact_store.open(self.logs_uri, "a") as file:
for message in self.buffer:
file.write(
remove_ansi_escape_codes(message) + "\n"
Expand Down
3 changes: 2 additions & 1 deletion src/zenml/materializers/base_materializer.py
Expand Up @@ -156,8 +156,9 @@ def save_visualizations(self, data: Any) -> Dict[str, VisualizationType]:
Example:
```
artifact_store = Client().active_stack.artifact_store
visualization_uri = os.path.join(self.uri, "visualization.html")
with fileio.open(visualization_uri, "w") as f:
with artifact_store.open(visualization_uri, "w") as f:
f.write("<html><body>data</body></html>")
visualization_uri_2 = os.path.join(self.uri, "visualization.png")
Expand Down

0 comments on commit 00e934f

Please sign in to comment.