In [0]:
# %pip install mlflow==2.19.0

In [0]:
# %pip install databricks-sdk==0.47.0

In [0]:
# dbutils.library.restartPython()

In [0]:
%pip show databricks-sdk 

In [0]:
import mlflow
print(mlflow.__version__)

In [0]:
# dbutils.widgets.removeAll()
dbutils.widgets.text("catalog_names", "")#source and target catalog names like dictionary
dbutils.widgets.text("target_schema_owner", "")
dbutils.widgets.text("max_workers", "")
dbutils.widgets.text("log_table_name", "")

In [0]:
catalog_pair_names_list = [catalog_pair.strip() for catalog_pair in dbutils.widgets.get("catalog_names").split(",")]
target_schema_owner_list = [catalog_schema_pair.strip() for catalog_schema_pair in dbutils.widgets.get("target_schema_owner").split(",")]
max_workers = int(dbutils.widgets.get("max_workers"))
log_table_name = dbutils.widgets.get("log_table_name")
print(catalog_pair_names_list, target_schema_owner_list, max_workers, log_table_name)

In [0]:
# for catalog_pair in catalog_pair_names_list:
#     second_value = catalog_pair.split(":")[1]
#     print(second_value)
#     spark.sql(f"DROP CATALOG IF EXISTS `{second_value}` CASCADE")

In [0]:
from mlflow.tracking import MlflowClient
from mlflow.exceptions import RestException
from pyspark.sql.types import StructType, StructField, StringType, ArrayType, TimestampType
from datetime import datetime
import traceback
from concurrent.futures import ThreadPoolExecutor


# Get the current notebook name
notebook_path = dbutils.notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get()
notebook_name = notebook_path.split('/')[-1]

# Schema definition for log table
log_table_schema = StructType([
    StructField("notebook_name", StringType(), True),
    StructField("entity_type", StringType(), True),
    StructField("entity_name", StringType(), True),
    StructField("action", StringType(), True),
    StructField("status", StringType(), True),
    StructField("message", StringType(), True),
    StructField("timestamp", TimestampType(), True),
    StructField("results_data", ArrayType(StructType([
        StructField("key", StringType(), True),
        StructField("value", StringType(), True)
    ])), True)
])

# Create log table
spark.sql(f"""
    CREATE TABLE IF NOT EXISTS {log_table_name}
    (
        notebook_name STRING,
        entity_type STRING,
        entity_name STRING,
        action STRING,
        status STRING,
        message STRING,
        timestamp TIMESTAMP,
        results_data ARRAY<STRUCT<key: STRING, value: STRING>>
    )
""")

# --- Helper Functions (modified from template) ---
def execute_and_log_sql(sql_command, entity_type, entity_name, action, success_message, error_message):
    try:
        results_df = spark.sql(sql_command)
        results = [
            {"key": col_name, "value": str(row[col_name])}
            for row in results_df.collect()
            for col_name in row.__fields__
        ]
        log_entry = {
            "entity_type": entity_type,
            "entity_name": entity_name,
            "action": action,
            "status": "success",
            "message": success_message,
            "timestamp": datetime.now(),
            "results_data": results
        }
    except Exception as e:
        log_entry = {
            "entity_type": entity_type,
            "entity_name": entity_name,
            "action": action,
            "status": "error",
            "message": f"{error_message}. Error: {e}",
            "timestamp": datetime.now(),
            "results_data": []
        }
    insert_log_entry(log_entry)

def insert_log_entry(log_data):
    try:
        # Add the notebook_name to the log_data dictionary
        log_data["notebook_name"] = notebook_name
        spark.createDataFrame([log_data], schema=log_table_schema).write.mode("append").saveAsTable(log_table_name)
    except Exception as e:
        print(f"Failed to insert log entry: {log_data}. Error: {e}")

def create_catalog_if_not_exists(target_catalog):
    sql_command = f"CREATE CATALOG IF NOT EXISTS `{target_catalog}`"
    execute_and_log_sql(
        sql_command,
        "catalog",
        f"`{target_catalog}`",
        "create",
        f"Catalog `{target_catalog}` created or already exists.",
        f"Failed to create catalog `{target_catalog}`."
    )

def create_schema_if_not_exists(target_catalog, schema_name):
    target_schema_fqn = f"`{target_catalog}`.`{schema_name}`"
    execute_and_log_sql(
        f"CREATE SCHEMA IF NOT EXISTS {target_schema_fqn}",
        "schema",
        target_schema_fqn,
        "create",
        f"Schema {target_schema_fqn} created or already exists.",
        f"Failed to create schema {target_schema_fqn}."
    )

# --- Core Logic for Model Synchronization ---
def get_all_registered_models_once():
    """
    Retrieves all registered models from the workspace a single time.
    Returns a list of all MLflow registered model objects.
    """
    client = MlflowClient()
    models = []
    page_token = None
    
    while True:
        try:
            results = client.search_registered_models(max_results=100, page_token=page_token)
            models.extend(results)
            if results.token:
                page_token = results.token
            else:
                break
        except RestException as e:
            error_message = f"Failed to search all models in the workspace. Error: {e}"
            insert_log_entry({
                "entity_type": "model_search",
                "entity_name": "workspace",
                "action": "search",
                "status": "error",
                "message": error_message,
                "timestamp": datetime.now(),
                "results_data": []
            })
            return []
    return models

def get_target_model_versions(target_model_fqn):
    """
    Retrieves a set of version numbers for a given model in the target catalog.
    Returns a set of integers.
    """
    client = MlflowClient()
    try:
        model = client.get_registered_model(target_model_fqn)
        return {v.version for v in model.latest_versions}
    except RestException:
        # Model does not exist, which is expected for new models.
        return set()

def update_aliases(source_model_fqn, source_version, target_model_fqn, target_version):
    """
    Copies aliases from a source model version to a target model version.
    """
    client = MlflowClient()
    try:
        source_version_details = client.get_model_version(source_model_fqn, source_version)
        source_aliases = source_version_details.aliases
        if source_aliases:
            for alias in source_aliases:
                print(f"  Setting alias '{alias}' on target version {target_version}")
                client.set_registered_model_alias(target_model_fqn, alias, target_version)
                insert_log_entry({
                    "entity_type": "model_alias",
                    "entity_name": f"{target_model_fqn}/{target_version}/{alias}",
                    "action": "set",
                    "status": "success",
                    "message": f"Successfully set alias '{alias}' on model version {target_version}.",
                    "timestamp": datetime.now(),
                    "results_data": [
                        {"key": "model_name", "value": target_model_fqn},
                        {"key": "version", "value": str(target_version)},
                        {"key": "alias", "value": alias}
                    ]
                })
    except RestException as e:
        error_message = f"Failed to copy aliases for version {source_version}. Error: {e}"
        insert_log_entry({
            "entity_type": "model_alias",
            "entity_name": f"{target_model_fqn}/{target_version}",
            "action": "set",
            "status": "error",
            "message": error_message,
            "timestamp": datetime.now(),
            "results_data": []
        })

def process_model_sync(source_catalog, source_schema, target_catalog, target_schema, model):
    """
    Processes a single registered model, checking and copying new versions.
    """
    source_model_fqn = model.name
    # Construct the target model's fully qualified name
    source_parts = source_model_fqn.split('.')
    model_name = source_parts[-1]
    target_model_fqn = f"{target_catalog}.{target_schema}.{model_name}"

    print(f"Processing model: {source_model_fqn} -> {target_model_fqn}")

    # Get existing versions in the target
    target_versions = get_target_model_versions(target_model_fqn)

    # Get all versions from the source model
    client = MlflowClient()
    try:
        source_model_versions = client.search_model_versions(filter_string=f"name='{source_model_fqn}'")
    except RestException as e:
        error_message = f"Failed to retrieve versions for source model {source_model_fqn}. Error: {e}"
        insert_log_entry({
            "entity_type": "model",
            "entity_name": source_model_fqn,
            "action": "get_versions",
            "status": "error",
            "message": error_message,
            "timestamp": datetime.now(),
            "results_data": []
        })
        return
    
    # Sort the versions by version number in ascending order
    sorted_versions = sorted(source_model_versions, key=lambda version: int(version.version))

    # Loop through source versions and copy if they don't exist in the target
    for version in sorted_versions:
        source_version_number = int(version.version)
        if source_version_number not in target_versions:
            print(f".  Copying new version {source_version_number} for model {model_name}")
            try:
                copy_result = client.copy_model_version(
                    f"models:/{source_model_fqn}/{source_version_number}",
                    f"{target_model_fqn}"
                )
                
                # Extract and format the output for results_data
                results_data_output = [
                    {"key": "copy_status", "value": "success"},
                    {"key": "source_uri", "value": f"models:/{source_model_fqn}/{source_version_number}"},
                    {"key": "target_uri", "value": f"{target_model_fqn}"},
                    {"key": "version_copied", "value": str(source_version_number)},
                    {"key": "version_result", "value": str(copy_result)}
                ]

                target_version_number = copy_result.version

                insert_log_entry({
                    "entity_type": "model_version",
                    "entity_name": f"{target_model_fqn}/{source_version_number}",
                    "action": "copy",
                    "status": "success",
                    "message": f"Successfully copied model version {source_version_number} from {source_model_fqn} to {target_model_fqn}.",
                    "timestamp": datetime.now(),
                    "results_data": results_data_output
                })

                # --- NEW: Copy aliases after successful version copy ---
                update_aliases(source_model_fqn, source_version_number, target_model_fqn, target_version_number)

            except RestException as e:
                error_message = f"Failed to copy model version {source_version_number} from {source_model_fqn} to {target_model_fqn}. Error: {e}"
                insert_log_entry({
                    "entity_type": "model_version",
                    "entity_name": f"{target_model_fqn}/{source_version_number}",
                    "action": "copy",
                    "status": "error",
                    "message": error_message,
                    "timestamp": datetime.now(),
                    "results_data": []
                })
        else:
            print(f"  Version {source_version_number} already exists in target. Skipping.")
            insert_log_entry({
                "entity_type": "model_version",
                "entity_name": f"{target_model_fqn}/{source_version_number}",
                "action": "copy",
                "status": "skipped",
                "message": f"Version {source_version_number} for model {target_model_fqn} already exists. Skipping copy.",
                "timestamp": datetime.now(),
                "results_data": []
            })
            
# --- Main function to orchestrate the process ---
def main(catalog_pair_names_list, target_schema_owner_list, max_workers):
    print("Starting model synchronization process...")
    
    # Get all models once at the beginning
    all_workspace_models = get_all_registered_models_once()
    print(f"Total models found in workspace: {len(all_workspace_models)}")
    
    all_tasks = []

    for catalog_pair_string in catalog_pair_names_list:
        try:
            source_catalog, target_catalog = catalog_pair_string.split(':')
            print(f"Processing catalog pair: {source_catalog} -> {target_catalog}")
            
            create_catalog_if_not_exists(target_catalog)
            
            source_schemas_df = spark.sql(f"SHOW SCHEMAS IN `{source_catalog}`")
            source_schemas = [row.databaseName for row in source_schemas_df.collect()]
            
            for schema in source_schemas:
                if schema == 'information_schema':
                    continue

                source_schema_fqn = f"`{source_catalog}`.`{schema}`"
                target_schema_fqn = f"`{target_catalog}`.`{schema}`"
                print(f"  Processing schema: {source_schema_fqn} -> {target_schema_fqn}")

                create_schema_if_not_exists(target_catalog, schema)
                
                # Filter the pre-fetched list of models for the current catalog and schema
                models_to_process = [
                    model for model in all_workspace_models
                    if model.name.startswith(f"{source_catalog}.{schema}.")
                ]

                if not models_to_process:
                    print(f"  No models found in schema {source_schema_fqn}. Skipping.")
                    continue

                for model in models_to_process:
                    all_tasks.append((source_catalog, schema, target_catalog, schema, model))

        except Exception as e:
            insert_log_entry({
                "entity_type": "catalog_pair",
                "entity_name": f"Processing of {catalog_pair_string}",
                "action": "unhandled_error",
                "status": "error",
                "message": f"An unhandled error occurred: {e}\n{traceback.format_exc()}",
                "timestamp": datetime.now(),
                "results_data": []
            })
    
    print(f"Collected {len(all_tasks)} models to process.")

    # Process all models in parallel
    print("Starting parallel model processing...")
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        executor.map(
            lambda task: process_model_sync(*task),
            all_tasks
        )
    
    print("Model synchronization complete.")

# --- Run the main job ---
# Assuming these variables are defined in the environment.
# For example:
# catalog_pair_names_list = ["source_catalog:target_catalog"]
# target_schema_owner_list = ["target_catalog.satyendranath_sure:owner_group"]
# max_workers = 8
# log_table_name = "your_catalog.your_schema.model_sync_log"
main(catalog_pair_names_list, target_schema_owner_list, max_workers)

In [0]:
%sql
select * from ${log_table_name} order by timestamp desc