Skip to content

Commit

Permalink
Remote entrypoint serialize (flyteorg#733)
Browse files Browse the repository at this point in the history
Signed-off-by: Emirhan Karagül <emirhan350z@gmail.com>
Signed-off-by: Yee Hing Tong <wild-endeavor@users.noreply.github.com>
Signed-off-by: Robert Everson <reverson@lyft.com>
  • Loading branch information
YmirKhang authored and Robert Everson committed May 27, 2022
1 parent ebd4c75 commit f5dbb1e
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 2 deletions.
25 changes: 23 additions & 2 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,19 @@

from flytekit.clients.helpers import iterate_node_executions, iterate_task_executions
from flytekit.clis.flyte_cli.main import _detect_default_config_file
from flytekit.clis.sdk_in_container import serialize
from flytekit.common import constants
from flytekit.common.exceptions import user as user_exceptions
from flytekit.common.translator import FlyteControlPlaneEntity, FlyteLocalEntity, get_serializable
from flytekit.configuration import auth as auth_config
from flytekit.configuration.internal import DOMAIN, PROJECT
from flytekit.core.base_task import PythonTask
from flytekit.core.context_manager import FlyteContextManager, ImageConfig, SerializationSettings, get_image_config
from flytekit.core.context_manager import (
FlyteContextManager,
ImageConfig,
SerializationSettings,
get_image_config,
)
from flytekit.core.data_persistence import FileAccessProvider
from flytekit.core.launch_plan import LaunchPlan
from flytekit.core.type_engine import TypeEngine
Expand Down Expand Up @@ -128,6 +134,7 @@ def from_config(
default_domain: typing.Optional[str] = None,
config_file_path: typing.Optional[str] = None,
grpc_credentials: typing.Optional[grpc.ChannelCredentials] = None,
venv_root: typing.Optional[str] = None,
) -> FlyteRemote:
"""Create a FlyteRemote object using flyte configuration variables and/or environment variable overrides.
Expand All @@ -151,6 +158,11 @@ def from_config(
raw_output_prefix=raw_output_data_prefix,
)

venv_root = venv_root or serialize._DEFAULT_FLYTEKIT_VIRTUALENV_ROOT
entrypoint = context_manager.EntrypointSettings(
path=os.path.join(venv_root, serialize._DEFAULT_FLYTEKIT_RELATIVE_ENTRYPOINT_LOC)
)

return cls(
flyte_admin_url=platform_config.URL.get(),
insecure=platform_config.INSECURE.get(),
Expand All @@ -169,6 +181,7 @@ def from_config(
common_models.RawOutputDataConfig(raw_output_data_prefix) if raw_output_data_prefix else None
),
grpc_credentials=grpc_credentials,
entrypoint_settings=entrypoint,
)

def __init__(
Expand All @@ -185,6 +198,7 @@ def __init__(
image_config: typing.Optional[ImageConfig] = None,
raw_output_data_config: typing.Optional[common_models.RawOutputDataConfig] = None,
grpc_credentials: typing.Optional[grpc.ChannelCredentials] = None,
entrypoint_settings: typing.Optional[context_manager.EntrypointSettings] = None,
):
"""Initialize a FlyteRemote object.
Expand All @@ -199,7 +213,11 @@ def __init__(
:param annotations: annotation config
:param image_config: image config
:param raw_output_data_config: location for offloaded data, e.g. in S3
:param grpc_credentials: gRPC channel credentials for connecting to flyte admin as returned by :func:`grpc.ssl_channel_credentials`
:param grpc_credentials: gRPC channel credentials for connecting to flyte admin as returned
by :func:`grpc.ssl_channel_credentials`
:param entrypoint_settings: EntrypointSettings object for use with Spark tasks. If supplied, this will be
used when serializing Spark tasks, which need to know the path to the flytekit entrypoint.py file,
inside the container.
"""
remote_logger.warning("This feature is still in beta. Its interface and UX is subject to change.")
if flyte_admin_url is None:
Expand All @@ -217,6 +235,8 @@ def __init__(
self._labels = labels
self._annotations = annotations
self._raw_output_data_config = raw_output_data_config
# Not exposing this as a property for now.
self._entrypoint_settings = entrypoint_settings

# Save the file access object locally, but also make it available for use from the context.
FlyteContextManager.with_context(FlyteContextManager.current_context().with_file_access(file_access).build())
Expand Down Expand Up @@ -525,6 +545,7 @@ def _serialize(
self.image_config,
# https://github.com/flyteorg/flyte/issues/1359
env={internal.IMAGE.env_var: self.image_config.default_image.full},
entrypoint_settings=self._entrypoint_settings,
),
entity=entity,
)
Expand Down
40 changes: 40 additions & 0 deletions plugins/flytekit-spark/tests/test_remote_register.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from flytekitplugins.spark import Spark
from mock import MagicMock, patch

from flytekit import task
from flytekit.remote.remote import FlyteRemote


@patch("flytekit.configuration.platform.URL")
@patch("flytekit.configuration.platform.INSECURE")
def test_spark_template_with_remote(mock_insecure, mock_url):
@task(task_config=Spark(spark_conf={"spark": "1"}))
def my_spark(a: str) -> int:
return 10

@task
def my_python_task(a: str) -> int:
return 10

mock_url.get.return_value = "localhost"

mock_insecure.get.return_value = True
mock_client = MagicMock()

remote = FlyteRemote.from_config("p1", "d1")

remote._image_config = MagicMock()
remote._client = mock_client

remote.register(my_spark)
serialized_spec = mock_client.create_task.call_args.kwargs["task_spec"]

# Check if the serialized spark task has mainApplicaitonFile field set.
assert serialized_spec.template.custom["mainApplicationFile"]
assert serialized_spec.template.custom["sparkConf"]

remote.register(my_python_task)
serialized_spec = mock_client.create_task.call_args.kwargs["task_spec"]

# Check if the serialized python task has no mainApplicaitonFile field set by default.
assert serialized_spec.template.custom is None

0 comments on commit f5dbb1e

Please sign in to comment.