diff --git a/sync/__init__.py b/sync/__init__.py index 6920b47..977e309 100644 --- a/sync/__init__.py +++ b/sync/__init__.py @@ -1,4 +1,4 @@ """Library for leveraging the power of Sync""" -__version__ = "0.4.13" +__version__ = "0.5.0" TIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ" diff --git a/sync/_databricks.py b/sync/_databricks.py index de20d6e..cb85fd1 100644 --- a/sync/_databricks.py +++ b/sync/_databricks.py @@ -7,9 +7,10 @@ import time import zipfile from collections import defaultdict +from datetime import datetime, timezone from pathlib import Path from time import sleep -from typing import Any, Collection, Dict, List, Tuple, TypeVar, Union +from typing import Any, Collection, Dict, List, Set, Tuple, TypeVar, Union from urllib.parse import urlparse import boto3 as boto @@ -48,6 +49,7 @@ def create_prediction( cluster_events: dict, eventlog: bytes, instances: dict = None, + instance_timelines: dict = None, volumes: dict = None, tasks: List[dict] = None, project_id: str = None, @@ -93,6 +95,7 @@ def create_prediction( "cluster": cluster, "cluster_events": cluster_events, "instances": instances, + "instance_timelines": instance_timelines, "volumes": volumes, "tasks": tasks, }, @@ -1926,6 +1929,46 @@ def _get_all_cluster_events(cluster_id: str): return all_events +def _update_monitored_timelines( + running_instance_ids: Set[str], + active_timelines_by_id: Dict[str, dict], +) -> Tuple[Dict[str, dict], List[dict]]: + """ + Shared monitoring method for both Azure and Databricks to reduce complexity. + Compares the current running instances (keyed by id) to the running + instance timelines (also keyed by id). Instance timeline elements that are + still running are updated, while the rest are returned in a "retired" list. + """ + + current_datetime = datetime.now(timezone.utc) + for id in running_instance_ids: + if id not in active_timelines_by_id: + # A new instance in the "running" state has been detected, so add it + # the dict of running instances and initialize the times. + logger.info(f"Adding new instance timeline: {id}") + active_timelines_by_id[id] = { + "instance_id": id, + "first_seen_running_time": current_datetime, + "last_seen_running_time": current_datetime, + } + + else: + # If an instance was already in the list of running instances then update + # then just update the last_seen_running_time. + active_timelines_by_id[id]["last_seen_running_time"] = current_datetime + + # If an instance in the active timeline is no longer in the running state then + # it should be moved over the retired timeline list. + retired_inst_timeline_list = [] + ids_to_retire = set(active_timelines_by_id.keys()).difference(running_instance_ids) + if ids_to_retire: + for id in ids_to_retire: + logger.info(f"Retiring instance: {id}") + retired_inst_timeline_list.append(active_timelines_by_id.pop(id)) + + return active_timelines_by_id, retired_inst_timeline_list + + KeyType = TypeVar("KeyType") diff --git a/sync/awsdatabricks.py b/sync/awsdatabricks.py index 550e420..b7fcb89 100644 --- a/sync/awsdatabricks.py +++ b/sync/awsdatabricks.py @@ -1,6 +1,6 @@ import logging from time import sleep -from typing import Dict, List, Tuple +from typing import List, Tuple from urllib.parse import urlparse import boto3 as boto @@ -13,6 +13,7 @@ _cluster_log_destination, _get_all_cluster_events, _get_cluster_instances_from_dbfs, + _update_monitored_timelines, _wait_for_cluster_termination, apply_prediction, apply_project_recommendation, @@ -203,15 +204,15 @@ def _get_cluster_report( cluster = cluster_response.result - reservations_response, volumes_response = _get_aws_cluster_info(cluster) + instances_response, timeline_response, volumes_response = _get_aws_cluster_info(cluster) - if reservations_response.error: + if instances_response.error: if allow_incomplete: - logger.warning(reservations_response.error) + logger.warning(instances_response.error) else: - return reservations_response + return instances_response - # The volumes data is less critical than reservations, so allow + # The volumes data is less critical than instances, so allow # the cluster report to get created even if the volumes response # has an error. if volumes_response.error: @@ -220,6 +221,12 @@ def _get_cluster_report( else: volumes = volumes_response.result + if timeline_response.error: + logger.warning(timeline_response.error) + timelines = [] + else: + timelines = timeline_response.result + cluster_events = _get_all_cluster_events(cluster_id) return Response( result=AWSDatabricksClusterReport( @@ -229,7 +236,8 @@ def _get_cluster_report( cluster_events=cluster_events, volumes=volumes, tasks=cluster_tasks, - instances=reservations_response.result, + instances=instances_response.result, + instance_timelines=timelines, ) ) @@ -243,10 +251,10 @@ def _get_cluster_report( setattr(sync._databricks, "__claim", __name__) -def _get_aws_cluster_info(cluster: dict) -> Tuple[Response[dict], Response[dict]]: - cluster_info = None - aws_region_name = DB_CONFIG.aws_region_name +def _load_aws_cluster_info(cluster: dict) -> Tuple[Response[dict], Response[dict]]: + cluster_info = None + cluster_id = None cluster_log_dest = _cluster_log_destination(cluster) if cluster_log_dest: @@ -271,20 +279,28 @@ def _get_aws_cluster_info(cluster: dict) -> Tuple[Response[dict], Response[dict] # If this cluster does not have the "Sync agent" configured, attempt a best-effort snapshot of the instances that # are associated with this cluster if not cluster_info: - try: - ec2 = boto.client("ec2", region_name=aws_region_name) - reservations = _get_ec2_instances(cluster_id, ec2) - volumes = _get_ebs_volumes_for_reservations(reservations, ec2) + ec2 = boto.client("ec2", region_name=DB_CONFIG.aws_region_name) + instances = _get_ec2_instances(cluster_id, ec2) + volumes = _get_ebs_volumes_for_instances(instances, ec2) cluster_info = { - "Reservations": reservations, - "Volumes": volumes, + "instances": instances, + "volumes": volumes, } except Exception as exc: logger.warning(exc) + return cluster_info, cluster_id + + +def _get_aws_cluster_info(cluster: dict) -> Tuple[Response[dict], Response[dict], Response[dict]]: + + aws_region_name = DB_CONFIG.aws_region_name + + cluster_info, cluster_id = _load_aws_cluster_info(cluster) + def missing_message(input: str) -> str: return ( f"Unable to find any active or recently terminated {input} for cluster `{cluster_id}` in `{aws_region_name}`. " @@ -292,19 +308,22 @@ def missing_message(input: str) -> str: + "https://docs.synccomputing.com/sync-gradient/integrating-with-gradient/databricks-workflows" ) - if not cluster_info or not cluster_info.get("Reservations"): - reservations_response = Response( - error=DatabricksError(message=missing_message("instances")) - ) + if not cluster_info or not cluster_info.get("instances"): + instances_response = Response(error=DatabricksError(message=missing_message("instances"))) + else: + instances_response = Response(result=cluster_info["instances"]) + + if not cluster_info or not cluster_info.get("instance_timelines"): + timeline_response = Response(error=DatabricksError(message=missing_message("timelines"))) else: - reservations_response = Response(result={"Reservations": cluster_info["Reservations"]}) + timeline_response = Response(result=cluster_info["instance_timelines"]) - if not cluster_info or not cluster_info.get("Volumes"): + if not cluster_info or not cluster_info.get("volumes"): volumes_response = Response(error=DatabricksError(message=missing_message("ebs volumes"))) else: - volumes_response = Response(result={"Volumes": cluster_info.get("Volumes")}) + volumes_response = Response(result=cluster_info.get("volumes")) - return reservations_response, volumes_response + return instances_response, timeline_response, volumes_response def _get_aws_cluster_info_from_s3(bucket: str, file_key: str, cluster_id): @@ -315,7 +334,7 @@ def _get_aws_cluster_info_from_s3(bucket: str, file_key: str, cluster_id): logger.warning(f"Failed to retrieve cluster info from S3 with key, '{file_key}': {err}") -def monitor_cluster(cluster_id: str, polling_period: int = 30) -> None: +def monitor_cluster(cluster_id: str, polling_period: int = 20) -> None: cluster = get_default_client().get_cluster(cluster_id) spark_context_id = cluster.get("spark_context_id") @@ -363,67 +382,48 @@ def write_file(body: bytes): logger.info("Saving state to DBFS") write_dbfs_file(path, body, dbx_client) - old_reservations = [] - recorded_volumes = {} + all_inst_by_id = {} + active_timelines_by_id = {} + retired_timelines = [] + recorded_volumes_by_id = {} while True: try: - new_reservations = _get_ec2_instances(cluster_id, ec2) - new_volumes = _get_ebs_volumes_for_reservations(new_reservations, ec2) - - recorded_volumes.update({v["VolumeId"]: v for v in new_volumes}) - - new_instance_id_to_reservation = dict( - zip( - (res["Instances"][0]["InstanceId"] for res in new_reservations), - new_reservations, - ) + current_insts = _get_ec2_instances(cluster_id, ec2) + recorded_volumes_by_id.update( + {v["VolumeId"]: v for v in _get_ebs_volumes_for_instances(current_insts, ec2)} ) - old_instance_id_to_reservation = dict( - zip( - (res["Instances"][0]["InstanceId"] for res in old_reservations), - old_reservations, - ) + # Record new (or overrwite) existing instances. + # Separately record the ids of those that are in the "running" state. + running_inst_ids = set({}) + for inst in current_insts: + all_inst_by_id[inst["InstanceId"]] = inst + if inst["State"]["Name"] == "running": + running_inst_ids.add(inst["InstanceId"]) + + active_timelines_by_id, new_retired_timelines = _update_monitored_timelines( + running_inst_ids, active_timelines_by_id ) - old_instance_ids = set(old_instance_id_to_reservation) - new_instance_ids = set(new_instance_id_to_reservation) - - # If we have the exact same set of instances, prefer the new set... - if old_instance_ids == new_instance_ids: - reservations = new_reservations - else: - # Otherwise, update old references and include any new instances in the list - newly_added_instance_ids = new_instance_ids.difference(old_instance_ids) - updated_instance_ids = new_instance_ids.intersection(old_instance_ids) - removed_instance_ids = old_instance_ids.difference(new_instance_ids) - - removed_reservations = [ - old_instance_id_to_reservation[id] for id in removed_instance_ids - ] - updated_reservations = [ - new_instance_id_to_reservation[id] for id in updated_instance_ids - ] - new_reservations = [ - new_instance_id_to_reservation[id] for id in newly_added_instance_ids - ] - - reservations = [*removed_reservations, *updated_reservations, *new_reservations] + retired_timelines.extend(new_retired_timelines) + all_timelines = retired_timelines + list(active_timelines_by_id.values()) write_file( orjson.dumps( - {"Reservations": reservations, "Volumes": list(recorded_volumes.values())} + { + "instances": list(all_inst_by_id.values()), + "instance_timelines": all_timelines, + "volumes": list(recorded_volumes_by_id.values()), + } ) ) - - old_reservations = reservations except Exception as e: logger.error(f"Exception encountered while polling cluster: {e}") sleep(polling_period) -def _get_ec2_instances(cluster_id: str, ec2_client: "botocore.client.ec2") -> List[Dict]: +def _get_ec2_instances(cluster_id: str, ec2_client: "botocore.client.ec2") -> List[dict]: filters = [ {"Name": "tag:Vendor", "Values": ["Databricks"]}, @@ -438,25 +438,24 @@ def _get_ec2_instances(cluster_id: str, ec2_client: "botocore.client.ec2") -> Li reservations += response.get("Reservations", []) next_token = response.get("NextToken") - num_instances = 0 - if reservations: - num_instances = len(reservations[0].get("Instances", [])) - logger.info(f"Identified {num_instances} instances in cluster") + instances = [] + for res in reservations: + for inst in res.get("Instances", []): + instances.append(inst) + logger.info(f"Identified {len(instances)} instances in cluster") - return reservations + return instances -def _get_ebs_volumes_for_reservations( - reservations: List[Dict], ec2_client: "botocore.client.ec2" -) -> List[Dict]: +def _get_ebs_volumes_for_instances( + instances: List[dict], ec2_client: "botocore.client.ec2" +) -> List[dict]: """Get all ebs volumes associated with a list of instance reservations""" instance_ids = [] - if reservations: - for res in reservations: - for instance in res.get("Instances", []): - if instance: - instance_ids.append(instance.get("InstanceId")) + if instances: + for instance in instances: + instance_ids.append(instance.get("InstanceId")) volumes = [] if instance_ids: diff --git a/sync/azuredatabricks.py b/sync/azuredatabricks.py index 8200a56..fb346fa 100644 --- a/sync/azuredatabricks.py +++ b/sync/azuredatabricks.py @@ -17,6 +17,7 @@ _cluster_log_destination, _get_all_cluster_events, _get_cluster_instances_from_dbfs, + _update_monitored_timelines, _wait_for_cluster_termination, apply_prediction, apply_project_recommendation, @@ -227,7 +228,8 @@ def _get_cluster_report( cluster=cluster, cluster_events=cluster_events, tasks=cluster_tasks, - instances=instances.result, + instances=instances.result.get("instances"), + instance_timelines=instances.result.get("timelines"), ) ) @@ -254,7 +256,7 @@ def _get_cluster_instances(cluster: dict) -> Response[dict]: cluster_id = cluster["cluster_id"] spark_context_id = cluster["spark_context_id"] cluster_instances_file_key = ( - f"{base_prefix}/sync_data/{spark_context_id}/cluster_instances.json" + f"{base_prefix}/sync_data/{spark_context_id}/azure_cluster_info.json" ) cluster_instances_file_response = None @@ -278,16 +280,23 @@ def _get_cluster_instances(cluster: dict) -> Response[dict]: compute = _get_azure_client(ComputeManagementClient) if resource_group_name: vms = compute.virtual_machines.list(resource_group_name=resource_group_name) + + for vm in vms: + compute.virtual_machines.instance_view( + resource_group_name=resource_group_name, + ) else: logger.warning("Failed to find Databricks managed resource group") vms = compute.virtual_machines.list_all() - cluster_instances = [ - vm.as_dict() - for vm in vms - if vm.tags.get("Vendor") == "Databricks" - and vm.tags.get("ClusterId") == cluster["cluster_id"] - ] + cluster_instances = { + "instances": [ + vm.as_dict() + for vm in vms + if vm.tags.get("Vendor") == "Databricks" + and vm.tags.get("ClusterId") == cluster["cluster_id"] + ] + } if not cluster_instances: no_instances_message = ( @@ -300,10 +309,9 @@ def _get_cluster_instances(cluster: dict) -> Response[dict]: return Response(result=cluster_instances) -def monitor_cluster(cluster_id: str, polling_period: int = 30) -> None: +def monitor_cluster(cluster_id: str, polling_period: int = 20) -> None: cluster = get_default_client().get_cluster(cluster_id) spark_context_id = cluster.get("spark_context_id") - while not spark_context_id: # This is largely just a convenience for when this command is run by someone locally logger.info("Waiting for cluster startup...") @@ -327,7 +335,10 @@ def _monitor_cluster( # If the event log destination is just a *bucket* without any sub-path, then we don't want to include # a leading `/` in our Prefix (which will make it so that we never actually find the event log), so # we make sure to re-strip our final Prefix - file_key = f"{base_prefix}/sync_data/{spark_context_id}/cluster_instances.json".strip("/") + file_key = f"{base_prefix}/sync_data/{spark_context_id}/azure_cluster_info.json".strip("/") + + azure_logger = logging.getLogger("azure.core.pipeline.policies.http_logging_policy") + azure_logger.setLevel(logging.WARNING) if filesystem == "dbfs": path = format_dbfs_filepath(file_key) @@ -342,64 +353,34 @@ def write_file(body: bytes): logger.warning("Failed to find Databricks managed resource group") compute = _get_azure_client(ComputeManagementClient) - previous_instances = {} + all_vms_by_id = {} + active_timelines_by_id = {} + retired_timelines = [] while True: try: - if resource_group_name: - vms = compute.virtual_machines.list(resource_group_name=resource_group_name) - else: - vms = compute.virtual_machines.list_all() - new_instances = [ - vm.as_dict() - for vm in vms - if vm.tags.get("Vendor") == "Databricks" and vm.tags.get("ClusterId") == cluster_id - ] + running_vms_by_id = _get_running_vms_by_id(compute, resource_group_name, cluster_id) - new_instance_id_to_reservation = dict( - zip( - (vm["id"] for vm in new_instances), - new_instances, - ) - ) + for vm in running_vms_by_id.values(): + all_vms_by_id[vm["name"]] = vm - old_instance_id_to_reservation = dict( - zip( - (vm["id"] for vm in previous_instances), - previous_instances, + active_timelines_by_id, new_retired_timelines = _update_monitored_timelines( + set(running_vms_by_id.keys()), active_timelines_by_id + ) + retired_timelines.extend(new_retired_timelines) + all_timelines = retired_timelines + list(active_timelines_by_id.values()) + + write_file( + orjson.dumps( + { + "instances": list(all_vms_by_id.values()), + "timelines": all_timelines, + } ) ) - old_instance_ids = set(old_instance_id_to_reservation) - new_instance_ids = set(new_instance_id_to_reservation) - - # If we have the exact same set of instances, prefer the new set... - if old_instance_ids == new_instance_ids: - instances = new_instances - else: - # Otherwise, update old references and include any new instances in the list - newly_added_instance_ids = new_instance_ids.difference(old_instance_ids) - updated_instance_ids = new_instance_ids.intersection(old_instance_ids) - removed_instance_ids = old_instance_ids.difference(new_instance_ids) - - removed_instances = [ - old_instance_id_to_reservation[id] for id in removed_instance_ids - ] - updated_instances = [ - new_instance_id_to_reservation[id] for id in updated_instance_ids - ] - new_instances = [ - new_instance_id_to_reservation[id] for id in newly_added_instance_ids - ] - - instances = [*removed_instances, *updated_instances, *new_instances] - - write_file(orjson.dumps(instances)) - - previous_instances = instances except Exception as e: logger.error(f"Exception encountered while polling cluster: {e}") - sleep(polling_period) @@ -435,3 +416,35 @@ def _get_azure_client(azure_client_class: Type[AzureClient]) -> AzureClient: def _get_azure_subscription_id(): return os.getenv("AZURE_SUBSCRIPTION_ID") or get_cli_profile().get_login_credentials()[1] + + +def _get_running_vms_by_id( + compute: AzureClient, resource_group_name: str | None, cluster_id: str +) -> dict[str, dict]: + + if resource_group_name: + vms = compute.virtual_machines.list(resource_group_name=resource_group_name) + else: + vms = compute.virtual_machines.list_all() + + current_vms = [ + vm.as_dict() + for vm in vms + if vm.tags.get("Vendor") == "Databricks" and vm.tags.get("ClusterId") == cluster_id + ] + + # A separate api call is required for each vm to see its power state. + # Use a conservative default of assuming the vm is running if the resource group name + # is missing and the api call can't be made. + running_vms_by_id = {} + for vm in current_vms: + if resource_group_name: + vm_state = compute.virtual_machines.instance_view( + resource_group_name=resource_group_name, vm_name=vm["name"] + ) + if len(vm_state.statuses) > 1 and vm_state.statuses[1].code == "PowerState/running": + running_vms_by_id[vm["name"]] = vm + else: + running_vms_by_id[vm["name"]] = vm + + return running_vms_by_id diff --git a/sync/models.py b/sync/models.py index 4be06fb..197a2f9 100644 --- a/sync/models.py +++ b/sync/models.py @@ -112,15 +112,16 @@ class DatabricksClusterReport(BaseModel): cluster: dict cluster_events: dict tasks: List[dict] + instances: Union[List[dict], None] + instance_timelines: Union[List[dict], None] class AWSDatabricksClusterReport(DatabricksClusterReport): - instances: Union[dict, None] - volumes: Union[dict, None] + volumes: Union[List[dict], None] class AzureDatabricksClusterReport(DatabricksClusterReport): - instances: Union[List[dict], None] + pass class DatabricksError(Error): diff --git a/tests/test_awsdatabricks.py b/tests/test_awsdatabricks.py index e256de1..ebceae7 100644 --- a/tests/test_awsdatabricks.py +++ b/tests/test_awsdatabricks.py @@ -813,7 +813,12 @@ def test_create_prediction_for_run_success_with_cluster_instance_file(respx_mock s3_stubber = Stubber(s3) mock_cluster_info_bytes = orjson.dumps( - {**MOCK_INSTANCES, **MOCK_VOLUMES}, + { + "volumes": MOCK_VOLUMES["Volumes"], + "instances": [ + inst for res in MOCK_INSTANCES["Reservations"] for inst in res["Instances"] + ], + }, option=orjson.OPT_UTC_Z | orjson.OPT_OMIT_MICROSECONDS | orjson.OPT_NAIVE_UTC, ) s3_stubber.add_response(