# Create Mosaic AI Vector Search Index (Managed Embeddings)

Sets up a Vector Search endpoint and Delta Sync index with **Databricks-managed embeddings**.

**Source**: `openalex.vector_search.works_for_embedding`
**Embedding Model**: `databricks-gte-large-en` (1024 dims, managed by Databricks)
**Endpoint**: Storage-optimized for cost efficiency at scale
**Sync**: Delta Sync (automatic updates from source table)

In [None]:
# Configuration
ENDPOINT_NAME = "openalex-vector-search"
INDEX_NAME = "openalex.vector_search.work_embeddings_index"
SOURCE_TABLE = "openalex.vector_search.works_for_embedding"
EMBEDDING_SOURCE_COLUMN = "text_to_embed"
EMBEDDING_MODEL = "databricks-gte-large-en"  # Databricks manages embedding
PRIMARY_KEY = "work_id"

# Metadata columns for filtering
METADATA_COLUMNS = ["publication_year", "type", "is_oa", "has_abstract", "has_content_pdf", "has_content_grobid_xml"]

## Step 1: Create Vector Search Endpoint (storage-optimized)

Storage-optimized endpoints are up to 7x cheaper than standard endpoints.
- 1 unit = 64M vectors @ 768 dimensions
- For 250M vectors @ 1536 dimensions = ~8 units

In [None]:
from databricks.vector_search.client import VectorSearchClient

# Initialize client
vsc = VectorSearchClient()

In [None]:
# Check if endpoint exists
try:
    endpoint = vsc.get_endpoint(ENDPOINT_NAME)
    print(f"Endpoint '{ENDPOINT_NAME}' already exists")
    print(f"  Status: {endpoint.get('endpoint_status', {}).get('state')}")
    print(f"  Type: {endpoint.get('endpoint_type')}")
except Exception as e:
    print(f"Endpoint does not exist, will create: {e}")

In [None]:
# Create storage-optimized endpoint
# Only run if endpoint doesn't exist

try:
    vsc.get_endpoint(ENDPOINT_NAME)
    print(f"Endpoint '{ENDPOINT_NAME}' already exists, skipping creation")
except Exception:
    endpoint = vsc.create_endpoint(
        name=ENDPOINT_NAME,
        endpoint_type="STORAGE_OPTIMIZED"  # 7x cheaper than STANDARD
    )
    print(f"Created endpoint: {endpoint}")

In [None]:
# Wait for endpoint to be ready
import time

while True:
    endpoint = vsc.get_endpoint(ENDPOINT_NAME)
    state = endpoint.get('endpoint_status', {}).get('state')
    print(f"Endpoint state: {state}")
    
    if state == 'ONLINE':
        print("Endpoint is ready!")
        break
    elif state in ['OFFLINE', 'FAILED']:
        raise Exception(f"Endpoint failed to start: {endpoint}")
    
    time.sleep(30)

## Step 2: Create Source Table (if needed)

The source table contains the text to embed and metadata columns.

In [None]:
%%sql
-- Create source table with Change Data Feed enabled (if not exists)
CREATE TABLE IF NOT EXISTS openalex.vector_search.works_for_embedding (
    work_id STRING NOT NULL,
    text_to_embed STRING NOT NULL,
    publication_year INT,
    type STRING,
    is_oa BOOLEAN,
    has_abstract BOOLEAN,
    has_content_pdf BOOLEAN,
    has_content_grobid_xml BOOLEAN
)
TBLPROPERTIES ('delta.enableChangeDataFeed' = 'true');

-- Populate from source (only if empty)
-- INSERT INTO openalex.vector_search.works_for_embedding
-- SELECT
--     CAST(id AS STRING) as work_id,
--     CONCAT('Title: ', COALESCE(title, ''), '\n\nAbstract: ', COALESCE(abstract, '')) as text_to_embed,
--     publication_year,
--     type,
--     open_access.is_oa as is_oa,
--     CASE WHEN abstract IS NOT NULL THEN true ELSE false END as has_abstract,
--     has_content.pdf as has_content_pdf,
--     has_content.grobid_xml as has_content_grobid_xml
-- FROM openalex.works.openalex_works
-- WHERE type != 'dataset'
--   AND abstract IS NOT NULL
--   AND title IS NOT NULL
--   AND id IS NOT NULL;

## Step 3: Create Delta Sync Index

In [None]:
# Check if index exists
try:
    index = vsc.get_index(ENDPOINT_NAME, INDEX_NAME)
    print(f"Index '{INDEX_NAME}' already exists")
    print(f"  Status: {index.get('status', {}).get('ready')}")
except Exception as e:
    print(f"Index does not exist, will create: {e}")

In [None]:
# Create Delta Sync index with MANAGED EMBEDDINGS
# Databricks handles embedding generation automatically using the specified model

try:
    vsc.get_index(ENDPOINT_NAME, INDEX_NAME)
    print(f"Index '{INDEX_NAME}' already exists, skipping creation")
except Exception:
    index = vsc.create_delta_sync_index(
        endpoint_name=ENDPOINT_NAME,
        index_name=INDEX_NAME,
        source_table_name=SOURCE_TABLE,
        primary_key=PRIMARY_KEY,
        # Managed embeddings: specify source column and model
        embedding_source_column=EMBEDDING_SOURCE_COLUMN,
        embedding_model_endpoint_name=EMBEDDING_MODEL,
        # Metadata columns for filtering
        columns_to_sync=METADATA_COLUMNS,
        # Use triggered sync for cost control (vs continuous)
        pipeline_type="TRIGGERED"
    )
    print(f"Created managed embedding index: {index}")

In [None]:
# Wait for index to sync
import time

while True:
    index = vsc.get_index(ENDPOINT_NAME, INDEX_NAME)
    status = index.get('status', {})
    ready = status.get('ready', False)
    indexed_count = status.get('indexed_row_count', 0)
    
    print(f"Index ready: {ready}, Indexed rows: {indexed_count:,}")
    
    if ready:
        print("Index is ready!")
        break
    
    time.sleep(60)

## Step 4: Test similarity search

In [None]:
# Test query using text (Databricks handles embedding automatically)
test_query = "climate change impacts on coral reef ecosystems"

index = vsc.get_index(ENDPOINT_NAME, INDEX_NAME)

# With managed embeddings, use query_text instead of query_vector
results = index.similarity_search(
    query_text=test_query,
    num_results=10,
    columns=["work_id", "publication_year", "type", "is_oa"]
)

print(f"Query: {test_query}")
print(f"Found {len(results.get('result', {}).get('data_array', []))} results")
for row in results.get('result', {}).get('data_array', []):
    print(f"  work_id: {row[0]}, year: {row[1]}, type: {row[2]}, is_oa: {row[3]}, score: {row[-1]:.4f}")

In [None]:
# Test with metadata filter
results_filtered = index.similarity_search(
    query_text=test_query,
    num_results=10,
    filters="publication_year > 2020",
    columns=["work_id", "publication_year", "type", "is_oa"]
)

print(f"Found {len(results_filtered.get('result', {}).get('data_array', []))} results (year > 2020)")
for row in results_filtered.get('result', {}).get('data_array', []):
    print(f"  work_id: {row[0]}, year: {row[1]}, type: {row[2]}, is_oa: {row[3]}, score: {row[-1]:.4f}")

## Step 5: Trigger manual sync (for updates)

Call this after new embeddings are added to sync the index.

In [None]:
# Trigger manual sync (for TRIGGERED pipeline type)
index = vsc.get_index(ENDPOINT_NAME, INDEX_NAME)
sync_result = index.sync()
print(f"Sync triggered: {sync_result}")

## Index Info

In [None]:
# Get current index info
index = vsc.get_index(ENDPOINT_NAME, INDEX_NAME)
print("Index configuration:")
print(f"  Name: {index.get('name')}")
print(f"  Source table: {index.get('delta_sync_index_spec', {}).get('source_table')}")
print(f"  Embedding column: {index.get('delta_sync_index_spec', {}).get('embedding_vector_columns')}")
print(f"  Embedding dimension: {index.get('delta_sync_index_spec', {}).get('embedding_dimension')}")
print(f"  Pipeline type: {index.get('delta_sync_index_spec', {}).get('pipeline_type')}")
print(f"  Status: {index.get('status')}")