Skip to content
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ FILES := $(shell git diff --name-only --diff-filter=AM $$(git merge-base origin/

.PHONY: test
test:
pytest
pytest -vv
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is consistent across our other repos for make test


.PHONY: lint
lint:
Expand Down
46 changes: 13 additions & 33 deletions sync/_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from datetime import datetime, timezone
from pathlib import Path
from time import sleep
from typing import Any, Collection, Dict, List, Set, Tuple, TypeVar, Union
from typing import Collection, Dict, List, Set, Tuple, Union
from urllib.parse import urlparse

import boto3 as boto
Expand All @@ -21,12 +21,13 @@
from sync.models import (
DatabricksAPIError,
DatabricksClusterReport,
DatabricksError,
DatabricksComputeType,
DatabricksError,
DatabricksPlanType,
Response
Response,
)
from sync.utils.dbfs import format_dbfs_filepath, read_dbfs_file
from sync.utils.json import deep_update

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -102,7 +103,7 @@ def create_submission_with_cluster_info(
cluster_activity_events=cluster_activity_events,
tasks=tasks,
plan_type=plan_type,
compute_type=compute_type
compute_type=compute_type,
)
eventlog = _get_event_log_from_cluster(cluster, tasks).result

Expand Down Expand Up @@ -308,12 +309,12 @@ def _get_cluster_report(


def _create_cluster_report(
cluster: dict,
cluster_info: dict,
cluster_activity_events: dict,
tasks: List[dict],
plan_type: DatabricksPlanType,
compute_type: DatabricksComputeType
cluster: dict,
cluster_info: dict,
cluster_activity_events: dict,
tasks: List[dict],
plan_type: DatabricksPlanType,
compute_type: DatabricksComputeType,
) -> DatabricksClusterReport:
raise NotImplementedError()

Expand Down Expand Up @@ -659,7 +660,7 @@ def get_recommendation_cluster(
if "autoscale" in cluster:
del cluster["autoscale"]

recommendation_cluster = _deep_update(cluster, recommendation["configuration"])
recommendation_cluster = deep_update(cluster, recommendation["configuration"])

return Response(result=recommendation_cluster)
return recommendation_response
Expand Down Expand Up @@ -727,7 +728,7 @@ def get_project_cluster(cluster: dict, project_id: str, region_name: str = None)
project_settings_response = get_project_cluster_settings(project_id, region_name)
project_cluster_settings = project_settings_response.result
if project_cluster_settings:
project_cluster = _deep_update(cluster, project_cluster_settings)
project_cluster = deep_update(cluster, project_cluster_settings)

return Response(result=project_cluster)
return project_settings_response
Expand Down Expand Up @@ -1638,24 +1639,3 @@ def _update_monitored_timelines(
retired_inst_timeline_list.append(active_timelines_by_id.pop(id))

return active_timelines_by_id, retired_inst_timeline_list


KeyType = TypeVar("KeyType")


def _deep_update(
mapping: Dict[KeyType, Any], *updating_mappings: Dict[KeyType, Any]
) -> Dict[KeyType, Any]:
updated_mapping = mapping.copy()
for updating_mapping in updating_mappings:
for k, v in updating_mapping.items():
if k in updated_mapping:
if isinstance(updated_mapping[k], dict) and isinstance(v, dict):
updated_mapping[k] = _deep_update(updated_mapping[k], v)
elif isinstance(updated_mapping[k], list) and isinstance(v, list):
updated_mapping[k] += v
else:
updated_mapping[k] = v
else:
updated_mapping[k] = v
return updated_mapping
92 changes: 90 additions & 2 deletions sync/api/projects.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
"""Project functions
"""
import io
import json
import logging
from time import sleep
from typing import List
from typing import List, Union
from urllib.parse import urlparse

import httpx

from sync.clients.sync import get_default_client
from sync.models import Platform, ProjectError, RecommendationError, Response, SubmissionError
from sync.models import (
AWSProjectConfiguration,
AzureProjectConfiguration,
Platform,
ProjectError,
RecommendationError,
Response,
SubmissionError,
)
from sync.utils.json import deep_update

from . import generate_presigned_url

Expand Down Expand Up @@ -396,3 +406,81 @@ def get_project_submission(project_id: str, submission_id: str) -> Response[dict
return Response(**response)

return Response(result=response["result"])


def get_latest_project_config_recommendation(
project_id: str,
) -> Response[Union[AWSProjectConfiguration, AzureProjectConfiguration]]:
"""Get Latest Project Configuration Recommendation.

:param project_id: project ID
:type project_id: str
:return: Project Configuration Recommendation object
:rtype: AWSProjectConfiguration or AzureProjectConfiguration
"""
latest_recommendation = get_default_client().get_latest_project_recommendation(project_id)
if latest_recommendation.get("result"):
return Response(
result=latest_recommendation["result"][0]["recommendation"]["configuration"]
)


def get_cluster_definition_and_recommendation(
project_id: str, cluster_spec_str: str
) -> Response[dict]:
"""Print Current Cluster Definition and Project Configuration Recommendatio.
Throws error if no cluster recommendation found for project

:param project_id: project ID
:type project_id: str
:param cluster_spec_str: Current Cluster Recommendation
:type cluster_spec_str: str
:return: Current Cluster Definition and Project Configuration Recommendation object
:rtype: dict
"""
recommendation_response = get_latest_project_config_recommendation(project_id)
if not recommendation_response:
logger.info(f"No cluster recommendation found for {project_id}")
return Response(error=RecommendationError(message="Recommendation failed"))
response_str = json.dumps(recommendation_response.result)
return Response(
result={
"cluster_recommendation": json.loads(response_str),
"cluster_definition": json.loads(cluster_spec_str),
}
)


def get_updated_cluster_defintion(
project_id: str, cluster_spec_str: str
) -> Response[Union[AWSProjectConfiguration, AzureProjectConfiguration]]:
"""Return Cluster Definition merged with Project Configuration Recommendations.

:param project_id: project ID
:type project_id: str
:param cluster_spec_str: Current Cluster Recommendation
:type cluster_spec_str: str
:return: Updated Cluster Definition with Project Configuration Recommendations
:rtype: AWSProjectConfiguration or AzureProjectConfiguration
"""
rec_response = get_latest_project_config_recommendation(project_id)
if not rec_response.error:
# Convert Response result object to str
latest_rec_str = json.dumps(rec_response.result)
# Convert json string to json
latest_recommendation = json.loads(latest_rec_str)
cluster_definition = json.loads(cluster_spec_str)
# num_workers/autoscale are mutually exclusive settings, and we are relying on our Prediction
# Recommendations to set these appropriately. Since we may recommend a Static cluster (i.e. a cluster
# with `num_workers`) for a cluster that was originally autoscaled, we want to make sure to remove this
# prior configuration
if "num_workers" in cluster_definition:
del cluster_definition["num_workers"]

if "autoscale" in cluster_definition:
del cluster_definition["autoscale"]

recommendation_cluster = deep_update(cluster_definition, latest_recommendation)
return Response(result=recommendation_cluster)
else:
return rec_response
7 changes: 7 additions & 0 deletions sync/clients/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,13 @@ def get_project_recommendation(self, project_id: str, recommendation_id: str) ->
)
)

def get_latest_project_recommendation(self, project_id: str) -> dict:
return self._send(
self._client.build_request(
"GET", f"/v1/projects/{project_id}/recommendations?page=0&per_page=1"
)
)

def get_project_submissions(self, project_id: str, params: dict = None) -> dict:
return self._send(
self._client.build_request(
Expand Down
40 changes: 39 additions & 1 deletion sync/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from dataclasses import dataclass
from enum import Enum, unique
from typing import Callable, Generic, List, TypeVar, Union
from typing import Callable, Dict, Generic, List, TypeVar, Union

from botocore.exceptions import ClientError
from pydantic import BaseModel, Field, root_validator, validator
Expand Down Expand Up @@ -137,3 +137,41 @@ def check_consistency(cls, err, values):
if err is None and values.get("result") is None:
raise ValueError("must provide result or error")
return err


class S3ClusterLogConfiguration(BaseModel):
destination: str
region: str
enable_encryption: bool
canned_acl: str


class DBFSClusterLogConfiguration(BaseModel):
destination: str


class AWSProjectConfiguration(BaseModel):
node_type_id: str
driver_node_type: str
custom_tags: Dict
cluster_log_conf: Union[S3ClusterLogConfiguration, DBFSClusterLogConfiguration]
cluster_name: str
num_workers: int
spark_version: str
runtime_engine: str
autoscale: Dict
spark_conf: Dict
aws_attributes: Dict
spark_env_vars: Dict


class AzureProjectConfiguration(BaseModel):
node_type_id: str
driver_node_type: str
cluster_log_conf: DBFSClusterLogConfiguration
custom_tags: Dict
num_workers: int
spark_conf: Dict
spark_version: str
runtime_engine: str
azure_attributes: Dict
22 changes: 22 additions & 0 deletions sync/utils/json.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
from json import JSONEncoder
from typing import Any, Dict, TypeVar


class DefaultDateTimeEncoder(JSONEncoder):
Expand Down Expand Up @@ -34,3 +35,24 @@ def default(self, obj):
date = date.isoformat()
date = date.replace("+00:00", "Z")
return date


KeyType = TypeVar("KeyType")


def deep_update(
mapping: Dict[KeyType, Any], *updating_mappings: Dict[KeyType, Any]
) -> Dict[KeyType, Any]:
updated_mapping = mapping.copy()
for updating_mapping in updating_mappings:
for k, v in updating_mapping.items():
if k in updated_mapping:
if isinstance(updated_mapping[k], dict) and isinstance(v, dict):
updated_mapping[k] = deep_update(updated_mapping[k], v)
elif isinstance(updated_mapping[k], list) and isinstance(v, list):
updated_mapping[k] += v
else:
updated_mapping[k] = v
else:
updated_mapping[k] = v
return updated_mapping
Loading