In [None]:
%pip install elasticsearch==8.19.0
%restart_python

In [None]:
"""
Sync Vector Embeddings to Elasticsearch

Bulk updates works-v32 documents with vector_embedding field from the
work_embeddings_v2 table. Uses partial document updates to add embeddings
without overwriting existing fields.

Run modes:
- is_full_sync=true: Load all 217M embeddings (initial load, ~4 hours)
- is_full_sync=false: Load only recent embeddings (for ongoing sync)
"""

from pyspark.sql import functions as F
from pyspark.sql.types import *
from elasticsearch import Elasticsearch, helpers
import logging

logging.basicConfig(level=logging.WARNING, format='[%(asctime)s]: %(message)s')
log = logging.getLogger(__name__)

ELASTIC_INDEX = "works-v32"
ELASTIC_URL = dbutils.secrets.get(scope="elastic", key="elastic_url")

IS_FULL_SYNC = dbutils.widgets.get("is_full_sync").lower() == "true"

print(f"IS_FULL_SYNC: {IS_FULL_SYNC}")
print(f"Target index: {ELASTIC_INDEX}")

In [None]:
# Load embeddings from Databricks table
# work_id is stored as string (e.g. "3043519958"), embedding is array<double> (1024 dims)

if IS_FULL_SYNC:
    SQL_QUERY = """
    SELECT work_id, embedding
    FROM openalex.vector_search.work_embeddings_v2
    """
else:
    # For incremental sync, we'd need to track which embeddings are new
    # For now, just use full sync for initial load
    SQL_QUERY = """
    SELECT work_id, embedding
    FROM openalex.vector_search.work_embeddings_v2
    """

df = spark.sql(SQL_QUERY)

# Get count for progress tracking
total_count = df.count()
print(f"Total embeddings to sync: {total_count:,}")

# Repartition for parallel processing
# ~217M rows / 10K per partition = 21,700 partitions
# Use 8000 partitions for manageable parallelism
if IS_FULL_SYNC:
    df = df.repartitionByRange(8000, "work_id")
    print(f"Repartitioned to {df.rdd.getNumPartitions()} partitions")
else:
    df = df.repartition(100)
    print(f"Using {df.rdd.getNumPartitions()} partitions for incremental sync")

In [None]:
# Schema for logging results
log_schema = StructType([
    StructField("partition_id", IntegerType(), True),
    StructField("updated_count", IntegerType(), True),
    StructField("error_count", IntegerType(), True),
    StructField("not_found_count", IntegerType(), True),
    StructField("errors", ArrayType(StringType()), True)
])

def generate_update_actions(partition):
    """
    Generate ES bulk update actions for partial document updates.
    Uses _update action to add vector_embedding without touching other fields.
    """
    for row in partition:
        # Build the OpenAlex ID format: https://openalex.org/W{work_id}
        doc_id = f"https://openalex.org/W{row.work_id}"
        
        # Convert embedding from list of doubles to list of floats
        # ES dense_vector expects float32, PySpark gives us float64
        embedding = [float(x) for x in row.embedding]
        
        yield {
            "_op_type": "update",
            "_index": ELASTIC_INDEX,
            "_id": doc_id,
            "doc": {
                "vector_embedding": embedding
            },
            # Don't fail if document doesn't exist (some works may not be in ES yet)
            "doc_as_upsert": False
        }

def send_partition_to_elastic(partition, partition_id):
    """
    Send a partition of embeddings to Elasticsearch using bulk update.
    """
    client = Elasticsearch(
        hosts=[ELASTIC_URL],
        request_timeout=180,
        max_retries=5,
        retry_on_timeout=True,
        http_compress=True,
    )
    
    updated_count = 0
    error_count = 0
    not_found_count = 0
    errors = []
    
    try:
        for success, info in helpers.parallel_bulk(
            client,
            generate_update_actions(partition),
            chunk_size=500,
            thread_count=4,
            queue_size=10,
            raise_on_error=False
        ):
            if success:
                updated_count += 1
            else:
                error_info = info.get("update", {})
                status = error_info.get("status", 0)
                
                # 404 = document not found in ES (work not synced yet)
                if status == 404:
                    not_found_count += 1
                else:
                    error_count += 1
                    if len(errors) < 10:
                        errors.append(str(info)[:500])
                        
    except Exception as e:
        error_count += 1
        errors.append(f"Bulk error: {str(e)[:500]}")
    finally:
        client.close()
    
    yield {
        "partition_id": partition_id,
        "updated_count": updated_count,
        "error_count": error_count,
        "not_found_count": not_found_count,
        "errors": errors
    }

In [None]:
# Execute the sync
print(f"Starting embedding sync to {ELASTIC_INDEX}...")

logs_rdd = df.rdd.mapPartitionsWithIndex(
    lambda idx, partition: send_partition_to_elastic(partition, idx)
)

logs_df = spark.createDataFrame(logs_rdd, log_schema)

# Cache and collect stats
logs_df.cache()
partition_count = logs_df.count()

print(f"Processed {partition_count} partitions")

In [None]:
# Aggregate results
stats = logs_df.agg(
    F.sum("updated_count").alias("total_updated"),
    F.sum("error_count").alias("total_errors"),
    F.sum("not_found_count").alias("total_not_found")
).collect()[0]

print(f"\n=== Sync Complete ===")
print(f"Total updated: {stats.total_updated:,}")
print(f"Total errors: {stats.total_errors:,}")
print(f"Total not found (work not in ES): {stats.total_not_found:,}")

# Show sample errors if any
if stats.total_errors > 0:
    print("\nSample errors:")
    error_sample = logs_df.filter(F.size("errors") > 0).select("errors").limit(5).collect()
    for row in error_sample:
        for err in row.errors[:2]:
            print(f"  - {err}")

In [None]:
# Refresh index and verify
try:
    client = Elasticsearch(hosts=[ELASTIC_URL], request_timeout=180)
    
    if client.indices.exists(index=ELASTIC_INDEX):
        client.indices.refresh(index=ELASTIC_INDEX)
        print(f"Refreshed index {ELASTIC_INDEX}")
        
        # Count documents with embeddings
        result = client.count(
            index=ELASTIC_INDEX,
            body={"query": {"exists": {"field": "vector_embedding"}}}
        )
        print(f"Documents with vector_embedding: {result['count']:,}")
        
        total_docs = client.count(index=ELASTIC_INDEX)['count']
        print(f"Total documents: {total_docs:,}")
        print(f"Coverage: {result['count'] / total_docs * 100:.1f}%")
    else:
        print(f"Index {ELASTIC_INDEX} does not exist")
finally:
    client.close()

In [None]:
# Test kNN search with a sample query
print("\nTesting kNN search...")

try:
    client = Elasticsearch(hosts=[ELASTIC_URL], request_timeout=180)
    
    # Get a sample embedding to use as query
    sample = spark.sql("""
        SELECT work_id, embedding
        FROM openalex.vector_search.work_embeddings_v2
        LIMIT 1
    """).collect()[0]
    
    query_vector = [float(x) for x in sample.embedding]
    
    # Run kNN search
    result = client.search(
        index=ELASTIC_INDEX,
        body={
            "knn": {
                "field": "vector_embedding",
                "query_vector": query_vector,
                "k": 5,
                "num_candidates": 50
            },
            "_source": ["id", "title"]
        },
        size=5
    )
    
    print(f"Query work_id: {sample.work_id}")
    print(f"\nTop 5 similar works:")
    for hit in result['hits']['hits']:
        print(f"  {hit['_score']:.4f}: {hit['_source'].get('title', 'N/A')[:80]}")
        
finally:
    client.close()