In [0]:
%pip install elasticsearch==8.19.0
%restart_python

In [0]:
import uuid
from datetime import datetime
from pyspark.sql import functions as F
from pyspark.sql.types import *
from dataclasses import dataclass

from elasticsearch import Elasticsearch, helpers
import logging
import json

ELASTIC_INDEX = "works-v26"
ELASTIC_URL = dbutils.secrets.get(scope="elastic", key="elastic_url")

UPDATE_FIELDS_ONLY = dbutils.widgets.get("update_fields_only").lower()

if UPDATE_FIELDS_ONLY.strip() == '':
    dbutils.notebook.exit("Please provide a comma-separated list of fields to update.")
    
print("UPDATE_FIELDS_ONLY: " + UPDATE_FIELDS_ONLY)

In [0]:
SQL_QUERY = f"""SELECT id, {UPDATE_FIELDS_ONLY}
FROM openalex.works.openalex_works"""
# WHERE array_contains(flatten(authorships.institutions.id), 'https://openalex.org/I141945490')"""
# WHERE id = 2151103935"""
print(SQL_QUERY)
df = (spark.sql(SQL_QUERY)
    .withColumn("id", F.concat(F.lit("https://openalex.org/W"), F.col("id")))
)

FIELDS = [f.strip().lower() for f in UPDATE_FIELDS_ONLY.split(',')]
if "concepts" in FIELDS:
    df = df.withColumn(
        "concepts",
        F.transform(
            F.col("concepts"),
            lambda c: F.struct(
                F.concat(F.lit("https://openalex.org/C"), c.id).alias("id"),
                c.wikidata.alias("wikidata"),
                c.display_name.alias("display_name"),
                c.level.alias("level"),
                c.score.alias("score")
            )
        )
    )
df = df.withColumn("doc", F.struct(FIELDS)).select("id", "doc")
display(df)

### Create Helpers

In [0]:
def generate_prepared_actions(partition, parsing_errors, op_type = "update"):
    for row in partition:
        try:
            yield {
                "_op_type": op_type,
                "_index": ELASTIC_INDEX,
                "_id": row.id,
                "doc": row.doc.asDict(True)
            }
        except Exception as e:
            parsing_errors.append({"row_id": row.id, "error": str(e)})             

def send_partition_to_elastic(partition, partition_id, op_type="update"):
    client = Elasticsearch(
        hosts=[ELASTIC_URL],
        request_timeout=180,
        max_retries=5,
        retry_on_timeout=True,
        http_compress=True,
    )

    indexed_count = 0
    parsing_errors = []
    indexing_errors = []
    skipped_count = 0
    op_type = op_type.lower()

    try:
        for success, info in helpers.parallel_bulk(client,
                generate_prepared_actions(partition, parsing_errors, op_type),
                chunk_size=500, thread_count=4, queue_size=10
            ):

            if success:
                indexed_count += 1
            else:
                error_info = info.get(op_type, {})
                status = error_info.get("status", 0)

                # ✅ Skip 409 (document already exists)
                if status == 409:
                    skipped_count += 1
                    continue

                id_url = error_info.get('_id')
                row_id = None
                if id_url:
                    try:
                        row_id = int(id_url.replace("https://openalex.org/W", ""))
                    except ValueError:
                        row_id = -1

                indexing_errors.append({
                    "row_id": row_id,
                    "error": str(info)[:1000]
                })
    except Exception as e:
        indexing_errors.append({
            "row_id": None,
            "error": "Parallel Bulk Error: " + str(e)[:1000]
        })
    finally:
        client.close()

    log_entry = {
        "index_name": ELASTIC_INDEX,
        "partition_id": partition_id,
        "success": len(indexing_errors) == 0,
        "indexed_count": indexed_count,
        "skipped_count": skipped_count,
        "parsing_error_count": len(parsing_errors),
        "indexing_error_count": len(indexing_errors),
        "message": f"{indexed_count} records indexed. {skipped_count} skipped. {len(parsing_errors)} parsing errors. {len(indexing_errors)} ES errors.",
        "parsing_errors": parsing_errors[:1000],
        "indexing_errors": indexing_errors[:1000],
    }

    yield log_entry

### Execute Sync with `mapPartitionsWithIndex`

In [0]:

# df_input = spark.read.table("openalex.works.works_api_sync")
import pprint # Import the pretty-print library for clean output
log_schema = StructType([
    StructField("index_name", StringType(), True),
    StructField("partition_id", IntegerType(), True),
    StructField("indexed_count", IntegerType(), True),
    StructField("skipped_count", IntegerType(), True),
    StructField("parsing_error_count", IntegerType(), True),
    StructField("indexing_error_count", IntegerType(), True),
    StructField(
        "parsing_errors",
        ArrayType(
            StructType([
                StructField("row_id", LongType(), True),
                StructField("error", StringType(), True)
            ])
        ),
        True
    ),
    StructField(
        "indexing_errors",
        ArrayType(
            StructType([
                StructField("row_id", LongType(), True),
                StructField("error", StringType(), True)
            ])
        ),
        True
    )
])

logs_rdd = df.rdd.mapPartitionsWithIndex(
    lambda partition_idx, partition: send_partition_to_elastic(partition, partition_idx, "update")
)
logs_df = spark.createDataFrame(logs_rdd, log_schema)

log_count = logs_df.count() # mapPartitionsWithIndex is lazy, so force it to run
print(f"Processed {log_count} partitions")

In [0]:
try:
    # delete index if exists
    client = Elasticsearch(hosts=[ELASTIC_URL], request_timeout=180)
    
    if client.indices.exists(index=ELASTIC_INDEX):
        client.indices.refresh(index=ELASTIC_INDEX)
        print(f"Refreshed index {ELASTIC_INDEX}")
        print(f"{client.count(index=ELASTIC_INDEX)['count']} documents in {ELASTIC_INDEX}")
    else:
        print(f"Index {ELASTIC_INDEX} does not exist")
finally:
    client.close()