In [None]:
%pip install elasticsearch==8.19.0
%restart_python

In [None]:
"""
Sync Vector Index (works-vectors-v1)

Bulk loads embeddings + 14 flat filter fields into a dedicated lightweight
ES index optimized for kNN vector search (12 shards vs 72 on works-v33).

Two-phase semantic search: kNN here returns IDs → mget full docs from works-v33.

Run modes:
- is_full_sync=true: Load all ~413M embeddings (initial load, ~6-8 hours)
- is_full_sync=false: Load recent updates only (daily incremental)
"""

import json
from datetime import datetime
from dataclasses import dataclass

from pyspark.sql import functions as F
from pyspark.sql.types import *
from elasticsearch import Elasticsearch, helpers
import logging

logging.basicConfig(level=logging.WARNING, format='[%(asctime)s]: %(message)s')
log = logging.getLogger(__name__)

ELASTIC_INDEX = "works-vectors-v1"
ELASTIC_URL = dbutils.secrets.get(scope="elastic", key="elastic_url")

IS_FULL_SYNC = dbutils.widgets.get("is_full_sync").lower() == "true"

print(f"IS_FULL_SYNC: {IS_FULL_SYNC}")
print(f"Target index: {ELASTIC_INDEX}")

In [None]:
# Set replicas to 0 for faster bulk indexing during full sync
if IS_FULL_SYNC:
    try:
        client = Elasticsearch(
            hosts=[ELASTIC_URL],
            request_timeout=180,
            max_retries=5,
            retry_on_timeout=True
        )
        if client.indices.exists(index=ELASTIC_INDEX):
            client.indices.put_settings(index=ELASTIC_INDEX, body={
                "index": {
                    "number_of_replicas": 0,
                    "refresh_interval": "-1"
                }
            })
            print(f"Set replicas=0, refresh=-1 on {ELASTIC_INDEX} for full sync")
        else:
            print(f"Index {ELASTIC_INDEX} does not exist")
    finally:
        client.close()

### Prepare Data

Join work_embeddings_v2 with openalex_works to get the 14 filter fields.
Flatten authorships into arrays: author_ids, institution_ids, country_codes, funder_ids.

In [None]:
import time

# Batch processing configuration
STAGING_TABLE = "openalex.vector_search.vector_sync_staging"
CHECKPOINT_TABLE = "openalex.vector_search.vector_sync_checkpoint"
NUM_BATCHES = 200 if IS_FULL_SYNC else 10

# --- Checkpoint table ---
spark.sql(f"""
    CREATE TABLE IF NOT EXISTS {CHECKPOINT_TABLE} (
        batch_id INT, indexed_count LONG, skipped_count LONG,
        error_count LONG, completed_at TIMESTAMP
    )
""")
completed_rows = spark.sql(f"SELECT batch_id FROM {CHECKPOINT_TABLE}").collect()
completed_batches = {row.batch_id for row in completed_rows}

# --- Staging table (materialized join, partitioned by batch_id) ---
staging_exists = spark.catalog.tableExists(STAGING_TABLE)

if IS_FULL_SYNC and staging_exists and len(completed_batches) > 0:
    # Resuming an interrupted full sync — reuse existing staging data
    total_staged = spark.sql(f"SELECT COUNT(*) FROM {STAGING_TABLE}").collect()[0][0]
    print(f"RESUMING full sync: {len(completed_batches)}/{NUM_BATCHES} batches done")
    print(f"Staging table has {total_staged:,} records")
else:
    # Fresh start — clear checkpoint and rebuild staging
    spark.sql(f"TRUNCATE TABLE {CHECKPOINT_TABLE}")
    completed_batches = set()

    if IS_FULL_SYNC:
        SQL_QUERY = """
        SELECT
          concat('https://openalex.org/W', e.work_id) as id,
          e.embedding,
          w.publication_year,
          lower(w.type) as type,
          w.open_access.is_oa as is_oa,
          lower(w.language) as language,
          array_compact(transform(w.authorships, a -> a.author.id)) as author_ids,
          array_distinct(array_compact(flatten(
            transform(w.authorships, a -> transform(a.institutions, i -> i.id))
          ))) as institution_ids,
          array_distinct(array_compact(flatten(
            transform(w.authorships, a -> transform(a.institutions, i -> lower(i.country_code)))
          ))) as country_codes,
          w.is_retracted,
          w.primary_location.source.id as source_id,
          w.cited_by_count,
          array_compact(transform(coalesce(w.funders, array()), f -> f.id)) as funder_ids,
          w.fulltext IS NOT NULL as has_fulltext,
          w.has_abstract,
          w.primary_location.license_id as license_id
        FROM openalex.vector_search.work_embeddings_v2 e
        JOIN openalex.works.openalex_works w ON e.work_id = CAST(w.id AS STRING)
        """
    else:
        SQL_QUERY = """
        SELECT
          concat('https://openalex.org/W', e.work_id) as id,
          e.embedding,
          w.publication_year,
          lower(w.type) as type,
          w.open_access.is_oa as is_oa,
          lower(w.language) as language,
          array_compact(transform(w.authorships, a -> a.author.id)) as author_ids,
          array_distinct(array_compact(flatten(
            transform(w.authorships, a -> transform(a.institutions, i -> i.id))
          ))) as institution_ids,
          array_distinct(array_compact(flatten(
            transform(w.authorships, a -> transform(a.institutions, i -> lower(i.country_code)))
          ))) as country_codes,
          w.is_retracted,
          w.primary_location.source.id as source_id,
          w.cited_by_count,
          array_compact(transform(coalesce(w.funders, array()), f -> f.id)) as funder_ids,
          w.fulltext IS NOT NULL as has_fulltext,
          w.has_abstract,
          w.primary_location.license_id as license_id
        FROM openalex.vector_search.work_embeddings_v2 e
        JOIN openalex.works.openalex_works w ON e.work_id = CAST(w.id AS STRING)
        WHERE w.updated_date >= current_date() - INTERVAL 2 days
        """

    print(f"Running query and writing staging table ({NUM_BATCHES} batches)...")
    staging_start = time.time()

    df = spark.sql(SQL_QUERY)
    df = df.withColumn("batch_id", F.abs(F.hash("id")) % F.lit(NUM_BATCHES))

    if staging_exists:
        spark.sql(f"DROP TABLE {STAGING_TABLE}")
    df.write.format("delta").partitionBy("batch_id").saveAsTable(STAGING_TABLE)

    total_staged = spark.sql(f"SELECT COUNT(*) FROM {STAGING_TABLE}").collect()[0][0]
    staging_min = (time.time() - staging_start) / 60
    print(f"Staging complete: {total_staged:,} records in {staging_min:.1f}min")

remaining = NUM_BATCHES - len(completed_batches)
print(f"\nReady: {remaining} batches to process")

### Bulk Load Helpers

In [None]:
log_schema = StructType([
    StructField("partition_id", IntegerType(), True),
    StructField("indexed_count", IntegerType(), True),
    StructField("skipped_count", IntegerType(), True),
    StructField("error_count", IntegerType(), True),
    StructField("errors", ArrayType(StringType()), True)
])


def generate_actions(partition, errors_list):
    """Generate ES bulk index actions for the vector index."""
    for row in partition:
        try:
            embedding = [float(x) for x in row.embedding]

            doc = {
                "id": row.id,
                "vector_embedding": embedding,
                "publication_year": row.publication_year,
                "type": row.type,
                "is_oa": row.is_oa if row.is_oa is not None else False,
                "language": row.language,
                "author_ids": list(row.author_ids) if row.author_ids else [],
                "institution_ids": list(row.institution_ids) if row.institution_ids else [],
                "country_codes": list(row.country_codes) if row.country_codes else [],
                "is_retracted": row.is_retracted if row.is_retracted is not None else False,
                "source_id": row.source_id,
                "cited_by_count": row.cited_by_count or 0,
                "funder_ids": list(row.funder_ids) if row.funder_ids else [],
                "has_fulltext": row.has_fulltext if row.has_fulltext is not None else False,
                "has_abstract": row.has_abstract if row.has_abstract is not None else False,
                "license_id": row.license_id,
            }

            yield {
                "_op_type": "index",
                "_index": ELASTIC_INDEX,
                "_id": row.id,
                "_source": doc
            }
        except Exception as e:
            errors_list.append(f"Parse error for {getattr(row, 'id', '?')}: {str(e)[:200]}")


def send_partition_to_elastic(partition, partition_id):
    """Send a partition of docs to Elasticsearch."""
    client = Elasticsearch(
        hosts=[ELASTIC_URL],
        request_timeout=180,
        max_retries=5,
        retry_on_timeout=True,
        http_compress=True,
    )

    indexed_count = 0
    skipped_count = 0
    errors = []

    try:
        for success, info in helpers.parallel_bulk(
            client,
            generate_actions(partition, errors),
            chunk_size=500,
            thread_count=4,
            queue_size=10,
            raise_on_error=False
        ):
            if success:
                indexed_count += 1
            else:
                error_info = info.get("index", {})
                status = error_info.get("status", 0)
                if status == 409:
                    skipped_count += 1
                else:
                    if len(errors) < 10:
                        errors.append(str(info)[:500])
    except Exception as e:
        errors.append(f"Bulk error: {str(e)[:500]}")
    finally:
        client.close()

    yield {
        "partition_id": partition_id,
        "indexed_count": indexed_count,
        "skipped_count": skipped_count,
        "error_count": len(errors),
        "errors": errors
    }

### Execute Bulk Load

In [None]:
# Process batches with checkpointing + progress
batches_todo = sorted(b for b in range(NUM_BATCHES) if b not in completed_batches)

if not batches_todo:
    print("All batches already completed! Nothing to do.")
else:
    # Prior progress from checkpoint
    prior_stats = spark.sql(f"""
        SELECT COALESCE(SUM(indexed_count), 0) as indexed,
               COALESCE(SUM(error_count), 0) as errors
        FROM {CHECKPOINT_TABLE}
    """).collect()[0]
    prior_indexed = prior_stats.indexed

    print(f"=== Starting bulk load to {ELASTIC_INDEX} ===")
    print(f"Batches: {len(batches_todo)} remaining of {NUM_BATCHES}")
    if prior_indexed > 0:
        print(f"Previously indexed: {prior_indexed:,}")
    print()

    start_time = time.time()
    session_indexed = 0
    session_errors = 0

    for i, batch_id in enumerate(batches_todo):
        batch_start = time.time()

        # Read this batch from staging table (fast — partitioned by batch_id)
        batch_df = spark.read.table(STAGING_TABLE).filter(F.col("batch_id") == batch_id)

        # Send to Elasticsearch
        batch_rdd = batch_df.rdd.mapPartitionsWithIndex(
            lambda idx, part: send_partition_to_elastic(part, idx)
        )
        batch_logs = spark.createDataFrame(batch_rdd, log_schema)

        # Aggregate batch results
        stats = batch_logs.agg(
            F.sum("indexed_count").alias("indexed"),
            F.sum("skipped_count").alias("skipped"),
            F.sum("error_count").alias("errors")
        ).collect()[0]

        batch_indexed = stats.indexed or 0
        batch_skipped = stats.skipped or 0
        batch_errors = stats.errors or 0

        # Checkpoint — survives cluster restarts
        spark.sql(f"""
            INSERT INTO {CHECKPOINT_TABLE}
            VALUES ({batch_id}, {batch_indexed}, {batch_skipped}, {batch_errors}, current_timestamp())
        """)

        # Update session totals
        session_indexed += batch_indexed
        session_errors += batch_errors
        total_indexed = prior_indexed + session_indexed

        # Rate and ETA
        elapsed = time.time() - start_time
        batch_time = time.time() - batch_start
        rate = session_indexed / elapsed if elapsed > 0 else 0
        batches_done = i + 1
        batches_left = len(batches_todo) - batches_done
        avg_batch_time = elapsed / batches_done
        eta_min = (batches_left * avg_batch_time) / 60

        print(
            f"[{batches_done}/{len(batches_todo)}] "
            f"batch {batch_id}: +{batch_indexed:,} ({batch_time:.0f}s) | "
            f"Total: {total_indexed:,} | "
            f"{rate:,.0f}/s | "
            f"ETA: {eta_min:.0f}min"
        )

        # Show sample errors if any
        if batch_errors > 0:
            error_rows = batch_logs.filter(F.size("errors") > 0).select("errors").limit(3).collect()
            for row in error_rows:
                for err in row.errors[:1]:
                    print(f"  ERROR: {err[:300]}")

    total_elapsed = (time.time() - start_time) / 60
    print(f"\n=== Session Complete ===")
    print(f"Session: {session_indexed:,} indexed, {session_errors:,} errors in {total_elapsed:.1f}min")
    print(f"All sessions: {prior_indexed + session_indexed:,} total indexed")

In [None]:
# Final summary from checkpoint table
summary = spark.sql(f"""
    SELECT
        COUNT(*) as batches_done,
        COALESCE(SUM(indexed_count), 0) as total_indexed,
        COALESCE(SUM(skipped_count), 0) as total_skipped,
        COALESCE(SUM(error_count), 0) as total_errors,
        MIN(completed_at) as first_batch_at,
        MAX(completed_at) as last_batch_at
    FROM {CHECKPOINT_TABLE}
""").collect()[0]

print(f"=== Sync Summary (from checkpoint) ===")
print(f"Batches completed: {summary.batches_done}/{NUM_BATCHES}")
print(f"Total indexed: {summary.total_indexed:,}")
print(f"Total skipped: {summary.total_skipped:,}")
print(f"Total errors:  {summary.total_errors:,}")
if summary.first_batch_at and summary.last_batch_at:
    print(f"Time span: {summary.first_batch_at} → {summary.last_batch_at}")

if summary.batches_done < NUM_BATCHES:
    missing = NUM_BATCHES - summary.batches_done
    print(f"\nWARNING: {missing} batches not yet completed. Re-run cell 8 to resume.")

In [None]:
# Post-sync: refresh, set refresh_interval, verify doc count, clean up
# NOTE: Do NOT force merge here — run graduated merge manually (see PLAN.md)
# NOTE: Do NOT set replicas=1 — keep at 0 permanently (9.3TB, rebuildable in ~4h)
try:
    client = Elasticsearch(hosts=[ELASTIC_URL], request_timeout=180)

    if client.indices.exists(index=ELASTIC_INDEX):
        # Refresh
        client.indices.refresh(index=ELASTIC_INDEX)
        print(f"Refreshed index {ELASTIC_INDEX}")

        # Doc count
        total_docs = client.count(index=ELASTIC_INDEX)['count']
        print(f"Total documents: {total_docs:,}")

        # Set refresh interval (but keep replicas=0)
        if IS_FULL_SYNC:
            client.indices.put_settings(index=ELASTIC_INDEX, body={
                "index": {
                    "refresh_interval": "30s"
                }
            })
            print(f"Set refresh_interval=30s on {ELASTIC_INDEX} (replicas stay at 0)")
    else:
        print(f"Index {ELASTIC_INDEX} does not exist")
finally:
    client.close()

# Clean up staging and checkpoint tables
print(f"\nCleaning up temp tables...")
spark.sql(f"DROP TABLE IF EXISTS {STAGING_TABLE}")
spark.sql(f"DROP TABLE IF EXISTS {CHECKPOINT_TABLE}")
print("Staging and checkpoint tables dropped.")

In [None]:
# Test kNN search with a sample query
print("Testing kNN search...")

try:
    client = Elasticsearch(hosts=[ELASTIC_URL], request_timeout=180)

    # Get a sample embedding to use as query
    sample = spark.sql("""
        SELECT work_id, embedding
        FROM openalex.vector_search.work_embeddings_v2
        LIMIT 1
    """).collect()[0]

    query_vector = [float(x) for x in sample.embedding]

    # Run kNN search on the vector index
    result = client.search(
        index=ELASTIC_INDEX,
        body={
            "knn": {
                "field": "vector_embedding",
                "query_vector": query_vector,
                "k": 5,
                "num_candidates": 50
            },
            "_source": ["id", "publication_year", "type", "cited_by_count"]
        },
        size=5
    )

    print(f"Query work_id: {sample.work_id}")
    print(f"\nTop 5 similar works:")
    for hit in result['hits']['hits']:
        src = hit['_source']
        print(f"  {hit['_score']:.4f}: {src.get('id', 'N/A')} ({src.get('type', '?')}, {src.get('publication_year', '?')}, cited: {src.get('cited_by_count', 0)})")

    # Test with a filter
    filtered_result = client.search(
        index=ELASTIC_INDEX,
        body={
            "knn": {
                "field": "vector_embedding",
                "query_vector": query_vector,
                "k": 5,
                "num_candidates": 50,
                "filter": {
                    "bool": {
                        "must": [
                            {"term": {"is_oa": True}},
                            {"range": {"publication_year": {"gte": 2020}}}
                        ]
                    }
                }
            },
            "_source": ["id", "publication_year", "is_oa"]
        },
        size=5
    )

    print(f"\nFiltered (is_oa=true, year>=2020):")
    for hit in filtered_result['hits']['hits']:
        src = hit['_source']
        print(f"  {hit['_score']:.4f}: {src.get('id', 'N/A')} ({src.get('publication_year', '?')}, is_oa={src.get('is_oa', '?')})")

finally:
    client.close()