In [0]:
%pip install databricks-sdk==0.39.0
dbutils.library.restartPython()

In [0]:
catalog='dev_bh_datascience'
dbName='ds_workshop'

In [0]:
# You must have `USE CATALOG` privileges on the catalog, and you must have `USE SCHEMA` privileges on the schema.
# If necessary, change the catalog and schema name here.

TABLE_NAME = f"{catalog}.{dbName}.silver_transaction"
TABLE_NAME_PREDICTIONS = f"{catalog}.{dbName}.silver_transaction_predictions"
BASELINE_PREDICTIONS = f"{catalog}.{dbName}.silver_predictions_baseline"

# Define the timestamp column name
TIMESTAMP_COL = "TransactionDate"

## Create a Predictions Table

In [0]:
import pyspark.sql.functions as F


(spark.table(TABLE_NAME)
   .withColumn("Prediction",  F.least(F.greatest(F.col("ProductRating") + F.randn(), F.lit(0.0)), F.lit(5.0)))
   .withColumn("ModelVersion", F.lit("1"))
   .withColumn("Critical", F.when(F.col("UserRole") == "Customer", F.lit(True)).otherwise(F.lit(False)))
   .write
   .option("overwriteSchema", "true")
   .option("delta.enableChangeDataFeed", "true")
   .mode("overwrite")
   .saveAsTable(TABLE_NAME_PREDICTIONS)
)

In [0]:
display(spark.sql(f"SELECT ProductRating, Prediction, ModelVersion, Critical from {TABLE_NAME_PREDICTIONS};"))

## Create baseline prediciton

first day we started making predictions as the baseline

In [0]:
display(spark.sql(f"""
  CREATE OR REPLACE VIEW {BASELINE_PREDICTIONS} AS 
  (SELECT * FROM {TABLE_NAME_PREDICTIONS} 
  WHERE date({TIMESTAMP_COL}) = (select min(date({TIMESTAMP_COL})) as min_date from {TABLE_NAME_PREDICTIONS})
  );"""))

In [0]:
display(spark.sql(f"SELECT * FROM {BASELINE_PREDICTIONS}"))

## Create an Inference Log Monitor

In [0]:
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.catalog import MonitorTimeSeries, MonitorInferenceLog, MonitorInferenceLogProblemType
import os


Note that if you create the monitor without a baseline table, you'll see comparisons to prior periods only. With a baseline table, you'll see the metric change relative to your baseline as wel

In [0]:
# Define time windows to aggregate metrics over
# Note that granularities must be subsets of 
# [5 minutes, 30 minutes, 1 hour, 1 day, 1 month, 1 year] 
# or multiples of [1 week]
GRANULARITIES = ["1 day", "1 week"]

In [0]:
# Create a monitor using a Timeseries profile type. After the intial refresh completes, you can view the autogenerated dashboard from the Quality tab of the table in Catalog Explorer. 
print(f"Creating monitor for {TABLE_NAME_PREDICTIONS}")

w = WorkspaceClient()

try:
  lhm_monitor = w.quality_monitors.create(
      table_name=TABLE_NAME_PREDICTIONS, # Always use 3-level namespace
      inference_log=MonitorInferenceLog(
          problem_type=MonitorInferenceLogProblemType.PROBLEM_TYPE_REGRESSION,
          prediction_col="Prediction",
          timestamp_col=TIMESTAMP_COL,
          granularities=GRANULARITIES,
          model_id_col="ModelVersion",
          label_col="ProductRating"
      ),
      baseline_table_name=BASELINE_PREDICTIONS,
      assets_dir = os.getcwd(),
      output_schema_name=f"{catalog}.{dbName}"
  )

except Exception as lhm_exception:
  if "already exist" in str(lhm_exception):
      print(f"Monitor for {TABLE_NAME_PREDICTIONS} already exists, retrieving monitor info:")
      lhm_monitor = w.quality_monitors.get(table_name=f"{TABLE_NAME_PREDICTIONS}")
  else:
      raise lhm_exception

In [0]:
import time
from databricks.sdk.service.catalog import MonitorInfoStatus, MonitorRefreshInfoState


# Wait for monitor to be created
lhm_monitor = w.quality_monitors.get(table_name=f"{TABLE_NAME_PREDICTIONS}")
while lhm_monitor.status == MonitorInfoStatus.MONITOR_STATUS_PENDING:
  lhm_monitor = w.quality_monitors.get(table_name=f"{TABLE_NAME_PREDICTIONS}")
  time.sleep(10)

assert lhm_monitor.status == MonitorInfoStatus.MONITOR_STATUS_ACTIVE, "Error creating monitor"

In [0]:
refreshes = w.quality_monitors.list_refreshes(table_name=f"{TABLE_NAME_PREDICTIONS}").refreshes
assert(len(refreshes) > 0)

run_info = refreshes[0]
while run_info.state in (MonitorRefreshInfoState.PENDING, MonitorRefreshInfoState.RUNNING):
  run_info = w.quality_monitors.get_refresh(table_name=f"{TABLE_NAME_PREDICTIONS}", refresh_id=run_info.refresh_id)
  time.sleep(30)

assert run_info.state == MonitorRefreshInfoState.SUCCESS, "Monitor refresh failed"

In [0]:
# Display profile metrics table
profile_table = lhm_monitor.profile_metrics_table_name
display(spark.sql(f"SELECT * FROM {profile_table}"))

# Display the drift metrics table
drift_table = lhm_monitor.drift_metrics_table_name
display(spark.sql(f"SELECT * FROM {drift_table}"))

In [0]:
profile_table

In [0]:
drift_table

In [0]:
display(spark.sql(f"SELECT window, mean_squared_error, r2_score from {profile_table} where mean_squared_error is not null"))

In [0]:
display(spark.sql(f"SELECT * from {drift_table} where drift_type = 'BASELINE';"))

## Add Custom Metrics

We'll create two aggregate metrics and then one derived metric that will contain the weighted mean squared error. We're using the Critical column that we added when we created the predictions table to add extra weight to some of the predictions. Remember, we labelled a row as critical if it was for a customer instead of an admin.

The two aggregate metrics will be the sum of weights and the weighted sum of squared prediction errors. Then we'll dive the weighted sum by the sum of the weights to get the weighted mean squared error.

In [0]:
from databricks.sdk.service.catalog import MonitorMetric, MonitorMetricType
from pyspark.sql import types as T


weights_sum = MonitorMetric(
    type=MonitorMetricType.CUSTOM_METRIC_TYPE_AGGREGATE,
    name="weights_sum",
    input_columns=[":table"],
    definition="""sum(CASE
      WHEN {{prediction_col}} = {{label_col}} THEN 0
      WHEN {{prediction_col}} != {{label_col}} AND Critical=TRUE THEN 2
      ELSE 1 END)""",
    output_data_type=T.StructField("weights_sum", T.DoubleType()).json(),
)

In [0]:
weighted_se = MonitorMetric(
    type=MonitorMetricType.CUSTOM_METRIC_TYPE_AGGREGATE,
    name="weighted_se",
    input_columns=[":table"],
    definition="""sum(CASE
      WHEN {{prediction_col}} = {{label_col}} THEN 0
      WHEN {{prediction_col}} != {{label_col}} AND Critical=TRUE THEN 2 * POWER({{prediction_col}} - {{label_col}}, 2)
      ELSE POWER({{prediction_col}} - {{label_col}}, 2) END)""",
    output_data_type=T.StructField("weighted_se", T.DoubleType()).json(),
)

In [0]:
weighted_mse = MonitorMetric(
    type=MonitorMetricType.CUSTOM_METRIC_TYPE_DERIVED,
    name="weighted_mse",
    input_columns=[":table"],
    definition="""weighted_se / weights_sum""",
    output_data_type=T.StructField("weighted_mse", T.DoubleType()).json(),
)

In [0]:
r2_score_delta = MonitorMetric(
    type=MonitorMetricType.CUSTOM_METRIC_TYPE_DRIFT,
    name="r2_score_delta",
    input_columns=[":table"],
    definition="{{current_df}}.r2_score - {{base_df}}.r2_score",
    output_data_type=T.StructField("r2_score_delta", T.DoubleType()).json(),
)

In [0]:
try:
  lhm_monitor = w.quality_monitors.update(
      table_name=TABLE_NAME_PREDICTIONS, # Always use 3-level namespace
      inference_log=MonitorInferenceLog(
          problem_type=MonitorInferenceLogProblemType.PROBLEM_TYPE_REGRESSION,
          prediction_col="Prediction",
          timestamp_col=TIMESTAMP_COL,
          granularities=GRANULARITIES,
          model_id_col="ModelVersion",
          label_col="ProductRating", # optional
      ),
      custom_metrics=[weights_sum, weighted_se, weighted_mse, r2_score_delta],
      baseline_table_name=BASELINE_PREDICTIONS,
      output_schema_name=f"{catalog}.{dbName}"
  )

except Exception as lhm_exception:
  print(lhm_exception)

In [0]:
w = WorkspaceClient()
run_info = w.quality_monitors.run_refresh(TABLE_NAME_PREDICTIONS)

while run_info.state in (MonitorRefreshInfoState.PENDING, MonitorRefreshInfoState.RUNNING):
  run_info = w.quality_monitors.get_refresh(table_name=f"{TABLE_NAME_PREDICTIONS}", refresh_id=run_info.refresh_id)
  time.sleep(30)

assert run_info.state == MonitorRefreshInfoState.SUCCESS, "Monitor refresh failed"

## View Custom Metrics

In [0]:
display(spark.sql(f"SELECT * from {profile_table} limit 10;"))

In [0]:
display(spark.sql(f"SELECT window, mean_squared_error, weights_sum, weighted_mse from {profile_table} where weights_sum is not null;"))

In [0]:
display(spark.sql(f"SELECT window, r2_score_delta, drift_type from {drift_table} limit 10"))

In [0]:
display(spark.sql(f"SELECT window, r2_score_delta, drift_type from {drift_table} where drift_type = 'BASELINE';"))

The following codes will delete the monitor under the quality of the delta table. But the profile table and drift table in unity catalog are note deleted.

In [0]:
# Uncomment the following line of code to clean up the monitor (if you wish to run the quickstart on this table again).
w.quality_monitors.delete(TABLE_NAME_PREDICTIONS)