In [0]:
import mlflow

MODEL_PATH = "/Volumes/openalex/works/models/topic-classification-title-abstract"

predict_udf = mlflow.pyfunc.spark_udf(spark, MODEL_PATH, env_manager="conda")
df_result = df_spark.withColumn("prediction", predict_udf(*df_spark.columns))

In [0]:
# %sql
# CREATE OR REPLACE TABLE openalex.works.works_topics_compare AS
# SELECT id as work_id, title, abstract, primary_location.source.display_name as journal_name, primary_topic, topics
# FROM openalex.works.openalex_works
# WHERE length(title) > 50 AND length(abstract) > 100 
# AND size(topics) > 0 AND cited_by_count > 50 AND primary_location.source.display_name IS NOT NULL
# QUALIFY row_number() over (partition by primary_topic.id order by id asc) <= 20

In [0]:
from transformers import TFAutoModelForSequenceClassification, pipeline, AutoTokenizer
import torch
import os

print(f"CUDA device count: {torch.cuda.device_count()}")

bert_classifier = pipeline(model="OpenAlex/bert-base-multilingual-cased-finetuned-openalex-topic-classification-title-abstract", top_k=5, batch_size=16, truncation=True, max_length=512)

test_input = """<TITLE>Supplemental Material: Estimating paleotidal constituents from Pliocene “tidal gauges”—an example from the paleo-Orinoco Delta, Trinidad"""
bert_classifier(test_input)

class ModelCache:
    model = None

    @classmethod
    def load(cls):
        if cls.model is not None:
            return

        # Determine available GPU count
        num_devices = torch.cuda.device_count()
        if num_devices == 0:
            raise RuntimeError("No CUDA devices available.")

        # Assign device using pid hash
        pid = os.getpid()
        device_id = pid % num_devices
        cls.assigned_device = device_id

        print(f"Loading model on GPU:{device_id} for pid:{pid}")

        cls.model = pipeline(
            model="OpenAlex/bert-base-multilingual-cased-finetuned-openalex-topic-classification-title-abstract",
            tokenizer="OpenAlex/bert-base-multilingual-cased-finetuned-openalex-topic-classification-title-abstract",
            device=device_id,  # ✅ Assign GPU
            top_k=5,
            batch_size=12,
            truncation=True,
            max_length=512
        )




### Testing `pseudo_abstract` - with Journal Name and Concepts if actual `abstract` is not present

In [0]:
test_input = """[CLS]<TITLE>Bupropion and Male Reproductive Health: Exploring Its Role in Sexual Dysfunction and Fertility[SEP]"""
print(bert_classifier(test_input))

test_input = """[CLS]<TITLE>Bupropion and Male Reproductive Health: Exploring Its Role in Sexual Dysfunction and Fertility<ABSTRACT>[Journal Name] DataCite API [Key Concepts] Nucleofection, Gestational period, TSG101[SEP]"""
print(bert_classifier(test_input))

## Examples from Justin's Topic Classification Paper
https://docs.google.com/document/d/1bDopkhuGieQ4F8gGNj7sEc8WSE8mvLZS/edit?tab=t.0

### Example #1
Shown as example of producing correct output on limited data - Exact Match

In [0]:
test_input = """[CLS]<TITLE>Multidisciplinary Team Approach to Cleft Lip and Palate Management"""
bert_classifier(test_input)

### Example #2
Found the abstract and the labeled outcome `Corporate Social Responsibility and Sustainability in Business` which was stated as being wrong captures the point of the paper more. 

In [0]:
test_input_1 = """[CLS]<TITLE>Mitverbrennung von Sekundärbrennstoffen[SEP]"""
test_input_2 = """[CLS]<TITLE>Mitverbrennung von Sekundärbrennstoffen<ABSTRACT>Auf europäischer Ebene wird unter festen Sekundärbrennstoffen ein breites Spektrum an für die energetische Verwertung bereitgestellten Fraktionen verstanden, die aus nicht gefährlichen Abfällen hergestellt werden. Dieses reicht dabei von produktionsspezifischen oder gewerblichen Abfällen, über Siedlungsabfall, bis hin zu Bau‐ und Abbruchabfall oder auch Altholz (DIN 2012). Auch auf nationaler Ebene werden Brennstoffe, die aus festen, heizwertreichen, nicht gefährlichen Abfällen hergestellt werden, als Sekundärbrennstoffe bezeichnet, die nach einem umfassenden Aufbereitungsprozess speziell für die Mitverbrennung bereitgestellt werden. Insgesamt hat sich die Mitverbrennung von Sekundärbrennstoffen in Industriefeuerungsanlagen und Zementwerken als feste Säule einer modernen Kreislaufwirtschaft etabliert, da hier heizwertreiche Stoffe als emissionsarme Energieträger hochwertig verwertet werden können. Durch die Einsparung von Primärbrennstoffen, der Rückgewinnung von Eisen‐ und Nichteisenmetallen sowie – im Fall der Mitverbrennung in Zementwerken – der stofflichen Verwertung des mineralischen Anteils kann so zum Klima‐ und Ressourcenschutz beigetragen werden. So wurden allein in der deutschen Zementindustrie im Jahr 2014 durch den Einsatz geeigneter Sekundärbrennstoffe fossile brennstoffspezifische Kohlendioxid(CO2)‐Emissionen von etwa 1,5 Mio. Megagramm (Mg) sowie etwa 2 Mio. Mg Steinkohleeinheiten an fossilen Primärenergieträgern eingespart (Oerter 2017a).[SEP]"""
print(bert_classifier(test_input_1))
print(bert_classifier(test_input_2))
# 12959	Engineering and Materials Science Studies - Mechanical Engineering - Engineering - Physical Sciences
# Keywords: Lightweight Materials; Automotive Engineering; Composite Structures; Ultrasonic Testing; 
# Thermoplastic Composites; Additive Manufacturing; Structural Analysis; Hot-Dip Galvanization; 
# Machine Elements; Renewable Resources
###############
# 12565	Physics and Engineering Research Articles - Computational Mechanics	- Engineering - Physical Sciences
# Keywords: Fluid Dynamics; Renewable Energy; Engineering; Materials Science; Renewable Energy Integration; 
# Vibration Analysis; Thermal Management; Sustainability; Turbine Technology; Mathematical Modeling

### Example #3
Exact Match

In [0]:
test_input = """[CLS]<TITLE>Novel, Potentially Zoonotic Paramyxoviruses from the African Straw-Colored Fruit Bat Eidolon helvum<ABSTRACT>Bats carry a variety of paramyxoviruses that impact human and domestic animal health when spillover occurs. Recent studies have shown a great diversity of paramyxoviruses in an urban-roosting population of straw-colored fruit bats in Ghana. Here, we investigate this further through virus isolation and describe two novel rubulaviruses: Achimota virus 1 (AchPV1) and Achimota virus 2 (AchPV2). The viruses form a phylogenetic cluster with each other and other bat-derived rubulaviruses, such as Tuhoko viruses…[SEP]"""
print(bert_classifier(test_input))

### Example 4
Exact Match

In [0]:
test_input = """[CLS]<TITLE>Bioinformatics basics: applications in biological science and medicine<ABSTRACT>Contents BIOLOGY AND INFORMATION Bioinformatics-A Rapidly Maturing Science Computers in Biology and Medicine The Virtual Doctor Biological Macromolecules as Information Carriers Proteins: From Sequence to Structure to Function DNA and RNA Structure DNA Cloning and Sequencing Genes, Taxonomy, and Evolution BIOLOGICAL DATABASES Biological Database Organization Public Databases Database Mining Tools GENOME ANALYSIS The Genomic Organization…[SEP]"""
print(bert_classifier(test_input))

### Example 5
Highlights were the NN Model is wrong, `1122: Educational Data Mining and Learning Analytics` as likely the best match which BERT misses as well due to not understanding SGPA/CGPA terms. Quality of this Abstract is not that great.

In [0]:
test_input = """[CLS]<TITLE>Machine Learning Approach to Predict SGPA and CGPA[SEP]<ASBTRACT>The prediction of SGPA and CGPA is beneficial to university students. Students will easily get an estimate of their final outcome from this project. As a result, the students will be able to brace themselves for a successful outcome. Students pass the day by participating in a variety of events. Students use social media sites such as Facebook, Instagram, and Twitter. They engage in various hobbies such as playing mobile games, listening to music, among others. As a result, they were able to move several times with these tasks. As a result, if a student spends so much time doing any of those things, she will not be able…[SEP]"""
print(bert_classifier(test_input))

In [0]:
import pickle
inv_target_vocab = None
with open(f"/Volumes/openalex/works/models/topic_classifier_v1/inv_target_vocab.pkl", "rb") as f:
    inv_target_vocab = pickle.load(f)

print("Length:", len(inv_target_vocab))
print("Sample Keys:", list(inv_target_vocab.keys())[:5])
print("Sample Values:", list(inv_target_vocab.values())[:5])

bert_classifier.model.config.label2id["3338: Effects of Beta-Adrenergic Agonists in Livestock"]
target_vocab["3338: Effects of Beta-Adrenergic Agonists in Livestock"]



In [0]:
# Step 1: Build topic_id → better description from target_vocab
topic_id_to_desc = {}
for idx, raw_label in inv_target_vocab.items():
    topic_id, desc = raw_label.split(":", 1)
    topic_id_to_desc[int(topic_id.strip())] = desc.strip()

# Step 2: Update pipeline model.config.id2label
new_id2label = {}

for k, v in bert_classifier.model.config.id2label.items():
    # Extract topic_id from value
    try:
        topic_id_str, _ = v.split(":", 1)
        topic_id = int(topic_id_str.strip())

        # Replace label if we have a better one
        if topic_id in topic_id_to_desc:
            new_id2label[k] = f"{topic_id}: {topic_id_to_desc[topic_id]}"
        else:
            new_id2label[k] = v  # fallback to original
    except Exception:
        new_id2label[k] = v  # fallback to original if parsing fails

# Step 4: Print differences
print("🔍 Differences between old and new id2label entries:")
for k in sorted(bert_classifier.model.config.id2label.keys()):
    old_val = bert_classifier.model.config.id2label[k]
    new_val = new_id2label[k]
    if old_val != new_val:
        print(f"ID {k}:\n  OLD: {old_val}\n  NEW: {new_val}\n")

# new_id2label

### Load input data and cache

In [0]:
df = spark.sql("SELECT * FROM openalex.works.works_topics_compare").repartition(96).cache()
print(f"Input Row Count: {df.count()}")

### Run inference

In [0]:
import json
from pyspark.sql.functions import *
from pyspark.sql.types import *
from topic_predictor import *
import re

def process_partition(rows, batch_size=12):
    ModelCache.load()
    model = ModelCache.model

    batch_rows = []
    batch_texts = []

    def yield_batch(rows_batch, texts_batch):
        try:
            batch_outputs = model(texts_batch)
        except Exception as e:
            batch_outputs = [[] for _ in texts_batch]  # fail-safe: empty predictions

        for row, output in zip(rows_batch, batch_outputs):
            row_dict = row.asDict()
            lm_output = [
                {
                    "topic_id": 10000 + int(topic["label"].split(":")[0]),
                    "score": float(topic["score"])
                }
                for topic in output
            ] if output else None

            row_dict["lm_topics"] = lm_output
            yield row_dict

    for row in rows:
        if row is None:
            continue

        title = clean_title(row['title']) or ""
        abstract = clean_abstract(row['abstract']) or ""
        full_text = f"<TITLE>{title.strip()}<ABSTRACT>{abstract.strip()}"

        batch_rows.append(row)
        batch_texts.append(full_text)

        if len(batch_texts) >= batch_size:
            yield from yield_batch(batch_rows, batch_texts)
            batch_rows = []
            batch_texts = []

    # Process remaining rows
    if batch_rows:
        yield from yield_batch(batch_rows, batch_texts)

topic_struct = StructType([
    StructField("topic_id", IntegerType(), True),
    StructField("score", FloatType(), True)
])

output_schema = StructType([
    StructField("work_id", StringType(), True),
    StructField("title", StringType(), True),
    StructField("abstract", StringType(), True),
    StructField("journal_name", StringType(), True),
    StructField("lm_topics", ArrayType(topic_struct), True)
])

res_rdd = df.select("work_id", "title", "abstract", "journal_name").rdd.mapPartitions(process_partition)
res_df = spark.createDataFrame(res_rdd, output_schema).cache()
print(f"Output Row count: {res_df.count()}")

res_df = (res_df.select("work_id", "title", "abstract", "journal_name", "lm_topics")
                .withColumn("lm_primary_topic", col("lm_topics")[0])
                .withColumn("created_timestamp", current_timestamp()))
display(res_df)

In [0]:
%python
res_df.write.mode("overwrite").saveAsTable("openalex.works.works_topics_compare_lm")

In [0]:
%sql 
SELECT topics.id from openalex.works.works_topics_compare

## Combined Test Results - comparing existing `NN` Model and `LM` (google-bert fine-tuned with OpenAlex Topics)

In [0]:
%sql
WITH comparison AS (
  SELECT
    a.work_id,
    a.primary_topic,
    b.lm_primary_topic.topic_id AS lm_primary_topic,
    a.topics,
    b.lm_topics,
    -- Boolean flag for primary_topic match
    a.primary_topic.id = concat('https://openalex.org/T',b.lm_primary_topic.topic_id) AS primary_topic_match,
    -- Boolean flag: is LM primary in the legacy topic list
    array_contains(a.topics.id, concat('https://openalex.org/T',b.lm_primary_topic.topic_id))
      AS lm_primary_in_topics,
    a.title,
    a.abstract
  FROM openalex.works.works_topics_compare a
  JOIN openalex.works.works_topics_compare_lm b
    ON a.work_id = b.work_id
)
SELECT
*
FROM comparison TABLESAMPLE (1 PERCENT) 
where primary_topic_match = false 
and primary_topic.id not in ("https://openalex.org/T10005", "https://openalex.org/T10232")



In [0]:
%sql
WITH comparison AS (
  SELECT
    a.work_id,
    a.primary_topic,
    b.lm_primary_topic.topic_id AS lm_primary_topic,
    a.topics,
    b.lm_topics,
    -- Boolean flag for primary_topic match
    a.primary_topic.id = concat('https://openalex.org/T',b.lm_primary_topic.topic_id) AS primary_topic_match,
    -- Boolean flag: is LM primary in the legacy topic list
    array_contains(a.topics.id, concat('https://openalex.org/T',b.lm_primary_topic.topic_id))
      AS lm_primary_in_topics,
    a.title,
    a.abstract
  FROM openalex.works.works_topics_compare a
  JOIN openalex.works.works_topics_compare_lm b
    ON a.work_id = b.work_id
)
SELECT count(distinct primary_topic.id) as unique_topics, 
count(distinct lm_primary_topic) as unique_lm_topics 
FROM comparison

In [0]:
%sql
WITH comparison AS (
  SELECT
    a.work_id,
    a.primary_topic,
    b.lm_primary_topic.topic_id AS lm_primary_topic,
    a.topics,
    b.lm_topics,
    -- Boolean flag for primary_topic match
    a.primary_topic.id = concat('https://openalex.org/T',b.lm_primary_topic.topic_id) AS primary_topic_match,
    -- Boolean flag: is LM primary in the legacy topic list
    array_contains(a.topics.id, concat('https://openalex.org/T',b.lm_primary_topic.topic_id))
      AS lm_primary_in_topics,
    a.title,
    a.abstract
  FROM openalex.works.works_topics_compare a
  JOIN openalex.works.works_topics_compare_lm b
    ON a.work_id = b.work_id
)
SELECT
  COUNT(*) AS sample_size,
  COUNT_IF(primary_topic_match) AS match_count,
  ROUND(100.0 * COUNT_IF(primary_topic_match) / COUNT(*), 1) AS match_percent,
  COUNT_IF(lm_primary_in_topics) AS lm_primary_in_topics_count,
  ROUND(100.0 * COUNT_IF(lm_primary_in_topics) / COUNT(*), 1) AS lm_primary_in_topics_percent
FROM comparison



In [0]:
%sql
SELECT *, 
  CASE WHEN lm_primary_topic_id = primary_topic.id THEN 1 ELSE 0 END as is_primary_same
 FROM openalex.works.works_topics_compare_lm

### Determine Metadata Mismatch for WORK IDS between PROD and Walden

In [0]:
%sql
SELECT max(id) FROM openalex.works.work_prod where doi is not null

In [0]:
%sql
--4,412,109,243 - Prod MAX(id) as of mid July
--6,600,000,001 - Walden MIN(id) minted
DESCRIBE HISTORY openalex.works.openalex_works

In [0]:
%sql
SELECT * FROM openalex.works.openalex_works where id = 2127407834

In [0]:
%sql
-- Count totals, intersects, and uniques with formatted output
WITH
prod_ids AS (
  SELECT id FROM openalex.works.work_prod
),
walden_ids AS (
  SELECT id FROM openalex.works.openalex_works
),
intersect_ids AS (
  SELECT id FROM prod_ids
  INTERSECT
  SELECT id FROM walden_ids
)
SELECT
  format_number((SELECT COUNT(*) FROM prod_ids), 0) AS prod_ids_total,
  format_number((SELECT COUNT(*) FROM walden_ids), 0) AS walden_ids_total,
  format_number((SELECT COUNT(*) FROM intersect_ids), 0) AS id_match,
  format_number((SELECT COUNT(*) FROM prod_ids WHERE id NOT IN (SELECT id FROM intersect_ids)), 0) AS prod_only,
  format_number((SELECT COUNT(*) FROM walden_ids WHERE id NOT IN (SELECT id FROM intersect_ids)), 0) AS walden_only;


In [0]:
%sql
-- Count ID matches where DOI also matches
WITH
prod AS (
  SELECT id, doi FROM openalex.works.work_prod -- from July 9 export.
),
walden AS (
  SELECT id, ids.doi AS walden_doi FROM openalex.works.openalex_works
),
intersected AS (
  SELECT
    p.id,
    p.doi AS prod_doi,
    w.walden_doi
  FROM prod p
  JOIN walden w ON p.id = w.id
)
SELECT
  format_number(COUNT(*), 0) AS total_walden_prod_work_id_match,
  format_number(COUNT_IF(prod_doi = walden_doi OR (prod_doi IS NULL AND walden_doi IS NULL)), 0) AS matched_by_doi_or_both_null,
  format_number(COUNT_IF(prod_doi != walden_doi), 0) AS mismatched_doi_non_null,
  format_number(COUNT_IF(prod_doi IS NULL AND walden_doi IS NOT NULL), 0) AS only_prod_doi_is_null,
  format_number(COUNT_IF(prod_doi IS NOT NULL AND walden_doi IS NULL), 0) AS only_walden_doi_is_null
FROM intersected;


In [0]:
%sql
select null = null

In [0]:
%sql
-- Count ID matches where DOI also matches
WITH
prod AS (
  SELECT id, doi FROM openalex.works.work_prod where doi is not null -- from July 9 export.
),
walden AS (
  SELECT id, ids.doi AS walden_doi FROM openalex.works.openalex_works where ids.doi is not null
),
intersected AS (
  SELECT
    p.id,
    p.doi AS prod_doi,
    w.walden_doi
  FROM prod p
  JOIN walden w ON p.id = w.id
)
SELECT
  format_number(COUNT(*), 0) AS total_id_intersection,
  format_number(COUNT_IF(prod_doi = walden_doi), 0) AS matched_by_doi,
  format_number(COUNT(*) - COUNT_IF(prod_doi = walden_doi), 0) AS mismatched_doi
FROM intersected;


In [0]:
%sql
SELECT * FROM openalex.works.work_prod
--where doi = '10.1128/jvi.00116-06'
where id = 2127407834

In [0]:
%sql
CREATE OR REPLACE TABLE openalex.works.work_topics_frontfill CLUSTER BY (work_id) AS
SELECT * FROM openalex.works.work_topics_backfill LIMIT 1000

In [0]:
%sql
SELECT format_number(count(*),0) as total, 
 format_number(count_if(topics is null or size(topics) = 0),0) as no_topics,
 format_number(count_if(length(abstract) > 0 
  and (topics is null or size(topics) = 0)),0) as has_abstract_no_topics,
 format_number(count_if(length(fulltext) > 0 
  and (topics is null or size(topics) = 0)),0) as has_fulltext_no_topics,
 format_number(count_if(length(fulltext) > 0 
  and (topics is null or size(topics) = 0) and not(length(abstract) > 0)),0) as has_fulltext_no_abstract_no_topics  
FROM openalex.works.openalex_works

In [0]:
%sql
select id, title, abstract, primary_location.source.display_name as journal_name
from openalex.works.openalex_works where topics is null