diff --git a/integration/Gradient-Train-Job.py b/integration/Gradient-Train-Job.py index f9aedf9..8aa02c2 100644 --- a/integration/Gradient-Train-Job.py +++ b/integration/Gradient-Train-Job.py @@ -18,6 +18,8 @@ # MAGIC * The Gradient Webhook has been configured # MAGIC * The Databricks Job has been Gradient enabled # MAGIC +# MAGIC When bypassing the Gradient Webhook with AWS, the cluster attached to this notebook must have an instance_arn with describe_instances and describe_volumes permissions. When bypassing the Gradient Webhook with Azure, the following environment variables must be set with the correct values: "AZURE_TENANT_ID", "AZURE_SUBSCRIPTION_ID", "AZURE_CLIENT_SECRET", "AZURE_CLIENT_ID" +# MAGIC # MAGIC This job will configure all runs to execute using ON DEMAND nodes only. The orginal settings will be restored after training is complete. # MAGIC @@ -41,7 +43,6 @@ os.environ["SYNC_API_KEY_SECRET"] = dbutils.widgets.get("Sync API Key Secret") os.environ["DATABRICKS_HOST"] = dbutils.widgets.get("Databricks Host").rstrip('\/') - print(f"DATABRICKS_JOB_ID: {DATABRICKS_JOB_ID}") print(f"TRAINING_RUNS: {TRAINING_RUNS}") print(f"BYPASS_WEBHOOK: {BYPASS_WEBHOOK}") @@ -73,18 +74,19 @@ else: raise ValueError(f"Unsupported platform: {platform}") +if BYPASS_WEBHOOK: + access_report = sync_databricks.get_access_report() -access_report = sync_databricks.get_access_report() + for line in access_report: + logger.info(line) -for line in access_report: - logger.info(line) - -assert not any(line.status is AccessStatusCode.RED for line in access_report), "Required access is missing" + assert not any(line.status is AccessStatusCode.RED for line in access_report), "Required access is missing" # COMMAND ---------- +from typing import Optional -def get_cluster_for_job(job: dict | None) -> dict: +def get_cluster_for_job(job: Optional[dict]) -> dict: if job is None: job = sync_databricks_client.get_job(DATABRICKS_JOB_ID) @@ -110,11 +112,10 @@ def get_cluster_for_job(job: dict | None) -> dict: else: raise ValueError("Could not identify a cluster for this job") -def get_tag_for_job(job: dict, tag_key: str) -> str | None: +def get_tag_for_job(job: dict, tag_key: str) -> Optional[str]: cluster = get_cluster_for_job(job) return cluster["custom_tags"].get(tag_key) - def validate_job(): logger.info("Validating Databricks Job") job = sync_databricks_client.get_job(DATABRICKS_JOB_ID) @@ -204,7 +205,7 @@ def validate_job(): # COMMAND ---------- class RecommendationError(Exception): - "Raised something goes wrong with the generation of a GradientML Recommendation" + "Raised when something goes wrong with the generation of a GradientML Recommendation" def __init__(self, error): super().__init__("recommendation Error: " + str(error)) @@ -227,7 +228,7 @@ def run_job(run_job_id: str): break return run -def wait_for_recommendation(starting_recommendation_id: str | None) -> None: +def wait_for_recommendation(starting_recommendation_id: Optional[str]) -> None: logger.info(f"waiting for log submission and rec generation and application") logger.info(f"starting recommendation id: {starting_recommendation_id}")