In [0]:
%pip install -U -qqqq mlflow-skinny[databricks] langgraph==0.3.4 databricks-langchain databricks-agents psycopg2-binary uv
dbutils.library.restartPython()

[43mNote: you may need to restart the kernel using %restart_python or dbutils.library.restartPython() to use updated packages.[0m


In [0]:
dbutils.widgets.text(
    "vector_index",
    "pgvector_index4",
    label="vector index name used for naming model endpoint, function, etc",
)

dbutils.widgets.text(
    "embedding_model",
    "databricks-gte-large-en",
    label="embedding model to use",
)

In [0]:
# Set the Databricks host and token
import os, json

os.environ['DATABRICKS_TOKEN'] = dbutils.secrets.get(scope="gsa-rag", key="endpoint-token")
conn_params = json.loads(dbutils.secrets.get(scope="gsa-rag", key="vector-conn-params"))
os.environ['INDEX_PARAMS'] = json.dumps(conn_params)
display(conn_params)


{'host': 'instance-d25b042e-1736-4852-8dbe-1f1e4fe0efe2.database.cloud.databricks.com',
 'database': 'databricks_postgres',
 'user': 'vector_db',
 'password': 'abc123456789'}

In [0]:
import mlflow.pyfunc
import psycopg2
import json
import os, sys
from databricks.sdk import WorkspaceClient


class VectorIndexModel(mlflow.pyfunc.PythonModel):
    def load_context(self, context):
        None

    def get_embedding(self, input_text):
        client = WorkspaceClient()
        response = client.serving_endpoints.query(
            name="databricks-gte-large-en", input=input_text
        )

        return response.data[0].embedding

    def predict(self, context, model_input):
        query = model_input["input"][0]
        num_docs = model_input["num_docs"][0]
        sensitivities = model_input["sensitivity"][0]

        conn_params=os.getenv("INDEX_PARAMS")
        conn_params = json.loads(conn_params)

        pre_computed_embedding = self.get_embedding(query)
        embedding_str = ",".join(map(str, pre_computed_embedding))
        sensitivitites_pred = ",".join([f"'{element}'" for element in sensitivities])
        
        conn = psycopg2.connect(**conn_params)
        cursor = conn.cursor()

        sql_query = f"""
        SELECT id, doc_url, content, sensitivity
        FROM document_base
        WHERE sensitivity = ANY(ARRAY[{sensitivitites_pred}])
        ORDER BY embedding <-> '[{embedding_str}]'::vector
        LIMIT {num_docs};
        """
        cursor.execute(sql_query)
        results = cursor.fetchall()
        cursor.close()
        conn.close()

        column_names = [desc[0] for desc in cursor.description]
        json_results = [dict(zip(column_names, row)) for row in results]

        return {"outputs": json_results}



In [0]:
vectorIndexModel = VectorIndexModel()

vectorIndexModel.predict(
    context={},
    model_input={
        "input": "what is gsa?", 
        "num_docs": "2", 
        "sensitivity": [["public"]]
    }
)



{'outputs': [{'id': 253,
   'doc_url': '/Volumes/tsfrt/gsa/performance/FY-2026-GSA-Annual-Performance-Plan_5-28-25.pdf',
   'content': "ERROR: Error code: 400 - {'error_code': 'BAD_REQUEST', 'message': 'Request size cannot exceed 4194304 bytes. Please shorten the request.'}",
   'sensitivity': 'public'},
  {'id': 150,
   'doc_url': '/Volumes/tsfrt/gsa/performance/GSA_Annual_Performance_Plan_FY_2023_FINAL_508.pdf',
   'content': "ERROR: Error code: 400 - {'error_code': 'BAD_REQUEST', 'message': 'Invalid base64 string for image\\n'}",
   'sensitivity': 'public'}]}

In [0]:
import mlflow
from mlflow.models.signature import ModelSignature
from mlflow.types import Schema, ColSpec
from mlflow.models.resources import DatabricksServingEndpoint
from mlflow.types.schema import Schema, ColSpec, Array
from mlflow.types import DataType

resources = [DatabricksServingEndpoint(endpoint_name="databricks-gte-large-en")]

input_example = {"input": "text chunk", "num_docs": "2", "sensitivity": ["public"]}
conn_params_path = "conn_params.json"

# Start an MLflow run and log the model
with mlflow.start_run():
    logged_agent_info = mlflow.pyfunc.log_model(
        artifact_path="pgvector-query",
        signature=ModelSignature(
            inputs=Schema(
                [
                    ColSpec("string", "input"),
                    ColSpec("string", "num_docs"),
                    ColSpec(Array(DataType.string), "sensitivity"),
                ]
            ),
            outputs=Schema(
                [
                    ColSpec("long", "id"),
                    ColSpec("string", "content"),
                    ColSpec("string", "sensitivity"),
                ]
            ),
        ),
        python_model=VectorIndexModel(),
        input_example=input_example,
        resources=resources,
        pip_requirements=[
            "mlflow==3.1.4",
            "cloudpickle==3.0.0",
            "httplib2==0.20.2",
            "psycopg2-binary==2.9.10",
        ],
    )

print(logged_agent_info)

🔗 View Logged Model at: https://e2-demo-field-eng.cloud.databricks.com/ml/experiments/1945536872982684/models/m-aea9bd037e36421f8a2a08fa2cc1febb?o=1444828305810485
2025/08/05 02:03:41 INFO mlflow.pyfunc: Validating input example against model signature
2025/08/05 02:03:41 INFO mlflow.models.model: Found the following environment variables used during model inference: [DATABRICKS_TOKEN]. Please check if you need to set them when deploying the model. To disable this message, set environment variable `MLFLOW_RECORD_ENV_VARS_IN_MODEL_LOGGING` to `false`.


<mlflow.models.model.ModelInfo object at 0x7f830ba667d0>


In [0]:
mlflow.set_registry_uri("databricks-uc")
vector_index = dbutils.widgets.get("vector_index")

# TODO: define the catalog, schema, and model name for your UC model
catalog = "tsfrt"
schema = "gsa"
model_name = vector_index
UC_MODEL_NAME = f"{catalog}.{schema}.{model_name}"

# register the model to UC
uc_registered_model_info = mlflow.register_model(
    model_uri=logged_agent_info.model_uri, name=UC_MODEL_NAME
)

Registered model 'tsfrt.gsa.pgvector_index4' already exists. Creating a new version of this model...


Downloading artifacts:   0%|          | 0/12 [00:00<?, ?it/s]

Uploading artifacts:   0%|          | 0/13 [00:00<?, ?it/s]

🔗 Created version '6' of model 'tsfrt.gsa.pgvector_index4': https://e2-demo-field-eng.cloud.databricks.com/explore/data/models/tsfrt/gsa/pgvector_index4/version/6?o=1444828305810485


In [0]:
import mlflow
from mlflow.deployments import get_deploy_client
from mlflow.exceptions import MlflowException

mlflow.set_registry_uri("databricks-uc")
client = get_deploy_client("databricks")

endpointname = f"{vector_index}-endpoint"
endpoint_config = {
    "served_entities": [
        {
            "entity_name": UC_MODEL_NAME,
            "entity_version": uc_registered_model_info.version,
            "workload_size": "Small",
            "scale_to_zero_enabled": True,
            "environment_vars": {"INDEX_PARAMS": json.dumps(conn_params)},
        }
    ]
}

try:
    endpoint = client.get_endpoint(endpointname)
    print(f"Endpoint {endpointname} already exists.")
    endpoint = client.update_endpoint(endpointname, config=endpoint_config)
except MlflowException as e:
    if "RESOURCE_ALREADY_EXISTS" in str(e):
        print(f"Endpoint {endpointname} does not exist")
        endpoint = client.create_endpoint(endpointname, config=endpoint_config)
    else:
        raise e

context = dbutils.notebook.entry_point.getDbutils().notebook().getContext()
workspace_url = context.apiUrl().get()

Endpoint pgvector_index4-endpoint already exists.


  endpoint = client.update_endpoint(endpointname, config=endpoint_config)


In [0]:
token = os.getenv("DATABRICKS_TOKEN")

spark.sql(f"""
CREATE OR REPLACE FUNCTION tsfrt.gsa.vector_query_{vector_index}(input_text STRING, num_docs STRING, sensitivity ARRAY<STRING>)
RETURNS STRING
LANGUAGE PYTHON
AS $$
import requests
import json

token = f'Bearer {token}'

url = '{workspace_url}/serving-endpoints/{endpointname}/invocations'

headers = {{
      'Authorization': token,
      'Content-Type': 'application/json'
}}
data = {{"input": input_text, "num_docs": num_docs, "sensitivity": sensitivity}}
try:
    response = requests.post(url, headers=headers, json=data)
    response.raise_for_status()
except requests.exceptions.RequestException as e:
    raise Exception(f'Request failed: {{e}}')
  
return response.json()['outputs']
$$;
""")

DataFrame[]