forked from flyteorg/flytekit
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Remote entrypoint serialize (flyteorg#733)
Signed-off-by: Emirhan Karagül <emirhan350z@gmail.com> Signed-off-by: Yee Hing Tong <wild-endeavor@users.noreply.github.com> Signed-off-by: Robert Everson <reverson@lyft.com>
- Loading branch information
Showing
2 changed files
with
63 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from flytekitplugins.spark import Spark | ||
from mock import MagicMock, patch | ||
|
||
from flytekit import task | ||
from flytekit.remote.remote import FlyteRemote | ||
|
||
|
||
@patch("flytekit.configuration.platform.URL") | ||
@patch("flytekit.configuration.platform.INSECURE") | ||
def test_spark_template_with_remote(mock_insecure, mock_url): | ||
@task(task_config=Spark(spark_conf={"spark": "1"})) | ||
def my_spark(a: str) -> int: | ||
return 10 | ||
|
||
@task | ||
def my_python_task(a: str) -> int: | ||
return 10 | ||
|
||
mock_url.get.return_value = "localhost" | ||
|
||
mock_insecure.get.return_value = True | ||
mock_client = MagicMock() | ||
|
||
remote = FlyteRemote.from_config("p1", "d1") | ||
|
||
remote._image_config = MagicMock() | ||
remote._client = mock_client | ||
|
||
remote.register(my_spark) | ||
serialized_spec = mock_client.create_task.call_args.kwargs["task_spec"] | ||
|
||
# Check if the serialized spark task has mainApplicaitonFile field set. | ||
assert serialized_spec.template.custom["mainApplicationFile"] | ||
assert serialized_spec.template.custom["sparkConf"] | ||
|
||
remote.register(my_python_task) | ||
serialized_spec = mock_client.create_task.call_args.kwargs["task_spec"] | ||
|
||
# Check if the serialized python task has no mainApplicaitonFile field set by default. | ||
assert serialized_spec.template.custom is None |