In [0]:
%pip install "databricks-sdk>=0.28.0" -qU
dbutils.library.restartPython()

In [0]:
current_user = dbutils.notebook.entry_point.getDbutils().notebook().getContext().userName().get()
reformat_current_user = current_user.split("@")[0].lower().replace(".", "_")

catalog = "main"
dbName = db = "dbdemos_mlops"

#Helper Class DBDemos with useful functions to streamline common tasks in a Databricks ML workflow

- The setup_schema method lets you create or reset a schema (database) and optionally create a volume. This is useful when you want to ensure your data environment is in a known state before running experiments.
- The download_file_from_git method downloads files from a GitHub repository into a specified local folder. This can be used to pull in code, datasets, or configuration files that are maintained in Git.
- The init_experiment_for_batch method creates a shared MLflow experiment (and the underlying folder) and sets the appropriate permissions using set_experiment_permission. This is helpful when you want a standardized way to start experiment tracking as part of your automated workflow.
- The wait_for_table method polls for the existence and population of a specified table, ensuring that downstream processes only run when the data is available.

In [0]:
import requests
import os
from concurrent.futures import ThreadPoolExecutor

class DBDemos():
    
    @staticmethod
    def setup_schema(catalog, db, reset_all_data, volume_name=None):
        """
        Sets up the Unity Catalog schema (database) and optionally drops and recreates it.
        Optionally creates a volume if a volume name is provided.
        """
        if reset_all_data:
            print(f'Clearing volume: `{catalog}`.`{db}`.`{volume_name}`')
            try:
                spark.sql(f"DROP VOLUME IF EXISTS `{catalog}`.`{db}`.`{volume_name}`")
                spark.sql(f"DROP SCHEMA IF EXISTS `{catalog}`.`{db}` CASCADE")
            except Exception as e:
                print(f'Catalog `{catalog}` or schema `{db}` do not exist. Skipping data reset.')

        def use_and_create_db(catalog, dbName):
            print(f"Using catalog `{catalog}`")
            spark.sql(f"USE CATALOG `{catalog}`")
            spark.sql(f"CREATE DATABASE IF NOT EXISTS `{dbName}`")

        assert catalog not in ['hive_metastore', 'spark_catalog'], "This demo only supports Unity Catalog. Please change your catalog name."
        current_catalog = spark.sql("SELECT current_catalog()").collect()[0]['current_catalog()']
        if current_catalog != catalog:
            catalogs = [r['catalog'] for r in spark.sql("SHOW CATALOGS").collect()]
            if catalog not in catalogs:
                spark.sql(f"CREATE CATALOG IF NOT EXISTS `{catalog}`")
                # Optionally set ownership if using a specific catalog name
                if catalog == 'dbdemos':
                    spark.sql(f"ALTER CATALOG `{catalog}` OWNER TO `account users`")
        use_and_create_db(catalog, db)

        print(f"Using schema: `{catalog}`.`{db}`")
        spark.sql(f"USE `{catalog}`.`{db}`")

        if volume_name:
            spark.sql(f"CREATE VOLUME IF NOT EXISTS {volume_name};")
    
    @staticmethod
    def download_file_from_git(dest, owner, repo, path):
        """
        Downloads files from a GitHub repository into the destination folder.
        """
        def download_file(url, destination):
            local_filename = url.split('/')[-1]
            with requests.get(url, stream=True) as r:
                r.raise_for_status()
                print(f"Saving {destination}/{local_filename}")
                with open(os.path.join(destination, local_filename), 'wb') as f:
                    for chunk in r.iter_content(chunk_size=8192):
                        f.write(chunk)
            return local_filename

        if not os.path.exists(dest):
            os.makedirs(dest)
        
        files = requests.get(f'https://api.github.com/repos/{owner}/{repo}/contents{path}').json()
        files = [f['download_url'] for f in files if 'NOTICE' not in f['name']]
        
        def download_to_dest(url):
            try:
                # Optionally switch to an alternate URL if needed
                s3url = url.replace("https://raw.githubusercontent.com/databricks-demos/dbdemos-dataset/main/", 
                                      "https://dbdemos-dataset.s3.amazonaws.com/")
                download_file(s3url, dest)
            except Exception:
                download_file(url, dest)
        
        with ThreadPoolExecutor(max_workers=10) as executor:
            list(executor.map(download_to_dest, files))
    
    @staticmethod
    def init_experiment_for_batch(demo_name, experiment_name):
        """
        Initializes an MLflow experiment in a shared folder and sets permissions.
        """
        import mlflow
        from databricks.sdk import WorkspaceClient
        
        w = WorkspaceClient()
        xp_root_path = f"/Shared/dbdemos/experiments/{demo_name}"
        try:
            w.workspace.mkdirs(path=xp_root_path)
        except Exception as e:
            print(f"ERROR: Couldn't create folder for experiment under {xp_root_path}. Please create it manually or skip init. Error: {e}")
            raise e
        
        xp = f"{xp_root_path}/{experiment_name}"
        print(f"Using experiment: {xp}")
        mlflow.set_experiment(xp)
        DBDemos.set_experiment_permission(xp)
        return mlflow.get_experiment_by_name(xp)
    
    @staticmethod
    def set_experiment_permission(experiment_path):
        """
        Sets permissions on the experiment folder so that all users can manage it.
        """
        from databricks.sdk import WorkspaceClient
        from databricks.sdk.service import iam
        
        w = WorkspaceClient()
        try:
            status = w.workspace.get_status(experiment_path)
            w.permissions.set("experiments", request_object_id=status.object_id, access_control_list=[
                iam.AccessControlRequest(group_name="users", permission_level=iam.PermissionLevel.CAN_MANAGE)
            ])
        except Exception as e:
            print(f"Error setting permissions for experiment {experiment_path}: {e}")
        print(f"Experiment {experiment_path} permissions set to public (users CAN_MANAGE).")
    
    @staticmethod
    def wait_for_table(table_name, timeout_duration=120):
        """
        Waits for a table to exist and be non-empty, or raises an exception after timeout.
        """
        import time
        i = 0
        while not spark.catalog.tableExists(table_name) or spark.table(table_name).count() == 0:
            time.sleep(1)
            i += 1
            if i > timeout_duration:
                raise Exception(f"Could not find table {table_name} or table is empty.")



In [0]:
import mlflow
import pandas as pd
import random
import re
import logging
logging.getLogger("mlflow").setLevel(logging.ERROR)

from mlflow import MlflowClient

# Set UC Model Registry as default
mlflow.set_registry_uri("databricks-uc")

# create an instance of MLflowClient (assigned to client), which allows subsequent code to interact with MLflow for tracking experiments and managing models.
client = MlflowClient()


#Setting up a bronze layer table for the ML Flow
Steps

1) Download dataset (csv file) from GitHub
2) Read into a pandas dataframe, clean data
3) Create a bronze table: Convert cleaned pandas dataframe into a table in spark

In [0]:
bronze_table_name = "mlops_churn_bronze_customers"

if not spark.catalog.tableExists(bronze_table_name):
    import requests 
    from io import StringIO
    #Dataset under apache license: https://github.com/IBM/telco-customer-churn-on-icp4d/blob/master/LICENSE
    csv = requests.get("https://raw.githubusercontent.com/IBM/telco-customer-churn-on-icp4d/master/data/Telco-Customer-Churn.csv").text
    df = pd.read_csv(StringIO(csv), sep=",")

    def cleanup_column(pdf):
        # Clean up column names
        pdf.columns = [re.sub(r'(?<!^)(?=[A-Z])', '_', name).lower().replace("__", "_") for name in pdf.columns]
        pdf.columns = [re.sub(r'[\(\)]', '', name).lower() for name in pdf.columns]
        pdf.columns = [re.sub(r'[ -]', '_', name).lower() for name in pdf.columns]
        return pdf.rename(columns = {'streaming_t_v': 'streaming_tv', 'customer_i_d': 'customer_id'})
    
    df = cleanup_column(df)
    print(f"creating `{bronze_table_name}` raw table")
    #convert the in-memory (pandas) data into a persistant Spark table --> becomes a managed table in the databricks environment, making it accessible to other notebooks and jobs. It also enforces schema consistency and guarantees that any previous version of the table is replaced with new data
    spark.createDataFrame(df).write.mode("overwrite").option("overwriteSchema", "true").saveAsTable(bronze_table_name)


    

#Integrating and automating AutoML into the workflow

In [0]:
from pyspark.sql.functions import col
#from databricks.feature_store import FeatureStoreClient
import mlflow
import databricks
from datetime import datetime

def get_automl_run(name):
  #get the most recent automl run
  df = spark.table("field_demos_metadata.automl_experiment").filter(col("name") == name).orderBy(col("date").desc()).limit(1)
  return df.collect()

#Get the automl run information from the field_demos_metadata.automl_experiment table. 
#If it's not available in the metadata table, start a new run with the given parameters
def get_automl_run_or_start(name, model_name, dataset, target_col, timeout_minutes):
  spark.sql("create database if not exists field_demos_metadata")
  spark.sql("create table if not exists field_demos_metadata.automl_experiment (name string, date string)")
  result = get_automl_run(name)
  if len(result) == 0:
    print("No run available, start a new Auto ML run, this will take a few minutes...")
    start_automl_run(name, model_name, dataset, target_col, timeout_minutes)
    result = get_automl_run(name)
  return result[0]


#Start a new auto ml classification task and save it as metadata.
def start_automl_run(name, model_name, dataset, target_col, timeout_minutes = 5):
  automl_run = databricks.automl.classify(
    dataset = dataset,
    target_col = target_col,
    timeout_minutes = timeout_minutes
  )
  experiment_id = automl_run.experiment.experiment_id
  path = automl_run.experiment.name
  data_run_id = mlflow.search_runs(experiment_ids=[automl_run.experiment.experiment_id], filter_string = "tags.mlflow.source.name='Notebook: DataExploration'").iloc[0].run_id
  exploration_notebook_id = automl_run.experiment.tags["_databricks_automl.exploration_notebook_id"]
  best_trial_notebook_id = automl_run.experiment.tags["_databricks_automl.best_trial_notebook_id"]

  cols = ["name", "date", "experiment_id", "experiment_path", "data_run_id", "best_trial_run_id", "exploration_notebook_id", "best_trial_notebook_id"]
  spark.createDataFrame(data=[(name, datetime.today().isoformat(), experiment_id, path, data_run_id, automl_run.best_trial.mlflow_run_id, exploration_notebook_id, best_trial_notebook_id)], schema = cols).write.mode("append").option("mergeSchema", "true").saveAsTable("field_demos_metadata.automl_experiment")
  #Create & save the first model version in the MLFlow repo (required to setup hooks etc)
  mlflow.register_model(f"runs:/{automl_run.best_trial.mlflow_run_id}/model", model_name)
  return get_automl_run(name)

#Generate nice link for the given auto ml run
def display_automl_link(name, model_name, dataset, target_col, force_refresh=False, timeout_minutes = 5):
  r = get_automl_run_or_start(name, model_name, dataset, target_col, timeout_minutes)
  html = f"""For exploratory data analysis, open the <a href="/#notebook/{r["exploration_notebook_id"]}">data exploration notebook</a><br/><br/>"""
  html += f"""To view the best performing model, open the <a href="/#notebook/{r["best_trial_notebook_id"]}">best trial notebook</a><br/><br/>"""
  html += f"""To view details about all trials, navigate to the <a href="/#mlflow/experiments/{r["experiment_id"]}/s?orderByKey=metrics.%60val_f1_score%60&orderByAsc=false">MLflow experiment</>"""
  displayHTML(html)


def display_automl_churn_link(): 
  display_automl_link("churn_auto_ml", "field_demos_customer_churn", spark.table("churn_features"), "churn", 5)

def get_automl_churn_run(): 
  return get_automl_run_or_start("churn_auto_ml", "field_demos_customer_churn", spark.table("churn_features"), "churn", 5)