Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
Signed-off-by: Prithvi Kannan <prithvi.kannan@databricks.com>
  • Loading branch information
prithvikannan committed May 2, 2024
1 parent 72a5b8c commit 1c1e0c6
Showing 1 changed file with 30 additions and 10 deletions.
40 changes: 30 additions & 10 deletions mlflow/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,18 @@
from mlflow.artifacts import download_artifacts
from mlflow.exceptions import MlflowException
from mlflow.models.resources import Resource, ResourceType, _ResourceBuilder
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE, RESOURCE_DOES_NOT_EXIST
from mlflow.protos.databricks_pb2 import (
INVALID_PARAMETER_VALUE,
RESOURCE_DOES_NOT_EXIST,
)
from mlflow.store.artifact.models_artifact_repo import ModelsArtifactRepository
from mlflow.store.artifact.runs_artifact_repo import RunsArtifactRepository
from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS
from mlflow.tracking._tracking_service.utils import _resolve_tracking_uri
from mlflow.tracking.artifact_utils import _download_artifact_from_uri, _upload_artifact_to_uri
from mlflow.tracking.artifact_utils import (
_download_artifact_from_uri,
_upload_artifact_to_uri,
)
from mlflow.utils.annotations import experimental
from mlflow.utils.databricks_utils import get_databricks_runtime_version
from mlflow.utils.docstring_utils import LOG_MODEL_PARAM_DOCS, format_docstring
Expand Down Expand Up @@ -489,7 +495,9 @@ def resources(self) -> Dict[str, Dict[ResourceType, List[Dict]]]:
def resources(self, value: Optional[Union[str, List[Resource]]]):
if isinstance(value, (Path, str)):
serialized_resource = _ResourceBuilder.from_yaml_file(value)
elif isinstance(value, List) and all(isinstance(resource, Resource) for resource in value):
elif isinstance(value, List) and all(
isinstance(resource, Resource) for resource in value
):
serialized_resource = _ResourceBuilder.from_resources(value)
else:
serialized_resource = value
Expand Down Expand Up @@ -658,7 +666,10 @@ def log(
if run_id is None:
run_id = mlflow.tracking.fluent._get_or_start_run().info.run_id
mlflow_model = cls(
artifact_path=artifact_path, run_id=run_id, metadata=metadata, resources=resources
artifact_path=artifact_path,
run_id=run_id,
metadata=metadata,
resources=resources,
)
flavor.save_model(path=local_path, mlflow_model=mlflow_model, **kwargs)

Expand Down Expand Up @@ -688,12 +699,15 @@ def log(
# We check signature presence here as some flavors have a default signature as a
# fallback when not provided by user, which is set during flavor's save_model() call.
if mlflow_model.signature is None and (
tracking_uri == "databricks" or get_uri_scheme(tracking_uri) == "databricks"
tracking_uri == "databricks"
or get_uri_scheme(tracking_uri) == "databricks"
):
_logger.warning(_LOG_MODEL_MISSING_SIGNATURE_WARNING)
mlflow.tracking.fluent.log_artifacts(local_path, mlflow_model.artifact_path, run_id)
mlflow.tracking.fluent.log_artifacts(
local_path, mlflow_model.artifact_path, run_id
)

# if the model_config kwarg is passed in, then log the model config as an params
# if the model_config kwarg is passed in, then log the model config as an params
try:
if "model_config" in kwargs:
model_config = kwargs["model_config"]
Expand All @@ -709,7 +723,9 @@ def log(
except MlflowException:
# We need to swallow all mlflow exceptions to maintain backwards compatibility with
# older tracking servers. Only print out a warning for now.
_logger.warning(_LOG_MODEL_METADATA_WARNING_TEMPLATE, mlflow.get_artifact_uri())
_logger.warning(
_LOG_MODEL_METADATA_WARNING_TEMPLATE, mlflow.get_artifact_uri()
)
_logger.debug("", exc_info=True)
if registered_model_name is not None:
mlflow.tracking._model_registry.fluent._register_model(
Expand Down Expand Up @@ -867,13 +883,17 @@ def update_model_requirements(
old_requirements_reqs = _get_requirements_from_file(requirements_txt_path)

if operation == "add":
updated_conda_reqs = _add_or_overwrite_requirements(requirement_list, old_conda_reqs)
updated_conda_reqs = _add_or_overwrite_requirements(
requirement_list, old_conda_reqs
)
updated_requirements_reqs = _add_or_overwrite_requirements(
requirement_list, old_requirements_reqs
)
else:
updated_conda_reqs = _remove_requirements(requirement_list, old_conda_reqs)
updated_requirements_reqs = _remove_requirements(requirement_list, old_requirements_reqs)
updated_requirements_reqs = _remove_requirements(
requirement_list, old_requirements_reqs
)

_write_requirements_to_file(conda_yaml_path, updated_conda_reqs)
_write_requirements_to_file(requirements_txt_path, updated_requirements_reqs)
Expand Down

0 comments on commit 1c1e0c6

Please sign in to comment.