Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/dependabot.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ updates:
directory: "/" # Location of package manifests
schedule:
interval: "weekly"
open-pull-requests-limit: 5
open-pull-requests-limit: 0
17 changes: 17 additions & 0 deletions .github/pull_request_template.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Summary

*please add a few lines to give the reviewer context on the changes*

## Checklist

Before formally opening this PR, please adhere to the following standards:

- [ ] Branch/PR names begin with the related Jira ticket id (ie PROD-31) for Jira integration
- [ ] File names are lower_snake_case
- [ ] Relevant unit tests have been added or not applicable
- [ ] Relevant documentation has been added or not applicable
- [ ] Mark yourself as the assignee (makes it easier to scan the PR list)

[Related Jira Ticket](https://synccomputing.atlassian.net/browse/PROD-) (add id)

Add any relevant testing examples or screenshots.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@ __pycache__
*.swo
venv/*
.idea/*
.python-version
.python-version
.DS_Store
.venv
2 changes: 1 addition & 1 deletion sync/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Library for leveraging the power of Sync"""
__version__ = "0.6.3"
__version__ = "0.6.4"

TIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ"
72 changes: 54 additions & 18 deletions sync/awsdatabricks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import logging
from pathlib import Path
from time import sleep
from typing import List, Tuple
from urllib.parse import urlparse
Expand Down Expand Up @@ -335,7 +336,11 @@ def _get_aws_cluster_info_from_s3(bucket: str, file_key: str, cluster_id):
logger.warning(f"Failed to retrieve cluster info from S3 with key, '{file_key}': {err}")


def monitor_cluster(cluster_id: str, polling_period: int = 20) -> None:
def monitor_cluster(
cluster_id: str,
polling_period: int = 20,
cluster_report_destination_override: dict = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it better to type the None as Optional too?

Suggested change
cluster_report_destination_override: dict = None,
cluster_report_destination_override: Optional[dict] = None,

) -> None:
cluster = get_default_client().get_cluster(cluster_id)
spark_context_id = cluster.get("spark_context_id")

Expand All @@ -347,16 +352,26 @@ def monitor_cluster(cluster_id: str, polling_period: int = 20) -> None:
spark_context_id = cluster.get("spark_context_id")

(log_url, filesystem, bucket, base_prefix) = _cluster_log_destination(cluster)
if log_url:
if cluster_report_destination_override:
filesystem = cluster_report_destination_override.get("filesystem", filesystem)
base_prefix = cluster_report_destination_override.get("base_prefix", base_prefix)

if log_url or cluster_report_destination_override:
_monitor_cluster(
(log_url, filesystem, bucket, base_prefix), cluster_id, spark_context_id, polling_period
(log_url, filesystem, bucket, base_prefix),
cluster_id,
spark_context_id,
polling_period,
)
else:
logger.warning("Unable to monitor cluster due to missing cluster log destination - exiting")


def _monitor_cluster(
cluster_log_destination, cluster_id: str, spark_context_id: int, polling_period: int
cluster_log_destination,
cluster_id: str,
spark_context_id: int,
polling_period: int,
) -> None:

(log_url, filesystem, bucket, base_prefix) = cluster_log_destination
Expand All @@ -368,20 +383,7 @@ def _monitor_cluster(
aws_region_name = DB_CONFIG.aws_region_name
ec2 = boto.client("ec2", region_name=aws_region_name)

if filesystem == "s3":
s3 = boto.client("s3")

def write_file(body: bytes):
logger.info("Saving state to S3")
s3.put_object(Bucket=bucket, Key=file_key, Body=body)

elif filesystem == "dbfs":
path = format_dbfs_filepath(file_key)
dbx_client = get_default_client()

def write_file(body: bytes):
logger.info("Saving state to DBFS")
write_dbfs_file(path, body, dbx_client)
write_file = _define_write_file(file_key, filesystem, bucket)

all_inst_by_id = {}
active_timelines_by_id = {}
Expand Down Expand Up @@ -428,6 +430,40 @@ def write_file(body: bytes):
sleep(polling_period)


def _define_write_file(file_key, filesystem, bucket):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice moving logic into sub-functions!

if filesystem == "file":
file_path = Path(f"{Path.home()}{file_key}")

def ensure_path_exists(report_path: Path):
logger.info(f"Ensuring path exists for {report_path}")
report_path.parent.mkdir(parents=True, exist_ok=True)

def write_file(body: bytes):
logger.info("Saving state to local file")
ensure_path_exists(file_path)
with open(file_path, "wb") as f:
f.write(body)

elif filesystem == "s3":
s3 = boto.client("s3")

def write_file(body: bytes):
logger.info("Saving state to S3")
s3.put_object(Bucket=bucket, Key=file_key, Body=body)

elif filesystem == "dbfs":
path = format_dbfs_filepath(file_key)
dbx_client = get_default_client()

def write_file(body: bytes):
logger.info("Saving state to DBFS")
write_dbfs_file(path, body, dbx_client)

else:
raise ValueError(f"Unsupported filesystem: {filesystem}")
return write_file


def _get_ec2_instances(cluster_id: str, ec2_client: "botocore.client.ec2") -> List[dict]:

filters = [
Expand Down
64 changes: 63 additions & 1 deletion tests/test_awsdatabricks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import io
import json
import unittest
from datetime import datetime
from unittest.mock import Mock, patch
from uuid import uuid4
Expand All @@ -10,7 +11,7 @@
from botocore.stub import Stubber
from httpx import Response

from sync.awsdatabricks import create_prediction_for_run
from sync.awsdatabricks import create_prediction_for_run, monitor_cluster
from sync.config import DatabricksConf
from sync.models import DatabricksAPIError, DatabricksError
from sync.models import Response as SyncResponse
Expand Down Expand Up @@ -1089,3 +1090,64 @@ def client_patch(name, **kwargs):
result = create_prediction_for_run("75778", "Premium", "Jobs Compute", "my-project-id")

assert result.result


@patch("sync.awsdatabricks._monitor_cluster")
@patch("sync.clients.databricks.DatabricksClient.get_cluster")
@patch(
"sync.awsdatabricks._cluster_log_destination",
)
class TestMonitorCluster(unittest.TestCase):
def test_monitor_cluster_with_override(
self,
mock_cluster_log_destination,
mock_get_cluster,
mock_monitor_cluster,
):
mock_cluster_log_destination.return_value = ("s3://bucket/path", "s3", "bucket", "path")

mock_get_cluster.return_value = {
"cluster_id": "0101-214342-tpi6qdp2",
"spark_context_id": 1443449481634833945,
}

cluster_report_destination_override = {
"filesystem": "file",
"base_prefix": "test_file_path",
}

monitor_cluster("0101-214342-tpi6qdp2", 1, cluster_report_destination_override)

expected_log_destination_override = ("s3://bucket/path", "file", "bucket", "test_file_path")
mock_monitor_cluster.assert_called_with(
expected_log_destination_override, "0101-214342-tpi6qdp2", 1443449481634833945, 1
)

mock_cluster_log_destination.return_value = (None, "s3", None, "path")
monitor_cluster("0101-214342-tpi6qdp2", 1, cluster_report_destination_override)
expected_log_destination_override = (None, "file", None, "test_file_path")
mock_monitor_cluster.assert_called_with(
expected_log_destination_override, "0101-214342-tpi6qdp2", 1443449481634833945, 1
)

def test_monitor_cluster_without_override(
self,
mock_cluster_log_destination,
mock_get_cluster,
mock_monitor_cluster,
):
mock_cluster_log_destination.return_value = ("s3://bucket/path", "s3", "bucket", "path")

mock_get_cluster.return_value = {
"cluster_id": "0101-214342-tpi6qdp2",
"spark_context_id": 1443449481634833945,
}

monitor_cluster("0101-214342-tpi6qdp2", 1)

mock_monitor_cluster.assert_called_with(
mock_cluster_log_destination.return_value,
"0101-214342-tpi6qdp2",
1443449481634833945,
1,
)