### Utilities

In [0]:
from transformers import TFAutoModelForSequenceClassification, pipeline, AutoTokenizer
import torch
import os

print(f"CUDA device count: {torch.cuda.device_count()}")

BATCH_SIZE = 150
MODEL_PATH = "/Volumes/openalex/works/models/topic-classification-title-abstract"

class ModelCache:
    model = None

    @classmethod
    def load(cls):
        if cls.model is not None:
            return

        # Determine available GPU count
        num_devices = torch.cuda.device_count()
        if num_devices == 0:
            cls.assigned_device = -1
        else:
            # Assign device using pid hash
            cls.assigned_device = os.getpid() % num_devices

        cls.model = pipeline(
            task = "text-classification",
            model = MODEL_PATH,
            device = cls.assigned_device,  # ✅ Assign GPU or CPU (-1)
            # device_map="auto",
            top_k=3,
            batch_size = BATCH_SIZE,
            truncation = True,
            max_length = 512
        )

### Load input data and cache

In [0]:
df = (spark.table("openalex.works.work_topics_frontfill_input")
      .select("work_id", "title", "abstract", "journal_name")
      .limit(6700000)
      .repartition(384)
)

print(f"Input Row Count: {df.count()}")

### Secondary Fill - create abstract from Journal Name and Concepts

In [0]:
# df = spark.sql("""
# SELECT --format_number(count(*),0)
#     id as work_id,
#     title,
#     concat('[Journal Name] ', primary_location.source.display_name, '\n[Key Concepts] ', 
#       concat_ws(', ', slice(concepts.display_name,1,3))) as abstract,
#     primary_location.source.display_name as journal_name
# FROM openalex.works.openalex_works
# LEFT ANTI JOIN openalex.works.work_topics_lm_output lm ON id = lm.work_id
# WHERE (topics IS NULL OR size(topics) = 0)
#     AND id > 6600000000
#     AND length(title) > 10
#     AND length(abstract) > 30
# LIMIT 6760000;
# """).repartition(384).cache()
# print(f"Input Row Count: {df.count()}")

### Run inference

In [0]:
import json
from pyspark.sql.functions import *
from pyspark.sql.types import *
from topic_predictor import *

import time
import re

def process_partition(rows, batch_size=BATCH_SIZE):
    ModelCache.load()
    model = ModelCache.model

    batch_rows = []
    batch_texts = []

    def yield_batch(rows_batch, texts_batch):
        try:
            batch_outputs = model(texts_batch)
        except Exception as e:
            batch_outputs = [[] for _ in texts_batch]  # fail-safe: empty predictions

        for row, output in zip(rows_batch, batch_outputs):
            row_dict = row.asDict()
            lm_output = [
                {
                    "topic_id": 10000 + int(topic["label"].split(":")[0]),
                    "score": float(topic["score"])
                }
                for topic in output
            ] if output else None

            row_dict["lm_topics"] = lm_output
            yield row_dict

    for row in rows:
        if row is None:
            continue

        title = clean_title(row['title']) or ""
        abstract = clean_abstract(row['abstract']) or ""
        full_text = f"[CLS]<TITLE> {title.strip()} <ABSTRACT> {abstract.strip()} [SEP]"

        batch_rows.append(row)
        batch_texts.append(full_text)

        if len(batch_texts) >= batch_size:
            yield from yield_batch(batch_rows, batch_texts)
            batch_rows = []
            batch_texts = []

    # Process remaining rows
    if batch_rows:
        yield from yield_batch(batch_rows, batch_texts)

topic_struct = StructType([
    StructField("topic_id", IntegerType(), True),
    StructField("score", FloatType(), True)
])

output_schema = StructType([
    StructField("work_id", StringType(), True),
    StructField("title", StringType(), True),
    StructField("abstract", StringType(), True),
    StructField("journal_name", StringType(), True),
    StructField("lm_topics", ArrayType(topic_struct), True)
])

start_time = time.time()

res_rdd = df.select("work_id", "title", "abstract", "journal_name").rdd.mapPartitions(process_partition)
res_df = spark.createDataFrame(res_rdd, output_schema).cache()
output_count = res_df.count()
print(f"Output Row count: {output_count}")

res_df = (res_df.select("work_id", "title", "abstract", "journal_name", "lm_topics")
                .withColumn("lm_primary_topic", col("lm_topics")[0])
                .withColumn("source", lit("bert_lm"))
                .withColumn("created_timestamp", current_timestamp()))
res_df.write.mode("append").saveAsTable("openalex.works.work_topics_lm_output")

runtime = time.time() - start_time
print(f"Total runtime: {runtime:.4f} seconds")
print(f"Total throughput: {output_count / (runtime):.4f} inferences/sec")

In [0]:
%sql
SELECT * FROM openalex.works.work_topics_lm_output;

### Transform `lm_output` to OpenAlex Topics Structs, Insert to frontfill if does not exist

In [0]:
%sql
MERGE INTO openalex.works.work_topics_frontfill AS target
USING (
  WITH topics_metadata AS (
    SELECT
      topic_id,
      t.display_name,
      NAMED_STRUCT(
        'id', concat('https://openalex.org/subfields/', s.subfield_id),
        'display_name', s.display_name
      ) AS subfield,
      NAMED_STRUCT(
        'id', concat('https://openalex.org/fields/', f.field_id),
        'display_name', f.display_name
      ) AS field,
      NAMED_STRUCT(
        'id', concat('https://openalex.org/domains/', d.domain_id),
        'display_name', d.display_name
      ) AS domain
    FROM openalex.common.topics t
    JOIN openalex.common.subfields s USING (subfield_id)
    JOIN openalex.common.fields f USING (field_id)
    JOIN openalex.common.domains d USING (domain_id)
  ),

  lm_output_exploded AS (
    SELECT 
      work_id,
      explode(lm_topics) AS result,
      source,
      created_timestamp
    FROM openalex.works.work_topics_lm_output
  )

  SELECT
    work_id,
    slice(array_sort(
      array_agg(
        NAMED_STRUCT(
          'id', concat('https://openalex.org/T', result.topic_id),
          'display_name', tm.display_name,
          'score', result.score,
          'subfield', tm.subfield,
          'field', tm.field,
          'domain', tm.domain
        )
      ),
      (left, right) -> CASE
        WHEN left.score > right.score THEN -1
        WHEN left.score < right.score THEN 1
        ELSE 0
      END
    ), 1, 3) AS topics,
    first(wt.source) AS source,
    max(wt.created_timestamp) AS created_datetime,
    max(wt.created_timestamp) AS updated_datetime
  FROM lm_output_exploded wt
  JOIN topics_metadata tm ON tm.topic_id = result.topic_id
  GROUP BY work_id
) AS source
ON target.work_id = source.work_id

-- Insert only if the work_id does not exist
WHEN NOT MATCHED THEN INSERT (
  work_id,
  topics,
  source,
  created_datetime,
  updated_datetime
) VALUES (
  source.work_id,
  source.topics,
  source.source,
  source.created_datetime,
  source.updated_datetime
);
