In [0]:
%pip install elasticsearch==8.19.0
%restart_python

In [0]:
import uuid
from datetime import datetime
from pyspark.sql import functions as F
from pyspark.sql.types import *
from dataclasses import dataclass

from elasticsearch import Elasticsearch, helpers
import logging
import json

logging.basicConfig(level=logging.WARNING, format='[%(asctime)s]: %(message)s')
log = logging.getLogger(__name__)

ELASTIC_INDEX = "works-v31"
ELASTIC_URL = dbutils.secrets.get(scope="elastic", key="elastic_url")
MAX_LENGTH = 32000  # Slightly below the 32766 limit

IS_FULL_SYNC = dbutils.widgets.get("is_full_sync").lower() == "true" # default is incremental

print(f"IS_FULL_SYNC: {IS_FULL_SYNC}")

In [None]:
# Set replicas to 0 for faster bulk indexing during full sync
if IS_FULL_SYNC:
    try:
        client = Elasticsearch(
            hosts=[ELASTIC_URL],
            request_timeout=180,
            max_retries=5,
            retry_on_timeout=True
        )
        if client.indices.exists(index=ELASTIC_INDEX):
            client.indices.put_settings(index=ELASTIC_INDEX, body={
                "index": {"number_of_replicas": 0}
            })
            print(f"Set replicas to 0 on {ELASTIC_INDEX} for full sync")
        else:
            print(f"Index {ELASTIC_INDEX} does not exist yet - will create with default settings")
    finally:
        client.close()

### Prepare Input

In [None]:
SQL_QUERY = """SELECT * FROM openalex.works.openalex_works
WHERE updated_date >= current_date() - INTERVAL 2 days
"""
COUNT_QUERY = """SELECT COUNT(*) as cnt FROM openalex.works.openalex_works
WHERE updated_date >= current_date() - INTERVAL 2 days
"""

if (IS_FULL_SYNC):
    SQL_QUERY = """SELECT * FROM openalex.works.openalex_works"""
    COUNT_QUERY = None  # Full sync doesn't need count-based optimization

# Get record count BEFORE loading data (lightweight SQL count, no transformations)
record_count = None
if not IS_FULL_SYNC and COUNT_QUERY:
    record_count = spark.sql(COUNT_QUERY).collect()[0].cnt
    print(f"Record count for daily sync: {record_count:,}")

df = (
    spark.sql(SQL_QUERY)
    .withColumn("display_name", F.col("title"))
    # First cast to date/timestamp
    .withColumn("created_date", F.to_timestamp("created_date"))
    .withColumn("updated_date", F.to_timestamp("updated_date"))
    .withColumn("publication_date", F.to_date("publication_date"))
    .withColumn(
        "concepts",
        F.transform(
            F.col("concepts"),
            lambda c: F.struct(
                F.concat(F.lit("https://openalex.org/C"), c.id).alias("id"),
                c.wikidata.alias("wikidata"),
                c.display_name.alias("display_name"),
                c.level.alias("level"),
                c.score.alias("score")
            )
        )
    )
    # Apply range checks using BETWEEN
    .withColumn(
        "created_date",
        F.when(
            F.col("created_date").between(F.lit("1000-01-01"), F.lit("9999-12-31")),
            F.col("created_date")
        ).otherwise(F.lit(None).cast("timestamp"))
    )
    .withColumn(
        "updated_date",
        F.when(
            F.col("updated_date").between(F.lit("1000-01-01"), F.lit("9999-12-31")),
            F.col("updated_date")
        ).otherwise(F.lit(None).cast("timestamp"))
    )
    .withColumn(
        "publication_date",
        F.when(
            F.col("publication_date").between(F.lit("1000-01-01"), F.lit("2050-12-31")),
            F.col("publication_date")
        ).otherwise(F.lit(None).cast("date"))
    )
    .filter(F.col("id").isNotNull())
)

# Dynamic partitioning based on record volume
# Only apply partition optimization for non-full syncs
if not IS_FULL_SYNC and record_count is not None:
    # Calculate optimal partition count:
    # - Small updates (<500k): use fewer partitions for efficiency
    # - Medium updates (500k-5M): moderate partitions  
    # - Large updates (5M-20M): many partitions like full sync
    # - Very large updates (>20M): use repartitionByRange for even distribution
    RECORDS_PER_PARTITION = 10000  # Target ~10k records per partition
    
    if record_count < 2_000_000:
        # Small daily update - coalesce to reduce overhead
        optimal_partitions = max(64, record_count // RECORDS_PER_PARTITION)
        df = df.coalesce(optimal_partitions)
        print(f"Small update: coalesced to {optimal_partitions} partitions")
    elif record_count < 10_000_000:
        # Medium daily update - use more partitions
        optimal_partitions = max(1024, record_count // RECORDS_PER_PARTITION)
        df = df.repartition(optimal_partitions)
        print(f"Medium update: repartitioned to {optimal_partitions} partitions")
    elif record_count < 20_000_000:
        # Large daily update - repartition for better distribution
        optimal_partitions = min(4096, record_count // RECORDS_PER_PARTITION)
        df = df.repartition(optimal_partitions)
        print(f"Large update: repartitioned to {optimal_partitions} partitions")
    else:
        # Very large update - use repartitionByRange like full sync
        df = df.repartitionByRange(8096, "id")
        print(f"Very large update: using repartitionByRange with 8096 partitions")

print(f"SQL query:\n{SQL_QUERY}")

In [None]:
@udf(StringType())
def truncate_abstract_index_string(raw_json: str, max_bytes: int = 32760) -> str:
    """
    Truncate inverted index JSON by finding a safe cutoff point.
    """
    try:
        if not raw_json:
            return None
            
        encoded = raw_json.encode('utf-8')
        if len(encoded) <= max_bytes:
            return raw_json
            
        safe_bytes = max_bytes - 100
        truncated = encoded[:safe_bytes].decode('utf-8', errors='ignore')
        
        last_complete_array = -1
        for pattern in ['],"', '],']:
            pos = truncated.rfind(pattern)
            if pos > last_complete_array:
                last_complete_array = pos
                
        if last_complete_array == -1:
            return '{}'
            
        if truncated[last_complete_array:last_complete_array+3] == '],"':
            result = truncated[:last_complete_array+1] + '}'
        else:
            result = truncated[:last_complete_array+1] + '}'
            
        if result.count('{') != result.count('}'):
            return '{}'
            
        return result
        
    except Exception:
        return None
    
def sanitize_name(col_name: str):
  """
  Cleans a string column by removing unwanted characters and normalizing whitespace.
  Handles multilingual text by preserving letters, numbers, punctuation, and symbols from all Unicode scripts.

  Args:
    col_name: The name of the column to sanitize.

  Returns:
    A PySpark Column object with the cleaning transformations applied.
  """
  # Pattern to match any character that is NOT a letter, number, punctuation, symbol, or space (in any language).
  # Usually they show up due to encoding changes (from windows-1251, etc)
  unwanted_chars_pattern = r"[^\p{L}\p{N}\p{P}\p{S}\p{Z}]"
  # Pattern to match one or more whitespace characters.
  multiple_spaces_pattern = r"\s+"

  return F.trim( # trim to a single space
      F.regexp_replace( 
          F.regexp_replace(F.col(col_name), unwanted_chars_pattern, ""), # replace with empty string (risk - having them as word boundaries)
          multiple_spaces_pattern, " " # collapse the 2+ spaces
      )
  )

def sanitize_string(col_name: str, max_len: int = 32000):
    return F.when(F.col(col_name).isNotNull(), F.substring(F.col(col_name), 1, max_len)).otherwise(None)

empty_sdg_array = F.array().cast("array<struct<id:string,display_name:string,score:double>>")

df_transformed = (
    df
    .withColumn("id", F.concat(F.lit("https://openalex.org/W"), F.col("id")))
    .withColumn("publication_year", F.coalesce(
        F.col("publication_year"),
        F.year(F.col("publication_date"))
    )) # remove setting 1800 on NULL
    .withColumn("publication_year", F.year("publication_date"))
    .withColumn("title", sanitize_name("title"))
    .withColumn("display_name", sanitize_name("display_name"))
    .withColumn("ids", 
        F.transform_values("ids",
            lambda k, v: F.when(k == "doi", 
                    F.concat(F.lit("https://doi.org/"),v)).otherwise(v)
        )
    )
    .withColumn("doi", sanitize_string("doi"))
    .withColumn("language", sanitize_string("language"))
    .withColumn("type", sanitize_string("type"))
    .withColumn("abstract", sanitize_string("abstract"))
    .withColumn("referenced_works", 
                F.expr("transform(referenced_works, x -> 'https://openalex.org/W' || x)"))
    .withColumn("referenced_works_count",
                F.when(F.col("referenced_works").isNotNull(), F.size("referenced_works")).otherwise(0))
    .withColumn("abstract_inverted_index", truncate_abstract_index_string(F.col("abstract_inverted_index")))
    .withColumn("open_access", F.struct(
        F.col("open_access.is_oa"),
        sanitize_string("open_access.oa_status").alias("oa_status"),
        F.lit(False).cast("boolean").alias("any_repository_has_fulltext"),
        F.col("open_access.oa_url")
    ))
    # Build full authorships first, then truncate for the limited version
    .withColumn("authorships_full", F.expr("""
        transform(authorships, x -> named_struct(
            'affiliations', x.affiliations,
            'author', x.author,
            'author_position', substring(x.author_position, 1, 32000),
            'countries', x.countries,
            'raw_author_name', substring(x.raw_author_name, 1, 32000),
            'is_corresponding', x.is_corresponding,
            'raw_affiliation_strings', transform(x.raw_affiliation_strings, aff -> substring(aff, 1, 32000)),
            'institutions', x.institutions
        ))
    """))
    .withColumn("authorships", F.slice(F.col("authorships_full"), 1, 100))
    .withColumn("locations", F.expr("""
        transform(locations, x -> named_struct(
            'is_oa', x.is_oa,
            'is_published', x.version = 'publishedVersion',
            'landing_page_url', substring(x.landing_page_url, 1, 32000),
            'pdf_url', substring(x.pdf_url, 1, 32000),
            'source', x.source,
            'raw_source_name', x.raw_source_name,
            'raw_type', x.raw_type,
            'native_id', x.native_id,
            'provenance', x.provenance,
            'license', x.license,
            'license_id', x.license_id,
            'version', x.version,
            'is_accepted', x.is_accepted
        ))
    """))
    # limit to a reasonable number (they go up to 130) - mainly for xpac
    .withColumn("concepts", F.slice(F.col("concepts"), 1, 40))
    .withColumn("has_fulltext", F.col("fulltext").isNotNull())
    .withColumn("_source", F.struct(
        F.col("id"),
        F.col("doi"),
        F.col("title"),
        F.col("display_name"),
        F.col("ids"),
        # move this to openalex_works at some point if it is useful in other contexts
        F.expr("""
            array_sort(
                array_distinct(
                    array_compact(
                        flatten(
                            TRANSFORM(locations, loc ->
                                CASE
                                WHEN loc.provenance IN ('crossref', 'pubmed', 'datacite')
                                    THEN array(loc.provenance, IF(loc.source.is_in_doaj, 'doaj', NULL))
                                WHEN loc.provenance = 'repo' AND lower(loc.native_id) like 'oai:arxiv.org%'
                                    THEN array('arxiv')
                                WHEN loc.provenance = 'repo' AND lower(loc.native_id) like 'oai:doaj.org/%'
                                    THEN array('doaj')
                                WHEN loc.provenance = 'mag' AND lower(loc.source.display_name) = 'pubmed'
                                    THEN array('pubmed')
                                ELSE array()
                                END
                            )
                        )
                    )
                )
            )
        """).alias("indexed_in"),
        F.col("publication_date"),
        F.col("publication_year"),
        F.col("language"),
        F.col("type"),
        F.coalesce(F.col("authorships"), F.lit([])).alias("authorships"),
        F.coalesce(F.col("authorships_full"), F.lit([])).alias("authorships_full"),
        F.col("authors_count"),
        F.coalesce(F.col("corresponding_author_ids"), F.lit([])).alias("corresponding_author_ids"),
        F.coalesce(F.col("corresponding_institution_ids"), F.lit([])).alias("corresponding_institution_ids"),
        F.col("primary_topic"),
        F.col("topics"),
        F.col("keywords"),
        F.col("concepts"),
        F.col("locations"),
        F.col("locations_count"),
        F.col("primary_location"),
        F.col("best_oa_location"),
        F.coalesce(F.col("sustainable_development_goals"), empty_sdg_array).alias("sustainable_development_goals"),
        F.col("awards"),
        F.col("funders"),
        F.col("institutions"),
        F.col("countries_distinct_count"),
        F.col("institutions_distinct_count"),
        F.col("open_access"),
        F.col("is_paratext"),
        F.col("is_retracted"),
        F.col("is_xpac"),
        F.col("biblio"),
        F.col("abstract"),
        F.col("referenced_works"),
        F.col("referenced_works_count"),
        F.coalesce(F.col("related_works"), F.lit([])).alias("related_works"),
        F.col("abstract_inverted_index"),
        F.col("cited_by_count"),
        F.col("counts_by_year"),
        F.col("apc_list"),
        F.col("apc_paid"),
        F.col("fwci"),
        F.col("citation_normalized_percentile"),
        F.col("cited_by_percentile_year"),
        F.coalesce(F.col("mesh"), F.lit([])).alias("mesh"),
        F.col("has_abstract"),
        F.col("has_content"),
        F.col("fulltext"),
        F.col("has_fulltext"),
        F.col("created_date"),
        F.col("updated_date"),
        F.current_timestamp().alias("indexed_timestamp")
    ))
    .select("id", "_source")
)

# Record count was already printed in cell-5 using lightweight SQL COUNT
# No need to call df_transformed.count() here - it causes OOM on large datasets

### Create Helpers

In [0]:
log_schema = StructType([
    StructField("index_name", StringType(), True),
    StructField("run_id", StringType(), True),
    StructField("partition_id", IntegerType(), True),
    StructField("indexed_count", IntegerType(), True),
    StructField("skipped_count", IntegerType(), True),
    StructField("parsing_error_count", IntegerType(), True),
    StructField("indexing_error_count", IntegerType(), True),
    StructField(
        "parsing_errors",
        ArrayType(
            StructType([
                StructField("row_id", LongType(), True),
                StructField("error", StringType(), True)
            ])
        ),
        True
    ),
    StructField(
        "indexing_errors",
        ArrayType(
            StructType([
                StructField("row_id", LongType(), True),
                StructField("error", StringType(), True)
            ])
        ),
        True
    )

])

@dataclass
class DatabricksEnvInfo:
    job_id: str = None
    run_id: str = None
    command_run_id: str = None
    notebook_path: str = None
    user: str = None

def get_databricks_env_info() -> DatabricksEnvInfo:
    attr = json.loads(
        dbutils.notebook.entry_point.getDbutils().notebook().getContext().safeToJson()
    )["attributes"]
    return DatabricksEnvInfo(
        command_run_id = attr.get("commandRunId"),
        job_id = attr.get("jobId"),
        run_id = attr.get("currentRunId"),
        notebook_path = attr.get("notebook_path"),
        user = attr.get("user"),
    )

dbx_env = get_databricks_env_info()

def get_run_identifier():
    if dbx_env.job_id and dbx_env.run_id:
        return f"{dbx_env.job_id}-{dbx_env.run_id}"
    else:
        return f"{dbx_env.user}-{dbx_env.command_run_id}"
    
def generate_prepared_actions(partition, parsing_errors, op_type = "index"):
    for row in partition:
        try:
            yield {
                "_op_type": op_type,
                "_index": ELASTIC_INDEX,
                "_id": row.id,
                "_source": row._source.asDict(True)
            }
        except Exception as e:
            parsing_errors.append({"row_id": row.id, "error": str(e)})             

def send_partition_to_elastic(partition, partition_id, op_type="create"):
    client = Elasticsearch(
        hosts=[ELASTIC_URL],
        request_timeout=180,
        max_retries=5,
        retry_on_timeout=True,
        http_compress=True,
    )

    indexed_count = 0
    parsing_errors = []
    indexing_errors = []
    skipped_count = 0
    op_type = op_type.lower()

    try:
        for success, info in helpers.parallel_bulk(client,
                generate_prepared_actions(partition, parsing_errors, op_type),
                chunk_size=500, thread_count=4, queue_size=10
            ):

            if success:
                indexed_count += 1
            else:
                error_info = info.get(op_type, {})
                status = error_info.get("status", 0)

                # âœ… Skip 409 (document already exists)
                if status == 409:
                    skipped_count += 1
                    continue

                id_url = error_info.get('_id')
                row_id = None
                if id_url:
                    try:
                        row_id = int(id_url.replace("https://openalex.org/W", ""))
                    except ValueError:
                        row_id = -1

                indexing_errors.append({
                    "row_id": row_id,
                    "error": str(info)[:1000]
                })
    except Exception as e:
        indexing_errors.append({
            "row_id": None,
            "error": "Parallel Bulk Error: " + str(e)[:1000]
        })
    finally:
        client.close()

    log_entry = {
        "index_name": ELASTIC_INDEX,
        "run_id": get_run_identifier(),
        "partition_id": partition_id,
        "success": len(indexing_errors) == 0,
        "indexed_count": indexed_count,
        "skipped_count": skipped_count,
        "parsing_error_count": len(parsing_errors),
        "indexing_error_count": len(indexing_errors),
        "message": f"{indexed_count} records indexed. {skipped_count} skipped. {len(parsing_errors)} parsing errors. {len(indexing_errors)} ES errors.",
        "parsing_errors": parsing_errors[:1000],
        "indexing_errors": indexing_errors[:1000],
    }

    yield log_entry

### Execute Sync with `mapPartitionsWithIndex`

In [0]:
logs_rdd = df_transformed.rdd.mapPartitionsWithIndex(
    lambda partition_idx, partition: send_partition_to_elastic(partition, partition_idx, "index")
)
logs_df = spark.createDataFrame(logs_rdd, log_schema)

log_count = logs_df.count() # mapPartitionsWithIndex is lazy, so force it to run
print(f"Processed {log_count} partitions")

In [0]:
try:
    # refresh index if exists
    client = Elasticsearch(hosts=[ELASTIC_URL], request_timeout=180)
    
    if client.indices.exists(index=ELASTIC_INDEX):
        client.indices.refresh(index=ELASTIC_INDEX)
        print(f"Refreshed index {ELASTIC_INDEX}")
        print(f"{client.count(index=ELASTIC_INDEX)['count']} documents in {ELASTIC_INDEX}")
        
        # Restore replicas after full sync
        if IS_FULL_SYNC:
            client.indices.put_settings(index=ELASTIC_INDEX, body={
                "index": {"number_of_replicas": 1}
            })
            print(f"Restored replicas to 1 on {ELASTIC_INDEX}")
    else:
        print(f"Index {ELASTIC_INDEX} does not exist")
finally:
    client.close()