In [0]:
%pip install elasticsearch==8.19.0
%restart_python

In [0]:
import uuid
from datetime import datetime
from pyspark.sql import functions as F
from pyspark.sql.types import *
from dataclasses import dataclass

from elasticsearch import Elasticsearch, helpers
import logging
import json

logging.basicConfig(level=logging.WARNING, format='[%(asctime)s]: %(message)s')
log = logging.getLogger(__name__)

ELASTIC_INDEX = "works-v26"
ELASTIC_URL = dbutils.secrets.get(scope="elastic", key="elastic_url")
MAX_LENGTH = 32000  # Slightly below the 32766 limit

OP_TYPE = dbutils.widgets.get("op_type")
DELETE_INDEX = dbutils.widgets.get("delete_index").lower() == "true" # default is false
REPARTITION = dbutils.widgets.get("repartition").lower() == "true" # default is false
IS_FULL_SYNC = dbutils.widgets.get("is_full_sync").lower() == "true" # default is incremental
OREO_ONLY = dbutils.widgets.get("oreo_only").lower() == "true"

print(f"OP_TYPE: {OP_TYPE}")
print(f"DELETE_INDEX: {DELETE_INDEX}")
print(f"REPARTITION: {REPARTITION}")
print(f"IS_FULL_SYNC: {IS_FULL_SYNC}")
print(f"OREO_ONLY: {OREO_ONLY}")

### Delete Index if Needed

In [0]:
if DELETE_INDEX:
  print(f"WARNING: Deleting {ELASTIC_INDEX}...")
  try:
    # delete index if exists
    client = Elasticsearch(
        hosts=[ELASTIC_URL],
        request_timeout=180,
        max_retries=5,
        retry_on_timeout=True
    )
    if client.indices.exists(index=ELASTIC_INDEX):
        client.indices.delete(index=ELASTIC_INDEX)
        print(f"Deleted index {ELASTIC_INDEX}")
    else:
        print(f"Index {ELASTIC_INDEX} does not exist")
  finally:
      client.close()

### Prepare Input

In [0]:
SQL_QUERY = """SELECT *
                FROM openalex.works.openalex_works
                WHERE openalex_works.created_date >= current_date() - INTERVAL 2 days 
                OR openalex_works.updated_date >= current_date() - INTERVAL 2 days
              """
if (IS_FULL_SYNC):
    SQL_QUERY = """SELECT * EXCEPT (fulltext) FROM openalex.works.openalex_works"""

if (OREO_ONLY):
    SQL_QUERY = """SELECT * FROM openalex.works.openalex_works_oreo"""

df = (
    # TEST (to avoid RDD accessing Parquet files directly via spark.read.table)
    spark.sql(SQL_QUERY)
    # spark.read.table("openalex.works.openalex_works")
    .withColumn("display_name", F.col("title"))
    # First cast to date/timestamp
    .withColumn("created_date", F.to_timestamp("created_date"))
    .withColumn("updated_date", F.to_timestamp("updated_date"))
    .withColumn("publication_date", F.to_date("publication_date"))
    .withColumn(
        "concepts",
        F.transform(
            F.col("concepts"),
            lambda c: F.struct(
                F.concat(F.lit("https://openalex.org/C"), c.id).alias("id"),
                c.wikidata.alias("wikidata"),
                c.display_name.alias("display_name"),
                c.level.alias("level"),
                c.score.alias("score")
            )
        )
    )
    # Apply range checks using BETWEEN
    .withColumn(
        "created_date",
        F.when(
            F.col("created_date").between(F.lit("1000-01-01"), F.lit("9999-12-31")),
            F.col("created_date")
        ).otherwise(F.lit(None).cast("timestamp"))
    )
    .withColumn(
        "updated_date",
        F.when(
            F.col("updated_date").between(F.lit("1000-01-01"), F.lit("9999-12-31")),
            F.col("updated_date")
        ).otherwise(F.lit(None).cast("timestamp"))
    )
    .withColumn(
        "publication_date",
        F.when(
            F.col("publication_date").between(F.lit("1000-01-01"), F.lit("9999-12-31")),
            F.col("publication_date")
        ).otherwise(F.lit(None).cast("date"))
    )
)

In [0]:
@udf(StringType())
def truncate_abstract_index_string(raw_json: str, max_bytes: int = 32760) -> str:
    try:
        if not raw_json:
            return None

        if len(raw_json) <= (max_bytes // 4):
            return raw_json

        encoded = raw_json.encode('utf-8')
        if len(encoded) <= max_bytes:
            return raw_json

        truncated = encoded[:max_bytes].decode('utf-8', errors='ignore')
        last_bracket = truncated.rfind(']')
        if last_bracket == -1:
            return None

        return truncated[:last_bracket + 1] + '}'
    except Exception:
        return None
    
def sanitize_string(col_name: str, max_len: int = 32000):
    return F.when(F.col(col_name).isNotNull(), F.substring(F.col(col_name), 1, max_len)).otherwise(None)

empty_sdg_array = F.array().cast("array<struct<id:string,display_name:string,score:double>>")

df_transformed = (
    df
    .withColumn("id", F.concat(F.lit("https://openalex.org/W"), F.col("id")))
    .withColumn("publication_year", F.when(F.col("publication_date").isNotNull(), F.year("publication_date")).otherwise(1800))
    .withColumn("title", sanitize_string("title"))
    .withColumn("display_name", sanitize_string("display_name"))
    .withColumn("ids", 
        F.transform_values("ids",
            lambda k, v: F.when(k == "doi", 
                    F.concat(F.lit("https://doi.org/"),v)).otherwise(v)
        )
    )
    .withColumn("doi", sanitize_string("doi"))
    .withColumn("language", sanitize_string("language"))
    .withColumn("type", sanitize_string("type"))
    .withColumn("abstract", sanitize_string("abstract"))
    .withColumn("referenced_works", 
                F.expr("transform(referenced_works, x -> 'https://openalex.org/W' || x)"))
    .withColumn("referenced_works_count", 
                F.when(F.col("referenced_works").isNotNull(), F.size("referenced_works")).otherwise(0))
    .withColumn("abstract_inverted_index", truncate_abstract_index_string(F.col("abstract_inverted_index")))
    .withColumn("open_access", F.struct(
        F.col("open_access.is_oa"),
        sanitize_string("open_access.oa_status").alias("oa_status"),
        F.lit(None).cast("boolean").alias("any_repository_has_fulltext")
    ))
    .withColumn("authorships", F.expr("""
        transform(authorships, x -> named_struct(
            'affiliations', x.affiliations,
            'author', x.author,
            'author_position', substring(x.author_position, 1, 32000),
            'countries', x.countries,
            'raw_author_name', substring(x.raw_author_name, 1, 32000),
            'is_corresponding', x.is_corresponding,
            'raw_affiliation_strings', transform(x.raw_affiliation_strings, aff -> substring(aff, 1, 32000)),
            'institutions', x.institutions
        ))
    """))
    .withColumn("locations", F.expr("""
        transform(locations, x -> named_struct(
            'is_oa', x.is_oa,
            'is_published', x.version = 'publishedVersion',
            'landing_page_url', substring(x.landing_page_url, 1, 32000),
            'pdf_url', substring(x.pdf_url, 1, 32000),
            'source', x.source,
            'native_id', x.native_id,
            'provenance', x.provenance,
            'license', x.license,
            'license_id', x.license_id
        ))
    """))
    # limit to a reasonable number (they go up to 130) - mainly for xpac
    .withColumn("concepts", F.slice(F.col("concepts"), 1, 40))
    .withColumn("_source", F.struct(
        F.col("id"),
        F.col("doi"),
        F.col("title"),
        F.col("display_name"),
        F.col("ids"),
        # move this to openalex_works at some point if it is useful in other contexts
        F.expr("""
            array_sort(
                array_distinct(
                    array_compact(
                        flatten(
                            TRANSFORM(locations, loc ->
                                CASE
                                WHEN loc.provenance IN ('crossref', 'pubmed', 'datacite')
                                    THEN array(loc.provenance, IF(loc.source.is_in_doaj, 'doaj', NULL))
                                WHEN loc.provenance = 'repo' AND lower(loc.native_id) like 'oai:arxiv.org%'
                                    THEN array('arxiv')
                                WHEN loc.provenance = 'repo' AND lower(loc.native_id) like 'oai:doaj.org/%'
                                    THEN array('doaj')
                                WHEN loc.provenance = 'mag' AND lower(loc.source.display_name) = 'pubmed'
                                    THEN array('pubmed')
                                ELSE array()
                                END
                            )
                        )
                    )
                )
            )
        """).alias("indexed_in"),
        F.col("publication_date"),
        F.col("publication_year"),
        F.col("language"),
        F.col("type"),
        F.col("authorships"),
        F.coalesce(F.col("corresponding_author_ids"), F.lit([])).alias("corresponding_author_ids"),
        F.coalesce(F.col("corresponding_institution_ids"), F.lit([])).alias("corresponding_institution_ids"),
        F.col("primary_topic"),
        F.col("topics"),
        F.col("keywords"),
        F.col("concepts"),
        F.col("locations"),
        F.col("locations_count"),
        F.col("primary_location"),
        F.col("best_oa_location"),
        F.coalesce(F.col("sustainable_development_goals"), empty_sdg_array).alias("sustainable_development_goals"),
        F.col("grants"),
        F.col("awards"),
        F.col("open_access"),
        F.col("is_paratext"),
        F.col("is_retracted"),
        F.col("is_xpac"),
        F.col("biblio"),
        F.col("abstract"),
        F.col("referenced_works"),
        F.col("referenced_works_count"),
        F.coalesce(F.col("related_works"), F.lit([])).alias("related_works"),
        F.col("abstract_inverted_index"),
        F.col("cited_by_count"),
        F.col("counts_by_year"),
        F.col("apc_list"),
        F.col("apc_paid"),
        F.coalesce(F.col("fwci"),F.lit(0)).alias("fwci"),
        F.col("citation_normalized_percentile"),
        F.col("cited_by_percentile_year"),
        F.coalesce(F.col("mesh"), F.lit([])).alias("mesh"),
        F.col("has_abstract"),
        F.col("created_date"),
        F.col("updated_date"),
        F.current_timestamp().alias("indexed_timestamp")
    ))
    .select("id", "_source")
)
# print(f"Transformed {df_transformed.count()} records.")
display(df_transformed)

### Create Helpers

In [0]:
log_schema = StructType([
    StructField("index_name", StringType(), True),
    StructField("run_id", StringType(), True),
    StructField("partition_id", IntegerType(), True),
    StructField("indexed_count", IntegerType(), True),
    StructField("skipped_count", IntegerType(), True),
    StructField("parsing_error_count", IntegerType(), True),
    StructField("indexing_error_count", IntegerType(), True),
    StructField(
        "parsing_errors",
        ArrayType(
            StructType([
                StructField("row_id", LongType(), True),
                StructField("error", StringType(), True)
            ])
        ),
        True
    ),
    StructField(
        "indexing_errors",
        ArrayType(
            StructType([
                StructField("row_id", LongType(), True),
                StructField("error", StringType(), True)
            ])
        ),
        True
    )
])

@dataclass
class DatabricksEnvInfo:
    job_id: str = None
    run_id: str = None
    command_run_id: str = None
    notebook_path: str = None
    user: str = None

def get_databricks_env_info() -> DatabricksEnvInfo:
    attr = json.loads(
        dbutils.notebook.entry_point.getDbutils().notebook().getContext().safeToJson()
    )["attributes"]
    return DatabricksEnvInfo(
        command_run_id = attr.get("commandRunId"),
        job_id = attr.get("jobId"),
        run_id = attr.get("currentRunId"),
        notebook_path = attr.get("notebook_path"),
        user = attr.get("user"),
    )

dbx_env = get_databricks_env_info()

def get_run_identifier():
    if dbx_env.job_id and dbx_env.run_id:
        return f"{dbx_env.job_id}-{dbx_env.run_id}"
    else:
        return f"{dbx_env.user}-{dbx_env.command_run_id}"
    
def generate_prepared_actions(partition, parsing_errors, op_type = "index"):
    for row in partition:
        try:
            yield {
                "_op_type": op_type,
                "_index": ELASTIC_INDEX,
                "_id": row.id,
                "_source": row._source.asDict(True)
            }
        except Exception as e:
            parsing_errors.append({"row_id": row.id, "error": str(e)})             

def send_partition_to_elastic(partition, partition_id, op_type="create"):
    client = Elasticsearch(
        hosts=[ELASTIC_URL],
        request_timeout=180,
        max_retries=5,
        retry_on_timeout=True,
        http_compress=True,
    )

    indexed_count = 0
    parsing_errors = []
    indexing_errors = []
    skipped_count = 0
    op_type = op_type.lower()

    try:
        for success, info in helpers.parallel_bulk(client,
                generate_prepared_actions(partition, parsing_errors, op_type),
                chunk_size=500, thread_count=4, queue_size=10
            ):

            if success:
                indexed_count += 1
            else:
                error_info = info.get(op_type, {})
                status = error_info.get("status", 0)

                # ✅ Skip 409 (document already exists)
                if status == 409:
                    skipped_count += 1
                    continue

                id_url = error_info.get('_id')
                row_id = None
                if id_url:
                    try:
                        row_id = int(id_url.replace("https://openalex.org/W", ""))
                    except ValueError:
                        row_id = -1

                indexing_errors.append({
                    "row_id": row_id,
                    "error": str(info)[:1000]
                })
    except Exception as e:
        indexing_errors.append({
            "row_id": None,
            "error": "Parallel Bulk Error: " + str(e)[:1000]
        })
    finally:
        client.close()

    log_entry = {
        "index_name": ELASTIC_INDEX,
        "run_id": get_run_identifier(),
        "partition_id": partition_id,
        "success": len(indexing_errors) == 0,
        "indexed_count": indexed_count,
        "skipped_count": skipped_count,
        "parsing_error_count": len(parsing_errors),
        "indexing_error_count": len(indexing_errors),
        "message": f"{indexed_count} records indexed. {skipped_count} skipped. {len(parsing_errors)} parsing errors. {len(indexing_errors)} ES errors.",
        "parsing_errors": parsing_errors[:1000],
        "indexing_errors": indexing_errors[:1000],
    }

    yield log_entry

### Re-balance partitions to avoid skew and driver OOM (FULL run only)

In [0]:
# takes too long - only if absolutely necessary
if REPARTITION:
    df_transformed = df_transformed.repartitionByRange(8096, "id")
# Trigger the shuffle once while caching the dataset for reliability and to avoid re-computation during MapPartitionsWithIndex
# can instead just write to openalex.works.works_api?
# df_transformed.write.mode("overwrite").saveAsTable("openalex.works.works_api_sync")

### Execute Sync with `mapPartitionsWithIndex`

In [0]:

# df_input = spark.read.table("openalex.works.works_api_sync")
import pprint # Import the pretty-print library for clean output

if (OREO_ONLY):
    indexed_count = 0
    # --- Your existing setup code remains the same ---
    df_transformed_rows = df_transformed.collect()
    client = Elasticsearch(
        hosts=[ELASTIC_URL],
        request_timeout=180,
        max_retries=5,
        retry_on_timeout=True,
        http_compress=True,
    )

    # --- Refined bulk indexing with detailed error reporting ---
    indexed_count = 0
    failed_docs = []

    # Use streaming_bulk with a small chunk_size
    for success, info in helpers.streaming_bulk(
        client,
        generate_prepared_actions(df_transformed_rows, "index"), # Assuming the generator is simplified
        chunk_size=100 # Process in batches of 10
    ):
        if success:
            indexed_count += 1
        else:
            # This block now correctly parses the error info
            action, result = info.popitem()
            doc_id = result.get("_id", "[unknown_id]")
            error_details = result.get("error", {})
            
            # Print a clear error message
            print("---" * 15)
            print(f"💥 FAILED to index document ID: {doc_id}")
            pprint.pprint(error_details)
            print("---" * 15)
            
            # Keep track of failed documents
            failed_docs.append(doc_id)
        if (indexed_count % 100 == 0):
            print(f"Indexed {indexed_count} documents so far...")

    client.close()

    # --- Final Summary ---
    print("\n--- BULK OPERATION COMPLETE ---")
    print(f"Successfully indexed: {indexed_count} documents")
    if failed_docs:
        print(f"Failed to index: {len(failed_docs)} documents")
        print(f"Failed document IDs: {failed_docs}")
else:
    logs_rdd = df_transformed.rdd.mapPartitionsWithIndex(
        lambda partition_idx, partition: send_partition_to_elastic(partition, partition_idx, OP_TYPE)
    )
    logs_df = spark.createDataFrame(logs_rdd, log_schema)

    log_count = logs_df.count() # mapPartitionsWithIndex is lazy, so force it to run
    print(f"Processed {log_count} partitions")

In [0]:
try:
    # delete index if exists
    client = Elasticsearch(hosts=[ELASTIC_URL], request_timeout=60)
    
    if client.indices.exists(index=ELASTIC_INDEX):
        client.indices.refresh(index=ELASTIC_INDEX)
        print(f"Refreshed index {ELASTIC_INDEX}")
        print(f"{client.count(index=ELASTIC_INDEX)['count']} documents in {ELASTIC_INDEX}")
    else:
        print(f"Index {ELASTIC_INDEX} does not exist")
finally:
    client.close()