Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accelerate scaler #2677

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .typos.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ prev = "prev"
creat = "creat"
ret = "ret"
daa = "daa"
cll = "cll"

[default]
locale = "en-us"
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ zenml model-registry models register-version Tensorflow-model \

### Deploy a registered model

Afte you have registered a model in the MLflow model registry, you can also
After you have registered a model in the MLflow model registry, you can also
easily deploy it as a prediction service. Checkout the
[MLflow model deployer documentation](../model-deployers/mlflow.md#deploy-from-model-registry)
for more information on how to do that.
Expand Down
2 changes: 2 additions & 0 deletions src/zenml/config/step_configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from zenml.logger import get_logger
from zenml.model.lazy_load import ModelVersionDataLazyLoader
from zenml.model.model import Model
from zenml.models.v2.misc.scaler_models import ScalerModel
from zenml.utils import deprecation_utils

if TYPE_CHECKING:
Expand Down Expand Up @@ -137,6 +138,7 @@ class StepConfigurationUpdate(StrictBaseModel):
failure_hook_source: Optional[Source] = None
success_hook_source: Optional[Source] = None
model: Optional[Model] = None
scaler: Optional[ScalerModel] = None

outputs: Mapping[str, PartialArtifactConfiguration] = {}

Expand Down
10 changes: 10 additions & 0 deletions src/zenml/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,3 +384,13 @@ class PluginSubType(StrEnum):
WEBHOOK = "webhook"
# Action Subtypes
PIPELINE_RUN = "pipeline_run"


class AggregateFunction(StrEnum):
"""All possible aggregation functions."""

COUNT = "count"
SUM = "sum"
MEAN = "mean"
MIN = "min"
MAX = "max"
20 changes: 20 additions & 0 deletions src/zenml/integrations/accelerate/__init__.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this an accelerate integration or is this just part of our huggingface integration? I'm pretty sure you'll get accelerate already with the packages we have defined there?

Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) ZenML GmbH 2024. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Accelerate integration for ZenML."""

from zenml.integrations.accelerate.scalers.accelerate_scaler import AccelerateScaler

__all__ = [
"AccelerateScaler",
]
14 changes: 14 additions & 0 deletions src/zenml/integrations/accelerate/scalers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) ZenML GmbH 2024. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Accelerate scalers for ZenML."""
131 changes: 131 additions & 0 deletions src/zenml/integrations/accelerate/scalers/accelerate_scaler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Copyright (c) ZenML GmbH 2024. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Utility function to run Accelerate jobs."""

import subprocess
from typing import Any, Callable, Optional, TypeVar

import cloudpickle as pickle

from zenml.logger import get_logger
from zenml.models.v2.misc.scaler_models import ScalerModel
from zenml.utils.function_utils import _cli_arg_name, create_cli_wrapped_script

logger = get_logger(__name__)
F = TypeVar("F", bound=Callable[..., None])


class AccelerateScaler(ScalerModel):
"""Accelerate scaler model.

Accelerate package: https://huggingface.co/docs/accelerate/en/index

Example:
```python
from zenml import step
from zenml.integrations.accelerate import AccelerateScaler

@step(scaler=AccelerateScaler(num_processes=42))
def training_step(some_param: int, ...):
# your training code is below
...
```

Args:
num_processes: The number of processes to use (shall be less or equal to GPUs count).
"""

num_processes: Optional[int] = None

def run(self, step_function: F, **function_kwargs: Any) -> Any:
"""Run a function with accelerate.

Accelerate package: https://huggingface.co/docs/accelerate/en/index

Example:
```python
from zenml import step
from zenml.integrations.accelerate import AccelerateScaler

@step(scaler=AccelerateScaler(num_processes=42))
def training_step(some_param: int, ...):
# your training code is below
...
```

Args:
step_function: The function to run.
**function_kwargs: The keyword arguments to pass to the function.

Returns:
The return value of the function in the main process.

Raises:
CalledProcessError: If the function fails.
"""
import torch

logger.info("Starting accelerate job...")

device_count = torch.cuda.device_count()
if self.num_processes is None:
num_processes = device_count
else:
if self.num_processes > device_count:
logger.warning(
f"Number of processes ({self.num_processes}) is greater than "
f"the number of available GPUs ({device_count}). Using all GPUs."
)
num_processes = device_count
num_processes = self.num_processes

with create_cli_wrapped_script(
step_function, flavour="accelerate"
) as (
script_path,
output_path,
):
command = f"accelerate launch --num_processes {num_processes} "
command += str(script_path.absolute()) + " "
for k, v in function_kwargs.items():
k = _cli_arg_name(k)
if isinstance(v, bool):
if v:
command += f"--{k} "
elif isinstance(v, str):
command += f'--{k} "{v}" '
elif type(v) in (list, tuple, set):
for each in v:
command += f"--{k} {each} "
else:
command += f"--{k} {v} "

logger.info(command)

result = subprocess.run(
command,
shell=True,
stdout=subprocess.PIPE,
universal_newlines=True,
)
for stdout_line in result.stdout.split("\n"):
logger.info(stdout_line)
if result.returncode == 0:
logger.info("Accelerate training job finished.")
return pickle.load(open(output_path, "rb"))
else:
logger.error(
f"Accelerate training job failed. With return code {result.returncode}."
)
raise subprocess.CalledProcessError(result.returncode, command)
84 changes: 84 additions & 0 deletions src/zenml/models/v2/misc/scaler_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (c) ZenML GmbH 2024. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Model definitions for ZenML scalers."""

from typing import Any, Callable, ClassVar, Dict, Optional, Set, TypeVar

from pydantic import BaseModel, root_validator

F = TypeVar("F", bound=Callable[..., None])


class ScalerModel(BaseModel):
"""Domain model for scalers."""

scaler_flavor: Optional[str] = None

ALLOWED_SCALER_FLAVORS: ClassVar[Set[str]] = {
"AggregateScaler",
"AccelerateScaler",
}

@root_validator(pre=True)
def validate_scaler_flavor(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validate the scaler flavor.

Args:
values: The values to validate.

Returns:
The validated values.

Raises:
ValueError: If the scaler flavor is not supported.
"""
if values.get("scaler_flavor", None) is None:
values["scaler_flavor"] = cls.__name__ # type: ignore[attr-defined]
if values["scaler_flavor"] not in cls.ALLOWED_SCALER_FLAVORS:
raise ValueError(
f"Invalid scaler flavor {values['scaler_flavor']}. "
f"Allowed values are {cls.ALLOWED_SCALER_FLAVORS}"
)
return values

def run(self, step_function: F, **kwargs: Any) -> Any:
"""Run the step using scaler.

Args:
step_function: The step function to run.
**kwargs: Additional arguments to pass to the step function.

Returns:
The result of the step function as per scaler config.

Raises:
NotImplementedError: If the scaler flavor is not supported.
"""
if self.scaler_flavor == "AccelerateScaler":
from zenml.integrations.accelerate import AccelerateScaler

runner = AccelerateScaler(**self.dict())
elif self.scaler_flavor == "AggregateScaler":
from zenml.scalers import AggregateScaler

runner = AggregateScaler(**self.dict()) # type: ignore[assignment]
else:
raise NotImplementedError

return runner.run(step_function, **kwargs)

class Config:
"""Pydantic model configuration."""

extra = "allow"
5 changes: 5 additions & 0 deletions src/zenml/new/steps/step_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from zenml.config.source import Source
from zenml.materializers.base_materializer import BaseMaterializer
from zenml.model.model import Model
from zenml.models.v2.misc.scaler_models import ScalerModel
from zenml.steps import BaseStep

MaterializerClassOrSource = Union[str, Source, Type[BaseMaterializer]]
Expand Down Expand Up @@ -73,6 +74,7 @@ def step(
on_success: Optional["HookSpecification"] = None,
model: Optional["Model"] = None,
model_version: Optional["Model"] = None, # TODO: deprecate me
scaler: Optional["ScalerModel"] = None,
) -> Callable[["F"], "BaseStep"]: ...


Expand All @@ -93,6 +95,7 @@ def step(
on_success: Optional["HookSpecification"] = None,
model: Optional["Model"] = None,
model_version: Optional["Model"] = None, # TODO: deprecate me
scaler: Optional["ScalerModel"] = None,
) -> Union["BaseStep", Callable[["F"], "BaseStep"]]:
"""Decorator to create a ZenML step.

Expand Down Expand Up @@ -124,6 +127,7 @@ def step(
(e.g. `module.my_function`).
model: configuration of the model in the Model Control Plane.
model_version: DEPRECATED, please use `model` instead.
scaler: configuration of the scaler for this step.

Returns:
The step instance.
Expand Down Expand Up @@ -162,6 +166,7 @@ def inner_decorator(func: "F") -> "BaseStep":
on_failure=on_failure,
on_success=on_success,
model=model or model_version,
scaler=scaler,
)

return step_instance
Expand Down
5 changes: 5 additions & 0 deletions src/zenml/scalers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from zenml.scalers.aggregate_scaler import AggregateScaler

__all__ = [
"AggregateScaler",
]
Loading
Loading