In [0]:
pip install snowflake

In [0]:
pip install mlflow

In [0]:
import os
import sys
import io
import shutil
import pandas as pd
import joblib
import snowflake.connector
import mlflow
from mlflow.tracking import MlflowClient

In [0]:
# Snowflake credentials from env
SNOWFLAKE_ACCOUNT = os.getenv('SNOWFLAKE_ACCOUNT')
SNOWFLAKE_USER = os.getenv('SNOWFLAKE_USER')
SNOWFLAKE_PASSWORD = os.getenv('SNOWFLAKE_PASSWORD')
SNOWFLAKE_WAREHOUSE = os.getenv('SNOWFLAKE_WAREHOUSE')
SNOWFLAKE_DATABASE = os.getenv('SNOWFLAKE_DATABASE')
SNOWFLAKE_SCHEMA = os.getenv('SNOWFLAKE_SCHEMA')

# Databricks workspace paths
DATABRICKS_EMAIL = os.getenv("DATABRICKS_EMAIL")
PREDICTION_DIR = f"/Workspace/Users/{DATABRICKS_EMAIL}/CREDITCARD/Predictions"
CHAMPION_MODEL_DIR = f"/Workspace/Users/{DATABRICKS_EMAIL}/CREDITCARD/Champion_Model"

In [0]:
# MLflow model name in Unity Catalog (must be 3-part: catalog.schema.name)
UC_MODEL_NAME = "workspace.default.CreditCardFraudModel"

# Snowflake tables
BATCH_INPUT_TABLE = f"{SNOWFLAKE_DATABASE}.{SNOWFLAKE_SCHEMA}.CREDITCARD_BATCH_INPUTS"
BATCH_PREDICTIONS_TABLE = f"{SNOWFLAKE_DATABASE}.{SNOWFLAKE_SCHEMA}.BATCH_PREDICTIONS"


In [0]:
def get_snowflake_connection():
    return snowflake.connector.connect(
        user=SNOWFLAKE_USER,
        password=SNOWFLAKE_PASSWORD,
        account=SNOWFLAKE_ACCOUNT,
        warehouse=SNOWFLAKE_WAREHOUSE,
        database=SNOWFLAKE_DATABASE,
        schema=SNOWFLAKE_SCHEMA
    )

In [0]:
def fetch_batch_data():
    print(f"📥 Fetching batch input data from Snowflake: {BATCH_INPUT_TABLE}")
    with get_snowflake_connection() as conn:
        df = pd.read_sql(f"SELECT * FROM {BATCH_INPUT_TABLE}", conn)
        print(f"✅ Loaded {df.shape[0]} rows, {df.shape[1]} columns.")
        return df

In [0]:
def get_champion_model():
    print(f"🔍 Looking for champion model in UC: {UC_MODEL_NAME}")
    mlflow.set_tracking_uri("databricks")
    mlflow.set_registry_uri("databricks-uc")
    client = MlflowClient()

    # Search all versions
    versions = client.search_model_versions(f"name='{UC_MODEL_NAME}'")

    # Get the one tagged as production and role=champion
    for v in versions:
        full_version = client.get_model_version(name=UC_MODEL_NAME, version=v.version)
        tags = full_version.tags
        if tags.get("status") == "production" and tags.get("role") == "champion":
            print(f"🏆 Champion model found: Version {v.version}, Run ID: {v.run_id}")
            model_uri = f"models:/{UC_MODEL_NAME}/{v.version}"
            return mlflow.sklearn.load_model(model_uri), model_uri

    raise Exception("❌ No champion model (status=production, role=champion) found.")

In [0]:
def generate_predictions(df, model):
    # Ensure ID column exists
    if 'ID' not in df.columns:
        df.insert(0, 'ID', range(1, len(df) + 1))

    features = df.drop(columns=['ID'] + (['CLASS'] if 'CLASS' in df.columns else []))

    print(f"🔮 Generating predictions for {features.shape[0]} records...")
    preds = model.predict(features)

    if hasattr(model, "predict_proba"):
        probs = model.predict_proba(features)[:, 1]
    else:
        probs = [None] * len(preds)

    result_df = df.copy()
    result_df['PREDICTION'] = preds
    result_df['PREDICTION_PROB'] = probs

    return result_df

In [0]:
def save_predictions_to_snowflake(df):
    print(f"🧾 Inserting predictions into Snowflake: {BATCH_PREDICTIONS_TABLE}")
    with get_snowflake_connection() as conn:
        cursor = conn.cursor()
        try:
            cursor.execute(f"TRUNCATE TABLE {BATCH_PREDICTIONS_TABLE}")
            conn.commit()

            cols = list(df.columns)
            placeholders = ', '.join(['%s'] * len(cols))
            insert_query = f"INSERT INTO {BATCH_PREDICTIONS_TABLE} ({', '.join(cols)}) VALUES ({placeholders})"
            data = [tuple(row) for row in df.to_numpy()]
            cursor.executemany(insert_query, data)
            conn.commit()
            print("✅ Predictions written to Snowflake.")
        finally:
            cursor.close()

In [0]:
def save_predictions_to_databricks(df, path):
    os.makedirs(path, exist_ok=True)
    file_path = os.path.join(path, "batch_predictions.csv")
    df.to_csv(file_path, index=False)
    print(f"💾 Predictions saved to: {file_path}")

In [0]:
def copy_champion_model_to_workspace(model_uri):
    # Download artifacts
    local_model_path = mlflow.artifacts.download_artifacts(model_uri)
    model_pkl_path = os.path.join(local_model_path, "model.pkl")

    if not os.path.exists(model_pkl_path):
        raise FileNotFoundError("❌ model.pkl not found in downloaded artifacts.")

    os.makedirs(CHAMPION_MODEL_DIR, exist_ok=True)
    dst_path = os.path.join(CHAMPION_MODEL_DIR, "champion_model.pkl")
    shutil.copy(model_pkl_path, dst_path)
    print(f"📦 Champion model copied to: {dst_path}")

In [0]:
def main():
    print("🚀 Starting batch inference pipeline")

    # 1. Load batch input data
    batch_df = fetch_batch_data()

    # 2. Load champion model from UC
    model, model_uri = get_champion_model()

    # 3. Generate predictions
    predictions_df = generate_predictions(batch_df, model)

    # 4. Save to Snowflake (truncate first)
    save_predictions_to_snowflake(predictions_df)

    # 5. Save predictions locally on Databricks
    save_predictions_to_databricks(predictions_df, PREDICTION_DIR)

    # 6. Copy model.pkl to Databricks workspace
    copy_champion_model_to_workspace(model_uri)

    print("✅ Batch inference completed.")

if __name__ == "__main__":
    main()