In [0]:
#dbutils.widgets.removeAll()


### Batch Inference

 - This notebook is designed to take rows from input_table_name, and use ai_query to create results in output_table_name.
 - You can keep adding rows to input_table_name, and run this notebook. The new rows will be appended (with results) into output_table_name.
 - The notebook uses structured streaming to create batches of approximately inference_rows_per_batch. This allows recovery from unexpected failure (for example of the endpoint). Just restart the job.
 - Each row will be processed exactly once. [Currently, there is a very small chance of one batch being reprocessed if the job fails. This will be fixed shortly.]

#### To use

 - Specify all the non-optional parameters and run the notebook.
 - You can schedule this notebook as a job. Specify all the parameters required to the job and trigger the job when new data is required to be processed from input_table_name.

#### How does it work?

 - There are two streams.
 - Stream1
   - Copies (input_table_pk, input_column_name) into an intermediate table (in tmp_schema), with partitions (or files) with approximately rows_per_batch. This is requied to do the actual ai_query processing in batches of rows_per_batch.
 - Stream2
   - Creates the smallest batch size possible (one batch per file in the intermediate table), calls ai_query and appends the results into output_table_name.
 - Structured streaming is resilient to crashes, as its checkpoint and write to Delta are transactional. So just restarting the job if it crashes (for a transient error) will work.

In [0]:
# This is required for provisioning_utils (see below)
try:
    from mlflow.utils.databricks_utils import get_databricks_host_creds, get_workspace_url
    print("MLflow utilities imported successfully.")
except ImportError as e:
    print("Failed to import MLflow utilities. Installing MLflow...")
    # Install MLflow
    %pip install mlflow
    # Restart Python to ensure the newly installed library is loaded
    dbutils.library.restartPython()


In [0]:
import sys, os
import math
import datetime, time

In [0]:
%run ./utils

In [0]:
import requests, json, math, os, time, uuid
import numpy as np
import pandas as pd
from pyspark.sql.functions import col, rand
import datetime
import math

# Set up widgets

# Widgets for input and output data.
dbutils.widgets.text("input_table_name", "", "01. Input Table Name")
dbutils.widgets.text("input_column_name", "", "02. Input Column Name")
dbutils.widgets.text("input_table_pk", "", "03. (optional) Input Table Primary Key (comma-separated if composite)")
dbutils.widgets.text("output_table_name", "", "04. Output Table Name")
dbutils.widgets.text("output_column_name", "results", "05. Output Column Name")
dbutils.widgets.text("tmp_schema", "", "06. Temp Schema for temp tables")

# Widgets for model parameters
dbutils.widgets.text("endpoint_name", "", "07. (Optional) Endpoint Name")
dbutils.widgets.dropdown("use_existing_endpoint", "false", ["true", "false"], "071. Use existing endpoint (see instructions)")
dbutils.widgets.text("model_uc_path", "", "08. Model UC Path")
dbutils.widgets.text("model_uc_version", "", "09. (optional) Model Version")
dbutils.widgets.text("ptus", "", "10. Provisioned Throughput Units")
dbutils.widgets.text("prompt", "", "11. Prompt")
dbutils.widgets.text("checkpoint_base_path", "", "12. Base path for stream checkpoint")
dbutils.widgets.text("model_param_max_tokens", "", "13. (optional) Model Param: Max Tokens")
dbutils.widgets.text("model_param_temperature", "0", "14. (optional) Model Param: Temperature")

# Widgets for batch inference
dbutils.widgets.text("inference_processing_rate_rows_per_second", "", "15. (Optional) Inference Processing Rate: Rows per Second per PTU")  # Derived from experimentation
dbutils.widgets.text("inference_rows_per_batch", "", "16. Rows per streaming mini batch")  # Approximate rows per streaming batch for inference

dbutils.widgets.text("repartition_stream", "", "17. (optional) Number of repartitions to the stream") 
dbutils.widgets.dropdown("is_debug",  "false", ["true", "false"], "18. Is Debug")  # When scheduled, set this to false


In [0]:
params_dict = dbutils.widgets.getAll()
json.dumps(params_dict)

In [0]:
# DEBUG
# TODO: Work out how to get this in a debug environment because I need to remove this cell
if params_dict.get("is_debug", "false").lower() == "true":
  print("DEBUG: Overriding with debug parameters")
  with open('test_batch_parameters.json', 'r') as f:
    params_dict = json.load(f)


In [0]:

params = AttributeDict(params_dict)

In [0]:
print("Pre-parsing Parameters:")
print(params)

In [0]:
# Validate parameters

# Data
assert params.input_table_name.count(".") == 2, "input_table_name should be in form catalog.schema.table"
assert params.input_column_name, "input_column_name can not be empty"

input_table_pk = []
if params.input_table_pk:
    input_table_pk = [e.strip() for e in params.input_table_pk.split(',')]
assert params.output_table_name.count(".") == 2, "output_table_name should be in form catalog.schema.table"
assert params.output_column_name, "output_column_name can not be empty"
assert params.tmp_schema.count(".") == 1, "tmp_schema should be in form catalog.schema"
# TODO: check if tmp_schema exists
assert params.checkpoint_base_path, "checkpoint_base_path can not be empty and ideally should be exclusively for this batch inference notebook"
# TODO: check checkpoint_base_path exists

# Model
if params.endpoint_name:
    assert params.endpoint_name.replace('-', '').replace('_', '').isalnum(), "Endpoint name must be alphanumeric with hyphens and underscores allowed in between."
params.use_existing_endpoint = str_to_bool(params.use_existing_endpoint, default_value='false')
if params.use_existing_endpoint:
    assert params.endpoint_name, "endpoint_name can not be empty if use_existing_endpoint is set to true"
assert params.model_uc_path.count(".") == 2, "model_uc_path should be in form catalog.schema.model_name"
if params.model_uc_version:
    assert params.model_uc_version.isdigit(), "model_uc_version should be a number"
    params.model_uc_version = int(params.model_uc_version)
    assert params.model_uc_version >= 1, "model_uc_version should be a positive integer"
params.ptus = int(params.ptus)
assert params.ptus > 0, "ptus should be a positive integer"
if not params.prompt:
  print("WARNING: There is no prompt. This means your input column should have a prompt.")

if params.model_param_max_tokens:
    assert params.model_param_max_tokens.isdigit(), "model_param_max_tokens should be a number"
    params.model_param_max_tokens = int(params.model_param_max_tokens)
    assert params.model_param_max_tokens >= 1, "model_param_max_tokens should be a positive integer"

if params.model_param_temperature:
    try:
        params.model_param_temperature = float(params.model_param_temperature)
    except ValueError:
        raise AssertionError("model_param_temperature should be a float")
    assert params.model_param_temperature >= 0, "model_param_temperature should be between 0 and 1"
    assert params.model_param_temperature < 1, "model_param_temperature should be between 0 and 1"

if params.inference_processing_rate_rows_per_second:
    assert params.inference_processing_rate_rows_per_second.isdigit(), "inference_processing_rate_rows_per_second should be a number"
    params.inference_processing_rate_rows_per_second = int(params.inference_processing_rate_rows_per_second)
    assert params.inference_processing_rate_rows_per_second > 0, "inference_processing_rate_rows_per_second should be a positive number"

assert params.inference_rows_per_batch.isdigit(), "inference_rows_per_batch should be a number"
params.inference_rows_per_batch = int(params.inference_rows_per_batch)
assert params.inference_rows_per_batch > 0, "inference_rows_per_batch should be a number"

if params.repartition_stream:
    params.repartition_stream = int(params.repartition_stream)
    assert params.repartition_stream > 0, "repartition_stream should be a number"

In [0]:
print("Post-parsing Parameters:")
print(params)

In [0]:
# # Get parameters
# print("Post-parsing Parameters:")
# print("----------")
# print("\nData parameters:")
# print(f"input_table_name: {input_table_name}")
# print(f"input_column_name: {input_column_name}")
# print(f"output_table_name: {output_table_name}")
# print(f"output_column_name: {output_column_name}")
# print(f"tmp_schema: {tmp_schema}")

# print("\nModel parameters:")
# print(f"endpoint_name: {endpoint_name}")
# print(f"model_uc_path: {model_uc_path}")
# print(f"prompt: {prompt}")
# print(f"model_uc_version: {model_uc_version}")
# print(f"max_tokens: {max_tokens}")
# print(f"temperature: {temperature}")
# print(f"ptus: {ptus}")
# print(f"checkpoint_base_path: {checkpoint_base_path}")
# print(f"max_tokens: {max_tokens}")
# print(f"temperature: {temperature}")

# print("\nBatch inference parameters:")
# print(f"inference_processing_rate_rows_per_second: {inference_processing_rate_rows_per_second}")
# print(f"inference_rows_per_batch: {inference_rows_per_batch}")
# print("----------")



In [0]:
if params.inference_processing_rate_rows_per_second:
  expected_time_per_batch_seconds = params.inference_rows_per_batch / (params.inference_processing_rate_rows_per_second * params.ptus)
  print(f"expected_time_per_batch_seconds: {expected_time_per_batch_seconds}")

In [0]:
total_source_size_rows = spark.sql(f"SELECT COUNT(1) FROM {params.input_table_name}").collect()[0][0]
print(f"total_source_size_rows: {total_source_size_rows}")

total_target_size_rows = 0
if spark.catalog.tableExists(params.output_table_name):
  total_target_size_rows = spark.sql(f"SELECT COUNT(1) FROM {params.output_table_name}").collect()[0][0]
print(f"total_target_size_rows: {total_target_size_rows}")

new_rows_to_process = total_source_size_rows - total_target_size_rows
print(f"new_rows_to_process: {new_rows_to_process}")

if params.inference_processing_rate_rows_per_second:
  expected_time_for_workload_seconds = new_rows_to_process / (params.inference_processing_rate_rows_per_second * params.ptus)
  print(f"expected_time_for_workload_seconds: {expected_time_for_workload_seconds} ({expected_time_for_workload_seconds / 60:.2f} minutes)")

In [0]:
num_partitions = math.ceil(new_rows_to_process / params.inference_rows_per_batch)
print(f"num_partitions: {num_partitions}")

In [0]:
# tmp_input_table should only depend on the source and target so it is deterministic and the same across restarts of the job.
tmp_input_table = f'{params.tmp_schema}.tmp_{params.input_table_name.replace(".", "_")}_{params.output_table_name.replace(".", "_")}'
print(tmp_input_table)

In [0]:
# Turn off optimized writes to create small files
# Only create the table if it does not exist.
# If it already exists, then do nothing. It should already have the data from a previous run, so just use it.
select_str = ", ".join([e for e in (*input_table_pk, params.input_column_name, "current_timestamp as ctime")])
sql_str = f"""
create table if not exists {tmp_input_table}
tblproperties('delta.autoOptimize.autoCompact' = false, 'delta.autoOptimize.optimizeWrite' = false)
as
-- select {params.input_table_pk}, {params.input_column_name}, current_timestamp as ctime
select {select_str}
from {params.input_table_name}
limit 0;
"""
print(sql_str)

In [0]:
spark.sql(sql_str)

In [0]:
# I do not want this table to be optimized into larger files.
sql_str = f"alter table {tmp_input_table} disable predictive optimization"
print(sql_str)

In [0]:
import traceback

try:
  spark.sql(sql_str)
except:
  traceback.print_exc()

In [0]:
def write_with_partitions(df, batch_id):
  num_rows = df.count()
  num_partitions = math.ceil(num_rows / params.inference_rows_per_batch)
  df.repartition(num_partitions).write.mode("append").saveAsTable(tmp_input_table)

In [0]:
# tmp_input_table uses the source and target tables, so if any one of those changes, this checkpoint will change.
# Always just use the source and destination of *this* stream as the checkpoint.
stream1_checkpoint_path = f"{params.checkpoint_base_path}/{params.input_table_name.lower().replace('.', '_')}/{tmp_input_table.replace('.', '_')}"
print(stream1_checkpoint_path)

In [0]:
stream1_select_expr = tuple([*input_table_pk, params.input_column_name, "current_timestamp as ctime"])
print(stream1_select_expr)


In [0]:
stream1_df = (
  spark.readStream
  #.option("maxBytesPerTrigger", 1)
  .table(params.input_table_name)
  .selectExpr(*stream1_select_expr)
  .writeStream
  .trigger(availableNow=True)
  .option("checkpointLocation", stream1_checkpoint_path)
  .foreachBatch(write_with_partitions)
)

In [0]:
stream1_df.start().awaitTermination()

In [0]:
provisioning_utils = ProvisioningUtils(use_existing_endpoint=params.use_existing_endpoint)


In [0]:
model_endpoint_name = params.endpoint_name
if not model_endpoint_name:
  ts = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
  model_endpoint_name = f"batch_inference_{ts}"
print(f"model_endpoint_name: {model_endpoint_name}")

provisioning_utils.create_endpoint(name=model_endpoint_name, model_name=params.model_uc_path, ptu=params.ptus, workload_size="Large")


In [0]:
provisioning_utils.monitor_endpoint(name=model_endpoint_name)

In [0]:
model_params_str = f"named_struct('temperature', {params.model_param_temperature})"
if params.model_param_max_tokens:
  model_params_str = f"named_struct('max_tokens', {params.model_param_max_tokens}, 'temperature', {params.model_param_temperature})"
ai_query_expr = f"""ai_query('{model_endpoint_name}', CONCAT('{params.prompt}', {params.input_column_name}), modelParameters => {model_params_str}) as {params.output_column_name}"""
print(ai_query_expr)

select_expr = tuple([*input_table_pk, params.input_column_name, ai_query_expr, "current_timestamp as ctime"])
#select_expr = f'"{input_table_pk}", "{input_column_name}", "{ai_query_expr}"'
print(select_expr)


In [0]:
checkpoint_path = f"{params.checkpoint_base_path}/{tmp_input_table.replace('.', '_')}/{params.output_table_name.lower().replace('.', '_')}"
print(checkpoint_path)

In [0]:
read_stream_df = (
  spark.readStream
  .option("maxBytesPerTrigger", 1)
  .table(tmp_input_table)
)

In [0]:
if params.repartition_stream and params.repartition_stream > 0:
  print(f"repartitioning stream [{params.repartition_stream}]")
  read_stream_df = read_stream_df.repartition(params.repartition_stream)

In [0]:
read_stream_df =  read_stream_df.selectExpr(*select_expr)


In [0]:
stream_start = time.time()
print(f"stream_start: {stream_start}")

write_stream_df = (
  read_stream_df
  .writeStream
  .trigger(availableNow=True)
  .option("checkpointLocation", checkpoint_path)
  .outputMode("append")
  .option("mergeSchema", "true")
  .toTable(params.output_table_name)
)

In [0]:
write_stream_df.awaitTermination()
stream_end = time.time()
print(f"stream_end: {stream_end}")

print(f"stream_duration_seconds: {stream_end - stream_start}")

In [0]:
if (not params.endpoint_name 
    or not params.use_existing_endpoint):
  print(f"stopping endpoint {model_endpoint_name}")
  r = provisioning_utils.stop_endpoint(name=model_endpoint_name)
  print(r)