# Test Load Embeddings to ES

Incremental test of loading embeddings to Elasticsearch.
Use the `limit` widget to control how many records to load.

**Incremental test plan:**
1. 1,000 records - validate it works
2. 1,000,000 records - test throughput
3. 10,000,000 records - stress test
4. Full load (217M) - production

In [None]:
%pip install elasticsearch==8.17.0
%restart_python

In [None]:
from pyspark.sql import functions as F
from elasticsearch import Elasticsearch, helpers
import time

# Configuration - read from Databricks table (not S3, to avoid IAM issues)
EMBEDDINGS_TABLE = "openalex.vector_search.work_embeddings_v2"
ELASTIC_INDEX = "works-v32"
ELASTIC_URL = dbutils.secrets.get(scope="elastic", key="elastic_url")

# Widget for limit - default 1000 for first test
dbutils.widgets.text("limit", "1000", "Number of records to load")
LIMIT = int(dbutils.widgets.get("limit"))

print(f"Loading {LIMIT:,} embeddings from {EMBEDDINGS_TABLE}")
print(f"Target index: {ELASTIC_INDEX}")

In [None]:
# Load embeddings from Databricks table with limit
print(f"Reading {LIMIT:,} records from table...")
t0 = time.time()

df = spark.table(EMBEDDINGS_TABLE).limit(LIMIT)

# Force evaluation to get actual count
actual_count = df.count()
print(f"Loaded {actual_count:,} records in {time.time() - t0:.1f}s")

# Show sample
print("\nSample record:")
sample = df.first()
print(f"  work_id: {sample.work_id}")
print(f"  embedding dims: {len(sample.embedding)}")

In [None]:
# Collect to driver for small batches, or use partition-based processing for larger ones
if LIMIT <= 100000:
    # For small tests, collect to driver and bulk update
    print(f"Collecting {actual_count:,} records to driver...")
    records = df.collect()
    print(f"Collected {len(records):,} records")
    use_partitions = False
else:
    # For larger tests, use partition-based processing
    print(f"Using partition-based processing for {actual_count:,} records")
    # Repartition for parallel processing
    num_partitions = max(100, actual_count // 10000)  # ~10K per partition
    df = df.repartition(num_partitions)
    print(f"Repartitioned to {df.rdd.getNumPartitions()} partitions")
    records = None
    use_partitions = True

In [None]:
from pyspark.sql.types import StructType, StructField, IntegerType, ArrayType, StringType
import time

def generate_update_actions(records_iter, index_name):
    """Generate ES bulk update actions."""
    for row in records_iter:
        doc_id = f"https://openalex.org/W{row.work_id}"
        embedding = [float(x) for x in row.embedding]
        
        yield {
            "_op_type": "update",
            "_index": index_name,
            "_id": doc_id,
            "doc": {
                "vector_embedding": embedding
            },
            "doc_as_upsert": False
        }

def send_partition_to_elastic(partition, partition_id):
    """Send a partition of embeddings to ES using bulk update."""
    from elasticsearch import Elasticsearch, helpers
    
    client = Elasticsearch(
        hosts=[ELASTIC_URL],
        request_timeout=180,
        max_retries=3,
        retry_on_timeout=True,
    )
    
    updated_count = 0
    error_count = 0
    not_found_count = 0
    errors = []
    
    try:
        for success, info in helpers.parallel_bulk(
            client,
            generate_update_actions(partition, ELASTIC_INDEX),
            chunk_size=500,
            thread_count=4,
            raise_on_error=False
        ):
            if success:
                updated_count += 1
            else:
                error_info = info.get("update", {})
                if error_info.get("status") == 404:
                    not_found_count += 1
                else:
                    error_count += 1
                    if len(errors) < 5:
                        errors.append(str(info)[:200])
    except Exception as e:
        error_count += 1
        errors.append(f"Error: {str(e)[:200]}")
    finally:
        client.close()
    
    yield {
        "partition_id": partition_id,
        "updated_count": updated_count,
        "error_count": error_count,
        "not_found_count": not_found_count,
        "errors": errors
    }

if use_partitions:
    # Partition-based processing for large datasets
    print(f"\nBulk updating via partitions...")
    t0 = time.time()
    
    log_schema = StructType([
        StructField("partition_id", IntegerType(), True),
        StructField("updated_count", IntegerType(), True),
        StructField("error_count", IntegerType(), True),
        StructField("not_found_count", IntegerType(), True),
        StructField("errors", ArrayType(StringType()), True)
    ])
    
    logs_rdd = df.rdd.mapPartitionsWithIndex(
        lambda idx, partition: send_partition_to_elastic(partition, idx)
    )
    
    logs_df = spark.createDataFrame(logs_rdd, log_schema)
    logs_df.cache()
    
    stats = logs_df.agg(
        F.sum("updated_count").alias("total_updated"),
        F.sum("error_count").alias("total_errors"),
        F.sum("not_found_count").alias("total_not_found")
    ).collect()[0]
    
    elapsed = time.time() - t0
    print(f"\n=== Results ===")
    print(f"Success: {stats.total_updated:,}")
    print(f"Not found (work not in ES): {stats.total_not_found:,}")
    print(f"Errors: {stats.total_errors:,}")
    print(f"Time: {elapsed:.1f}s")
    print(f"Rate: {stats.total_updated/elapsed:,.0f} docs/sec")
    
    # Show sample errors if any
    if stats.total_errors > 0:
        print("\nSample errors:")
        error_sample = logs_df.filter(F.size("errors") > 0).select("errors").limit(3).collect()
        for row in error_sample:
            for err in row.errors[:2]:
                print(f"  - {err}")
                
    success_count = stats.total_updated
    not_found_count = stats.total_not_found
    error_count = stats.total_errors

else:
    # Driver-based processing for small datasets
    print(f"\nBulk updating {len(records):,} records to ES...")
    t0 = time.time()
    
    client = Elasticsearch(
        hosts=[ELASTIC_URL],
        request_timeout=180,
        max_retries=3,
        retry_on_timeout=True,
    )
    
    success_count = 0
    error_count = 0
    not_found_count = 0
    
    for success, info in helpers.parallel_bulk(
        client,
        generate_update_actions(records, ELASTIC_INDEX),
        chunk_size=500,
        thread_count=4,
        raise_on_error=False
    ):
        if success:
            success_count += 1
        else:
            error_info = info.get("update", {})
            if error_info.get("status") == 404:
                not_found_count += 1
            else:
                error_count += 1
                if error_count <= 5:
                    print(f"Error: {info}")
    
    elapsed = time.time() - t0
    print(f"\n=== Results ===")
    print(f"Success: {success_count:,}")
    print(f"Not found (work not in ES): {not_found_count:,}")
    print(f"Errors: {error_count:,}")
    print(f"Time: {elapsed:.1f}s")
    print(f"Rate: {success_count/elapsed:,.0f} docs/sec")
    
    client.close()

In [None]:
# Refresh index and verify
print("\nRefreshing index and verifying...")

client = Elasticsearch(hosts=[ELASTIC_URL], request_timeout=180)

# Refresh to make new docs searchable
client.indices.refresh(index=ELASTIC_INDEX)

# Count documents with embeddings
result = client.count(
    index=ELASTIC_INDEX,
    body={"query": {"exists": {"field": "vector_embedding"}}}
)
print(f"Documents with vector_embedding: {result['count']:,}")

client.close()

In [None]:
# Test kNN search
print("\nTesting kNN search...")

client = Elasticsearch(hosts=[ELASTIC_URL], request_timeout=180)

# Get a sample embedding to use as query
if records is not None:
    sample_embedding = [float(x) for x in records[0].embedding]
    sample_work_id = records[0].work_id
else:
    # Get a sample from the dataframe
    sample = df.limit(1).collect()[0]
    sample_embedding = [float(x) for x in sample.embedding]
    sample_work_id = sample.work_id

# Run kNN search
t0 = time.time()
result = client.search(
    index=ELASTIC_INDEX,
    body={
        "knn": {
            "field": "vector_embedding",
            "query_vector": sample_embedding,
            "k": 5,
            "num_candidates": 50
        },
        "_source": ["id", "title"]
    },
    size=5
)
latency = (time.time() - t0) * 1000

print(f"Query work_id: W{sample_work_id}")
print(f"Latency: {latency:.0f}ms")
print(f"\nTop 5 similar works:")
for hit in result['hits']['hits']:
    title = hit['_source'].get('title', 'N/A')
    if title:
        title = title[:70] + "..." if len(title) > 70 else title
    print(f"  {hit['_score']:.4f}: {title}")

client.close()

print(f"\n✓ Test complete! kNN search working.")

In [None]:
# Test kNN with pre-filter
print("\nTesting kNN with pre-filter (is_oa:true)...")

client = Elasticsearch(hosts=[ELASTIC_URL], request_timeout=180)

t0 = time.time()
result = client.search(
    index=ELASTIC_INDEX,
    body={
        "knn": {
            "field": "vector_embedding",
            "query_vector": sample_embedding,
            "k": 5,
            "num_candidates": 50,
            "filter": {
                "term": {"is_oa": True}
            }
        },
        "_source": ["id", "title", "is_oa"]
    },
    size=5
)
latency = (time.time() - t0) * 1000

print(f"Latency: {latency:.0f}ms")
print(f"\nTop 5 similar OA works:")
all_oa = True
for hit in result['hits']['hits']:
    title = hit['_source'].get('title', 'N/A')
    is_oa = hit['_source'].get('is_oa', False)
    if not is_oa:
        all_oa = False
    if title:
        title = title[:60] + "..." if len(title) > 60 else title
    print(f"  {hit['_score']:.4f} [OA:{is_oa}]: {title}")

client.close()

if all_oa:
    print(f"\n✓ Pre-filter working! All results are OA.")
else:
    print(f"\n✗ Warning: Some results are not OA - filter may not be working correctly.")