Skip to content
2 changes: 1 addition & 1 deletion sync/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Library for leveraging the power of Sync"""
__version__ = "0.4.13"
__version__ = "0.5.0"

TIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ"
45 changes: 44 additions & 1 deletion sync/_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
import time
import zipfile
from collections import defaultdict
from datetime import datetime, timezone
from pathlib import Path
from time import sleep
from typing import Any, Collection, Dict, List, Tuple, TypeVar, Union
from typing import Any, Collection, Dict, List, Set, Tuple, TypeVar, Union
from urllib.parse import urlparse

import boto3 as boto
Expand Down Expand Up @@ -48,6 +49,7 @@ def create_prediction(
cluster_events: dict,
eventlog: bytes,
instances: dict = None,
instance_timelines: dict = None,
volumes: dict = None,
tasks: List[dict] = None,
project_id: str = None,
Expand Down Expand Up @@ -93,6 +95,7 @@ def create_prediction(
"cluster": cluster,
"cluster_events": cluster_events,
"instances": instances,
"instance_timelines": instance_timelines,
"volumes": volumes,
"tasks": tasks,
},
Expand Down Expand Up @@ -1926,6 +1929,46 @@ def _get_all_cluster_events(cluster_id: str):
return all_events


def _update_monitored_timelines(
running_instance_ids: Set[str],
active_timelines_by_id: Dict[str, dict],
) -> Tuple[Dict[str, dict], List[dict]]:
"""
Shared monitoring method for both Azure and Databricks to reduce complexity.
Compares the current running instances (keyed by id) to the running
instance timelines (also keyed by id). Instance timeline elements that are
still running are updated, while the rest are returned in a "retired" list.
"""

current_datetime = datetime.now(timezone.utc)
for id in running_instance_ids:
if id not in active_timelines_by_id:
# A new instance in the "running" state has been detected, so add it
# the dict of running instances and initialize the times.
logger.info(f"Adding new instance timeline: {id}")
active_timelines_by_id[id] = {
"instance_id": id,
"first_seen_running_time": current_datetime,
"last_seen_running_time": current_datetime,
}

else:
# If an instance was already in the list of running instances then update
# then just update the last_seen_running_time.
active_timelines_by_id[id]["last_seen_running_time"] = current_datetime

# If an instance in the active timeline is no longer in the running state then
# it should be moved over the retired timeline list.
retired_inst_timeline_list = []
ids_to_retire = set(active_timelines_by_id.keys()).difference(running_instance_ids)
if ids_to_retire:
for id in ids_to_retire:
logger.info(f"Retiring instance: {id}")
retired_inst_timeline_list.append(active_timelines_by_id.pop(id))

return active_timelines_by_id, retired_inst_timeline_list


KeyType = TypeVar("KeyType")


Expand Down
163 changes: 81 additions & 82 deletions sync/awsdatabricks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from time import sleep
from typing import Dict, List, Tuple
from typing import List, Tuple
from urllib.parse import urlparse

import boto3 as boto
Expand All @@ -13,6 +13,7 @@
_cluster_log_destination,
_get_all_cluster_events,
_get_cluster_instances_from_dbfs,
_update_monitored_timelines,
_wait_for_cluster_termination,
apply_prediction,
apply_project_recommendation,
Expand Down Expand Up @@ -203,15 +204,15 @@ def _get_cluster_report(

cluster = cluster_response.result

reservations_response, volumes_response = _get_aws_cluster_info(cluster)
instances_response, timeline_response, volumes_response = _get_aws_cluster_info(cluster)

if reservations_response.error:
if instances_response.error:
if allow_incomplete:
logger.warning(reservations_response.error)
logger.warning(instances_response.error)
else:
return reservations_response
return instances_response

# The volumes data is less critical than reservations, so allow
# The volumes data is less critical than instances, so allow
# the cluster report to get created even if the volumes response
# has an error.
if volumes_response.error:
Expand All @@ -220,6 +221,12 @@ def _get_cluster_report(
else:
volumes = volumes_response.result

if timeline_response.error:
logger.warning(timeline_response.error)
timelines = []
else:
timelines = timeline_response.result

cluster_events = _get_all_cluster_events(cluster_id)
return Response(
result=AWSDatabricksClusterReport(
Expand All @@ -229,7 +236,8 @@ def _get_cluster_report(
cluster_events=cluster_events,
volumes=volumes,
tasks=cluster_tasks,
instances=reservations_response.result,
instances=instances_response.result,
instance_timelines=timelines,
)
)

Expand All @@ -243,10 +251,10 @@ def _get_cluster_report(
setattr(sync._databricks, "__claim", __name__)


def _get_aws_cluster_info(cluster: dict) -> Tuple[Response[dict], Response[dict]]:
cluster_info = None
aws_region_name = DB_CONFIG.aws_region_name
def _load_aws_cluster_info(cluster: dict) -> Tuple[Response[dict], Response[dict]]:

cluster_info = None
cluster_id = None
cluster_log_dest = _cluster_log_destination(cluster)

if cluster_log_dest:
Expand All @@ -271,40 +279,51 @@ def _get_aws_cluster_info(cluster: dict) -> Tuple[Response[dict], Response[dict]
# If this cluster does not have the "Sync agent" configured, attempt a best-effort snapshot of the instances that
# are associated with this cluster
if not cluster_info:

try:
ec2 = boto.client("ec2", region_name=aws_region_name)
reservations = _get_ec2_instances(cluster_id, ec2)
volumes = _get_ebs_volumes_for_reservations(reservations, ec2)
ec2 = boto.client("ec2", region_name=DB_CONFIG.aws_region_name)
instances = _get_ec2_instances(cluster_id, ec2)
volumes = _get_ebs_volumes_for_instances(instances, ec2)

cluster_info = {
"Reservations": reservations,
"Volumes": volumes,
"instances": instances,
"volumes": volumes,
}

except Exception as exc:
logger.warning(exc)

return cluster_info, cluster_id


def _get_aws_cluster_info(cluster: dict) -> Tuple[Response[dict], Response[dict], Response[dict]]:

aws_region_name = DB_CONFIG.aws_region_name

cluster_info, cluster_id = _load_aws_cluster_info(cluster)

def missing_message(input: str) -> str:
return (
f"Unable to find any active or recently terminated {input} for cluster `{cluster_id}` in `{aws_region_name}`. "
+ "Please refer to the following documentation for options on how to address this - "
+ "https://docs.synccomputing.com/sync-gradient/integrating-with-gradient/databricks-workflows"
)

if not cluster_info or not cluster_info.get("Reservations"):
reservations_response = Response(
error=DatabricksError(message=missing_message("instances"))
)
if not cluster_info or not cluster_info.get("instances"):
instances_response = Response(error=DatabricksError(message=missing_message("instances")))
else:
instances_response = Response(result=cluster_info["instances"])

if not cluster_info or not cluster_info.get("instance_timelines"):
timeline_response = Response(error=DatabricksError(message=missing_message("timelines")))
else:
reservations_response = Response(result={"Reservations": cluster_info["Reservations"]})
timeline_response = Response(result=cluster_info["instance_timelines"])

if not cluster_info or not cluster_info.get("Volumes"):
if not cluster_info or not cluster_info.get("volumes"):
volumes_response = Response(error=DatabricksError(message=missing_message("ebs volumes")))
else:
volumes_response = Response(result={"Volumes": cluster_info.get("Volumes")})
volumes_response = Response(result=cluster_info.get("volumes"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just want to make confirm the capitalization change of "volume" here is expected

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep this is expected, that was part of the "de-nesting" that I talked about.


return reservations_response, volumes_response
return instances_response, timeline_response, volumes_response


def _get_aws_cluster_info_from_s3(bucket: str, file_key: str, cluster_id):
Expand All @@ -315,7 +334,7 @@ def _get_aws_cluster_info_from_s3(bucket: str, file_key: str, cluster_id):
logger.warning(f"Failed to retrieve cluster info from S3 with key, '{file_key}': {err}")


def monitor_cluster(cluster_id: str, polling_period: int = 30) -> None:
def monitor_cluster(cluster_id: str, polling_period: int = 20) -> None:
cluster = get_default_client().get_cluster(cluster_id)
spark_context_id = cluster.get("spark_context_id")

Expand Down Expand Up @@ -363,67 +382,48 @@ def write_file(body: bytes):
logger.info("Saving state to DBFS")
write_dbfs_file(path, body, dbx_client)

old_reservations = []
recorded_volumes = {}
all_inst_by_id = {}
active_timelines_by_id = {}
retired_timelines = []
recorded_volumes_by_id = {}
while True:
try:
new_reservations = _get_ec2_instances(cluster_id, ec2)
new_volumes = _get_ebs_volumes_for_reservations(new_reservations, ec2)

recorded_volumes.update({v["VolumeId"]: v for v in new_volumes})

new_instance_id_to_reservation = dict(
zip(
(res["Instances"][0]["InstanceId"] for res in new_reservations),
new_reservations,
)
current_insts = _get_ec2_instances(cluster_id, ec2)
recorded_volumes_by_id.update(
{v["VolumeId"]: v for v in _get_ebs_volumes_for_instances(current_insts, ec2)}
)

old_instance_id_to_reservation = dict(
zip(
(res["Instances"][0]["InstanceId"] for res in old_reservations),
old_reservations,
)
# Record new (or overrwite) existing instances.
# Separately record the ids of those that are in the "running" state.
running_inst_ids = set({})
for inst in current_insts:
all_inst_by_id[inst["InstanceId"]] = inst
if inst["State"]["Name"] == "running":
running_inst_ids.add(inst["InstanceId"])

active_timelines_by_id, new_retired_timelines = _update_monitored_timelines(
running_inst_ids, active_timelines_by_id
)

old_instance_ids = set(old_instance_id_to_reservation)
new_instance_ids = set(new_instance_id_to_reservation)

# If we have the exact same set of instances, prefer the new set...
if old_instance_ids == new_instance_ids:
reservations = new_reservations
else:
# Otherwise, update old references and include any new instances in the list
newly_added_instance_ids = new_instance_ids.difference(old_instance_ids)
updated_instance_ids = new_instance_ids.intersection(old_instance_ids)
removed_instance_ids = old_instance_ids.difference(new_instance_ids)

removed_reservations = [
old_instance_id_to_reservation[id] for id in removed_instance_ids
]
updated_reservations = [
new_instance_id_to_reservation[id] for id in updated_instance_ids
]
new_reservations = [
new_instance_id_to_reservation[id] for id in newly_added_instance_ids
]

reservations = [*removed_reservations, *updated_reservations, *new_reservations]
retired_timelines.extend(new_retired_timelines)
all_timelines = retired_timelines + list(active_timelines_by_id.values())

write_file(
orjson.dumps(
{"Reservations": reservations, "Volumes": list(recorded_volumes.values())}
{
"instances": list(all_inst_by_id.values()),
"instance_timelines": all_timelines,
"volumes": list(recorded_volumes_by_id.values()),
}
)
)

old_reservations = reservations
except Exception as e:
logger.error(f"Exception encountered while polling cluster: {e}")

sleep(polling_period)


def _get_ec2_instances(cluster_id: str, ec2_client: "botocore.client.ec2") -> List[Dict]:
def _get_ec2_instances(cluster_id: str, ec2_client: "botocore.client.ec2") -> List[dict]:

filters = [
{"Name": "tag:Vendor", "Values": ["Databricks"]},
Expand All @@ -438,25 +438,24 @@ def _get_ec2_instances(cluster_id: str, ec2_client: "botocore.client.ec2") -> Li
reservations += response.get("Reservations", [])
next_token = response.get("NextToken")

num_instances = 0
if reservations:
num_instances = len(reservations[0].get("Instances", []))
logger.info(f"Identified {num_instances} instances in cluster")
instances = []
for res in reservations:
for inst in res.get("Instances", []):
instances.append(inst)
logger.info(f"Identified {len(instances)} instances in cluster")

return reservations
return instances


def _get_ebs_volumes_for_reservations(
reservations: List[Dict], ec2_client: "botocore.client.ec2"
) -> List[Dict]:
def _get_ebs_volumes_for_instances(
instances: List[dict], ec2_client: "botocore.client.ec2"
) -> List[dict]:
"""Get all ebs volumes associated with a list of instance reservations"""

instance_ids = []
if reservations:
for res in reservations:
for instance in res.get("Instances", []):
if instance:
instance_ids.append(instance.get("InstanceId"))
if instances:
for instance in instances:
instance_ids.append(instance.get("InstanceId"))

volumes = []
if instance_ids:
Expand Down
Loading