In [0]:
%pip install elasticsearch==8.19.0
%restart_python

### Prepare Input

In [0]:
df = spark.sql("""
WITH exploded AS (
    SELECT id as work_id, cited_by_count, explode(keywords) as keyword
    FROM openalex.works.openalex_works
),
-- one row per (work_id, keyword_id)
dedup AS (
  SELECT work_id, cited_by_count, keyword
  FROM exploded
  QUALIFY row_number() OVER (PARTITION BY work_id, keyword.id ORDER BY work_id, keyword.id) = 1
),
-- Aggregate on unique keywords
aggregated_counts AS (
  SELECT
    keyword.id as id,
    keyword.display_name as display_name,
    count(DISTINCT work_id) as works_count,
    sum(cited_by_count) as cited_by_count
  FROM dedup
  GROUP BY 1, 2
)
-- Join with the common keywords table to get metadata
SELECT
  ac.id as id,
  STRUCT(
    ac.id,
    ac.display_name,
    ac.works_count,
    ac.cited_by_count,
    CONCAT("https://api.openalex.org/works?filter=keywords.id:keywords/", kw.keyword_id) AS works_api_url,
    kw.updated_datetime AS updated_date,
    date(kw.created_datetime) as created_date
  ) as _source
FROM aggregated_counts ac
JOIN openalex.common.keywords kw
  ON kw.keyword_id = replace(ac.id, 'https://openalex.org/keywords/', '')""")

rows = df.collect()

print(f"Keywords count: {len(rows)}")

In [0]:
from elasticsearch import Elasticsearch, helpers
import json

ELASTIC_INDEX = "keywords-v1"
ELASTIC_URL = dbutils.secrets.get(scope="elastic", key="elastic_url")

client = Elasticsearch(
    hosts = [ELASTIC_URL],
    request_timeout = 180,
    max_retries = 5,
    retry_on_timeout = True
)

def actions_from_spark(rows, op_type = "index"):
    for row in rows:
        yield {
            "_op_type": op_type,
            "_index": ELASTIC_INDEX,
            "_id": row.id,
            "_source": row._source.asDict(True)
        }

# Delete old index
if client.indices.exists(index=ELASTIC_INDEX):
    client.indices.delete(index=ELASTIC_INDEX)

ok = fail = 0
for success, info in helpers.streaming_bulk(client, actions_from_spark(rows),
    chunk_size=2000, request_timeout=60, max_retries=3):
    if success:
        ok += 1
    else:
        fail += 1

print(f"Indexed ok={ok}, failed={fail}")

In [0]:
client.indices.refresh(index=ELASTIC_INDEX)

In [0]:
client.count(index=ELASTIC_INDEX)