### Load and Prepare Concept `model` and `vocabularies`

In [0]:
import pickle
import pandas as pd

import tensorflow as tf
import numpy as np

from pyspark.sql.functions import *
from pyspark.sql.types import *
    
with open("/Volumes/openalex/works/models/concept_tagger_v3/paper_title_vocab.pkl", "rb") as f:
    title_vocab = pickle.load(f)

with open("/Volumes/openalex/works/models/concept_tagger_v3/doc_type_vocab.pkl", "rb") as f:
    doc_type_vocab = pickle.load(f)

with open("/Volumes/openalex/works/models/concept_tagger_v3/journal_name_vocab.pkl", "rb") as f:
    journal_vocab = pickle.load(f)

with open("/Volumes/openalex/works/models/concept_tagger_v3/tag_id_vocab.pkl", "rb") as f:
    tag_id_vocab = pickle.load(f)    

with open("/Volumes/openalex/works/models/concept_tagger_v3/topics_vocab.pkl", "rb") as f:
    topics_vocab = pickle.load(f)
    inverted_topics_vocab = {v: k for k, v in topics_vocab.items()}

# broadcast model and vocabularies
# this results in pickle error - need to find other ways
# bc_model = spark.sparkContext.broadcast(model)
bc_title_vocab = spark.sparkContext.broadcast(title_vocab)
bc_journal_vocab = spark.sparkContext.broadcast(journal_vocab)
bc_doc_type_vocab = spark.sparkContext.broadcast(doc_type_vocab)
bc_tag_id_vocab = spark.sparkContext.broadcast(tag_id_vocab)
bc_topics_vocab = spark.sparkContext.broadcast(topics_vocab)
bc_inverted_topics_vocab = spark.sparkContext.broadcast(inverted_topics_vocab)

#### Prepare `openalex_works` Input

In [0]:
df = spark.sql("""
    WITH works_with_concept_keys AS (
        SELECT
            xxhash64(
                concat_ws('|',
                    title,
                    abstract,
                    primary_location.source.display_name,
                    primary_location.source.type
                )
            ) AS concept_key,
            FIRST(title) as title,
            FIRST(abstract) as abstract,
            FIRST(primary_location.source.display_name) AS journal,
            FIRST(primary_location.source.id) AS source_id,
            FIRST(primary_location.source.type) AS doc_type
        FROM openalex.works.openalex_works
        WHERE (concepts IS NULL OR size(concepts) = 0)
            AND title IS NOT NULL
            AND (
                (length(title) > 20 AND length(abstract) > 50)
                OR length(title) > 50
                OR length(abstract) > 150
            )
        GROUP BY concept_key
    )
    SELECT *
    FROM works_with_concept_keys w
    LEFT ANTI JOIN openalex.works.openalex_works_concepts_predicted p
    ON w.concept_key = p.concept_key
    LIMIT 20480000
""").repartition(4096).cache()

print(f"Total number of rows to process: {df.count()}")

### Load model per partition, execute via `mapPartitions`

In [0]:
import tensorflow as tf
import numpy as np

schema = StructType([
    StructField("concept_key", LongType(), nullable=False),
    StructField("concepts", ArrayType(
        StructType([
            StructField("id", LongType(), nullable=False),
            StructField("score", DoubleType(), nullable=False),
        ])
    ), nullable=True)
])

def tokenize(text, vocab, max_len):
    if text is None or text.strip() == "":
        # Fully padded zeros for missing or empty text
        return np.zeros(max_len, dtype=np.int64)
    tokens = text.lower().split()
    token_ids = [vocab.get(token, 0) for token in tokens][:max_len]
    token_ids += [0] * (max_len - len(token_ids))
    return np.array(token_ids, dtype=np.int64)

def process_batch(batch, model, title_vocab, journal_vocab, doc_type_vocab, tag_id_vocab):
    import tensorflow as tf
    import numpy as np

    # Prepare arrays
    concept_keys = []
    title_ids_batch = []
    abstract_ids_batch = []
    doc_type_ids = []
    journal_ids = []

    for row in batch:
        concept_keys.append(row.concept_key)
        title_ids_batch.append(tokenize(row.title or "", title_vocab, 32))
        abstract_ids_batch.append(tokenize(row.abstract or "", title_vocab, 256))
        doc_type_ids.append([doc_type_vocab.get((row.doc_type or '').lower(), 0)])
        journal_ids.append([journal_vocab.get((row.journal or '').lower(), 0)])

    # Convert to tensors
    title_ids_batch = tf.constant(np.stack(title_ids_batch))
    abstract_ids_batch = tf.constant(np.stack(abstract_ids_batch))
    doc_type_ids = tf.constant(np.array(doc_type_ids, dtype=np.int64))
    journal_ids = tf.constant(np.array(journal_ids, dtype=np.int64))

    outputs = model.signatures['serving_default'](
        paper_title_ids=title_ids_batch,
        abstract_ids=abstract_ids_batch,
        doc_type_id=doc_type_ids,
        journal_id=journal_ids
    )
    logits_batch = outputs['cls'].numpy()  # shape: [batch_size, num_concepts]

    score_threshold = 0.25
    for i in range(len(batch)):
        logits = logits_batch[i]
        top_k_idx = np.argpartition(-logits, kth=9)[:10]
        combined = np.unique(np.concatenate((np.where(logits >= score_threshold)[0], top_k_idx)))
        sorted_idx = combined[np.argsort(-logits[combined])][:65]

        concepts = [
            {"id": int(tag_id_vocab.get(idx, 0)), "score": float(logits[idx])}
            for idx in sorted_idx
        ]
        yield (concept_keys[i], concepts)

# Now apply on partitions:
def process_partition(rows_iter, batch_size=64):
    # Load model ONCE per partition
    model = tf.saved_model.load("/Volumes/openalex/works/models/concept_tagger_v3/model")

    batch = []
    for row in rows_iter:
        batch.append(row)
        if len(batch) >= batch_size:
            yield from process_batch(batch, model, bc_title_vocab.value,
                                     bc_journal_vocab.value, bc_doc_type_vocab.value, bc_tag_id_vocab.value)
            batch = []

    if batch:
        yield from process_batch(batch, model, bc_title_vocab.value,
                                    bc_journal_vocab.value, bc_doc_type_vocab.value, bc_tag_id_vocab.value)

# Apply with mapPartitions
result_rdd = df.rdd.mapPartitions(process_partition)

# Convert to DataFrame
inferred_concepts_df = spark.createDataFrame(result_rdd, schema).cache()

print(f"Total number of rows processed: {inferred_concepts_df.count()}")

In [0]:
# array<struct<id:bigint,wikidata:string,display_name:string,level:int,score:float>>
concepts_enriched_schema = ArrayType(StructType([
    StructField("id", LongType(), nullable=True),
    StructField("wikidata", StringType(), nullable=True),
    StructField("display_name", StringType(), nullable=True),
    StructField("level", IntegerType(), nullable=True),
    StructField("score", FloatType(), nullable=True)
]))
keywords_schema = ArrayType(StructType([
    StructField("id", StringType(), nullable=True),
    StructField("display_name", StringType(), nullable=True),
    StructField("score", FloatType(), nullable=True)
]))

final_df = (df.join(inferred_concepts_df, on="concept_key", how="left")
    .withColumn("concepts_enriched", lit(None).cast(concepts_enriched_schema))
    .withColumn("keywords", lit(None).cast(keywords_schema))
    .withColumn("created_timestamp", current_timestamp()))

final_df.write.mode("append").saveAsTable("openalex.works.openalex_works_concepts_predicted")

### Show stats

In [0]:
%sql
SELECT count(*), 
  count(DISTINCT concept_key),
  count_if(size(concepts) > 0) AS concepts_not_null,
  count_if(size(concepts_enriched) > 0) AS concepts_enriched_not_null
FROM openalex.works.openalex_works_concepts_predicted

In [0]:
%sql
SELECT count(*), count_if(keywords is null) 
from works.openalex_works_concepts_predicted

In [0]:
%sql
SELECT * FROM openalex.works.openalex_works_concepts_predicted 

In [0]:
# %sql
# WITH new_keyword_id AS (
#   SELECT display_name, 
#   regexp_replace(
#     regexp_replace(
#       regexp_replace(replace(lower(display_name), '\'', ''), '\\s*\\([^)]*\\)', ''),  -- remove " ( ... )"
#       '[^^\\p{L}\\p{N}\./–\*#]+', '-'                               -- non-alnum -> "-"
#     ),
#     '(^-+|-+$)', ''                                               -- trim leading/trailing "-"
#   ) as new_keyword_id, keyword_id 
#   from openalex.common.concepts
# )
# SELECT * FROM new_keyword_id 
# WHERE keyword_id is not null 
#   and new_keyword_id <> keyword_id

In [0]:
# %sql
# WITH concept_metadata AS (
#   SELECT
#     concept_id,
#     wikidata_id as wikidata,
#     display_name,
#     keyword_id,
#     use_as_keyword,
#     level
#   FROM openalex.common.concepts --WHERE wikidata_id IS NOT NULL
# ),
# concepts_exploded AS (
#   SELECT
#     concept_key,
#     explode(concepts_enriched) as concept
#   FROM openalex.works.openalex_works_concepts_predicted
#   WHERE keywords IS NULL AND concepts_enriched IS NOT NULL
# ),
# enriched_concepts_exploded AS (
#   SELECT
#     wc.concept_key,
#     concept,
#     STRUCT(
#         concat('https://openalex.org/keywords/',   
#           regexp_replace(
#             regexp_replace(
#               regexp_replace(replace(lower(display_name), '\'', ''), '\\s*\\([^)]*\\)', ''),  -- remove " ( ... )"
#               '[^^\\p{L}\\p{N}\./–\*#]+', '-' -- non-alnum -> "-"
#             ),
#             '(^-+|-+$)', '' -- trim leading/trailing "-"
#           )
#         ) as id,
#         c.display_name,
#         wc.concept.score
#       ) AS keyword
#   FROM concepts_exploded wc
#   JOIN concept_metadata c ON wc.concept.id = c.concept_id
#   WHERE c.use_as_keyword = true
# ),
# updates AS (
# select
#   first(concept_key) as concept_key,
#   array_sort(
#     array_agg(keyword),
#     (left, right) -> CASE
#       WHEN left.score > right.score THEN -1
#       WHEN left.score < right.score THEN 1
#       ELSE 0
#     END
#   ) AS keywords
# from enriched_concepts_exploded
# group by work_id
# )
# MERGE INTO openalex.works.openalex_works_concepts_predicted AS target
# USING updates AS src ON target.concept_key = src.concept_key
# WHEN MATCHED THEN UPDATE SET target.keywords = src.keywords;
# -- SELECT * FROM work_concepts_keywords; --7092554937, first update 219,825,895