### Initialize LM Classifier

In [0]:
%pip install tf-keras torch

In [0]:
from transformers import TFAutoModelForSequenceClassification, pipeline, AutoTokenizer

classifier_multi = pipeline(model="OpenAlex/bert-base-multilingual-cased-finetuned-openalex-topic-classification-title-abstract", top_k=3)
classifier_multi("""<TITLE>Supplemental Material: Estimating paleotidal constituents from Pliocene “tidal gauges”—an example from the paleo-Orinoco Delta, Trinidad""")

### Test with samples

In [0]:
olfactory_input = """<TITLE>The Shape of the Olfactory Bulb Predicts Olfactory Function<ABSTRACT>The olfactory bulb (OB) plays a key role in the processing of olfactory information. A large body of research has shown that OB volumes correlate with olfactory function, which provides diagnostic and prognostic information in olfactory dysfunction. Still, the potential value of the OB shape remains unclear. Based on our clinical experience we hypothesized that the shape of the OB predicts olfactory function, and that it is linked to olfactory loss, age, and gender. The aim of this study was to produce a classification of OB shape in the human brain, scalable to clinical and research applications. Results from patients with the five most frequent causes of olfactory dysfunction (n = 192) as well as age/gender-matched healthy controls (n = 77) were included. Olfactory function was examined in great detail using the extended “Sniffin’ Sticks” test. A high-resolution structural T2-weighted MRI scan was obtained for all. The planimetric contours (surface in mm2) of OB were delineated manually, and then all surfaces were added and multiplied to obtain the OB volume in mm3. OB shapes were outlined manually and characterized on a selected slice through the posterior coronal plane tangential to the eyeballs. We looked at OB shapes in terms of convexity and defined two patterns/seven categories based on OB contours: convex"""

paleotidal_input = """<TITLE>Supplemental Material: Estimating paleotidal constituents from Pliocene “tidal gauges”—an example from the paleo-Orinoco Delta, Trinidad"""

print(classifier_multi(olfactory_input))
print(classifier_multi(paleotidal_input))

# IN PROD (first 2 topics match)
# 10971	Olfactory and Sensory Function Studies
# 11667	Advanced Chemical Sensor Technologies
# 14144	Neurological Disease Mechanisms and Treatments

### Setup dependencies

In [0]:
from transformers import TFAutoModelForSequenceClassification, pipeline, AutoTokenizer
import torch
import os
import unicodedata

print(f"CUDA device count: {torch.cuda.device_count()}")

class ModelCache:
    model = None

    @classmethod
    def load(cls):
        if cls.model is not None:
            return

        # Determine available GPU count
        num_devices = torch.cuda.device_count()
        if num_devices == 0:
            raise RuntimeError("No CUDA devices available.")

        # Assign device using pid hash
        pid = os.getpid()
        device_id = pid % num_devices
        cls.assigned_device = device_id

        print(f"Loading model on GPU:{device_id} for pid:{pid}")

        cls.model = pipeline(
            model="OpenAlex/bert-base-multilingual-cased-finetuned-openalex-topic-classification-title-abstract",
            tokenizer="OpenAlex/bert-base-multilingual-cased-finetuned-openalex-topic-classification-title-abstract",
            device=device_id,  # ✅ Assign GPU
            top_k=5,
            batch_size=12,
            truncation=True,
            max_length=512
        )
def remove_non_latin_characters(text):
    """
    Function to remove non-latin characters.

    Input:
    text: string of characters

    Output:
    final_char: string of characters with non-latin characters removed
    """
    final_char = []
    groups_to_skip = ['HIRAGANA', 'CJK', 'KATAKANA','ARABIC', 'HANGUL', 'THAI','DEVANAGARI','BENGALI',
                      'THAANA','GUJARATI','CYRILLIC']
    for char in text:
        try:
            script = unicodedata.name(char).split(" ")[0]
            if script not in groups_to_skip:
                final_char.append(char)
        except:
            pass
    return "".join(final_char)
    
def group_non_latin_characters(text):
    """
    Function to group non-latin characters and return the number of latin characters.

    Input:
    text: string of characters

    Output:
    groups: list of character groups
    latin_chars: number of latin characters
    """
    groups = []
    latin_chars = []
    text = text.replace(".", "").replace(" ", "")
    for char in text:
        try:
            script = unicodedata.name(char).split(" ")[0]
            if script == 'LATIN':
                latin_chars.append(script)
            else:
                if script not in groups:
                    groups.append(script)
        except:
            if "UNK" not in groups:
                groups.append("UNK")
    return groups, len(latin_chars)

def check_for_non_latin_characters(text):
    """
    Function to check if non-latin characters are dominant in a text.

    Input:
    text: string of characters

    Output:
    0: if text should be not used
    1: if text should be used
    """
    groups, latin_chars = group_non_latin_characters(str(text))
    if name_to_keep_ind(groups) == 1:
        return 1
    elif latin_chars > 20:
        return 1
    else:
        return 0
    
def clean_title(old_title):
    """
    Function to check if title should be kept and then remove non-latin characters. Also
    removes some HTML tags from the title.
    
    Input:
    old_title: string of title
    
    Output:
    new_title: string of title with non-latin characters and HTML tags removed
    """
    keep_title = check_for_non_latin_characters(old_title)
    if keep_title == 1:
        new_title = remove_non_latin_characters(old_title)
        if '<' in new_title:
            new_title = new_title.replace("<i>", "").replace("</i>","")\
                                 .replace("<sub>", "").replace("</sub>","") \
                                 .replace("<sup>", "").replace("</sup>","") \
                                 .replace("<em>", "").replace("</em>","") \
                                 .replace("<b>", "").replace("</b>","") \
                                 .replace("<I>", "").replace("</I>", "") \
                                 .replace("<SUB>", "").replace("</SUB>", "") \
                                 .replace("<scp>", "").replace("</scp>", "") \
                                 .replace("<font>", "").replace("</font>", "") \
                                 .replace("<inf>","").replace("</inf>", "") \
                                 .replace("<i /> ", "") \
                                 .replace("<p>", "").replace("</p>","") \
                                 .replace("<![CDATA[<B>", "").replace("</B>]]>", "") \
                                 .replace("<italic>", "").replace("</italic>","")\
                                 .replace("<title>", "").replace("</title>", "") \
                                 .replace("<br>", "").replace("</br>","").replace("<br/>","") \
                                 .replace("<B>", "").replace("</B>", "") \
                                 .replace("<em>", "").replace("</em>", "") \
                                 .replace("<BR>", "").replace("</BR>", "") \
                                 .replace("<title>", "").replace("</title>", "") \
                                 .replace("<strong>", "").replace("</strong>", "") \
                                 .replace("<formula>", "").replace("</formula>", "") \
                                 .replace("<roman>", "").replace("</roman>", "") \
                                 .replace("<SUP>", "").replace("</SUP>", "") \
                                 .replace("<SSUP>", "").replace("</SSUP>", "") \
                                 .replace("<sc>", "").replace("</sc>", "") \
                                 .replace("<subtitle>", "").replace("</subtitle>", "") \
                                 .replace("<emph/>", "").replace("<emph>", "").replace("</emph>", "") \
                                 .replace("""<p class="Body">""", "") \
                                 .replace("<TITLE>", "").replace("</TITLE>", "") \
                                 .replace("<sub />", "").replace("<sub/>", "") \
                                 .replace("<mi>", "").replace("</mi>", "") \
                                 .replace("<bold>", "").replace("</bold>", "") \
                                 .replace("<mtext>", "").replace("</mtext>", "") \
                                 .replace("<msub>", "").replace("</msub>", "") \
                                 .replace("<mrow>", "").replace("</mrow>", "") \
                                 .replace("</mfenced>", "").replace("</math>", "")

            if '<mml' in new_title:
                all_parts = [x for y in [i.split("mml:math>") for i in new_title.split("<mml:math")] for x in y if x]
                final_parts = []
                for part in all_parts:
                    if re.search(r"\>[$%#!^*\w.,/()+-]*\<", part):
                        pull_out = re.findall(r"\>[$%#!^*\w.,/()+-]*\<", part)
                        final_pieces = []
                        for piece in pull_out:
                            final_pieces.append(piece.replace(">", "").replace("<", ""))
                        
                        final_parts.append(" "+ "".join(final_pieces) + " ")
                    else:
                        final_parts.append(part)
                
                new_title = "".join(final_parts).strip()
            else:
                pass

            if '<xref' in new_title:
                new_title = re.sub(r"\<xref[^/]*\/xref\>", "", new_title)

            if '<inline-formula' in new_title:
                new_title = re.sub(r"\<inline-formula[^/]*\/inline-formula\>", "", new_title)

            if '<title' in new_title:
                new_title = re.sub(r"\<title[^/]*\/title\>", "", new_title)

            if '<p class=' in new_title:
                new_title = re.sub(r"\<p class=[^>]*\>", "", new_title)
            
            if '<span class=' in new_title:
                new_title = re.sub(r"\<span class=[^>]*\>", "", new_title)

            if 'mfenced open' in new_title:
                new_title = re.sub(r"\<mfenced open=[^>]*\>", "", new_title)
            
            if 'math xmlns' in new_title:
                new_title = re.sub(r"\<math xmlns=[^>]*\>", "", new_title)

        if '<' in new_title:
            new_title = new_title.replace(">i<", "").replace(">/i<", "") \
                                 .replace(">b<", "").replace(">/b<", "") \
                                 .replace("<inline-formula>", "").replace("</inline-formula>","")

        return new_title
    else:
        return ''
    
def clean_abstract(raw_abstract, inverted=False):
    """
    Function to clean abstract and return it in a format for the model.
    
    Input:
    raw_abstract: string of abstract
    inverted: boolean to determine if abstract is inverted index or not
    
    Output:
    final_abstract: string of abstract in format for model
    """
    if inverted:
        if isinstance(raw_abstract, dict) | isinstance(raw_abstract, str):
            if isinstance(raw_abstract, dict):
                invert_abstract = raw_abstract
            else:
                invert_abstract = json.loads(raw_abstract)
            
            if invert_abstract.get('IndexLength'):
                ab_len = invert_abstract['IndexLength']

                if ab_len > 20:
                    abstract = [" "]*ab_len
                    for key, value in invert_abstract['InvertedIndex'].items():
                        for i in value:
                            abstract[i] = key
                    final_abstract = " ".join(abstract)[:2500]
                    keep_abs = check_for_non_latin_characters(final_abstract)
                    if keep_abs == 1:
                        pass
                    else:
                        final_abstract = None
                else:
                    final_abstract = None
            else:
                if len(invert_abstract) > 20:
                    abstract = [" "]*1200
                    for key, value in invert_abstract.items():
                        for i in value:
                            try:
                                abstract[i] = key
                            except:
                                pass
                    final_abstract = " ".join(abstract)[:2500].strip()
                    keep_abs = check_for_non_latin_characters(final_abstract)
                    if keep_abs == 1:
                        pass
                    else:
                        final_abstract = None
                else:
                    final_abstract = None
                
        else:
            final_abstract = None
    else:
        ab_len = len(raw_abstract)
        if ab_len > 30:
            final_abstract = raw_abstract[:2500]
            keep_abs = check_for_non_latin_characters(final_abstract)
            if keep_abs == 1:
                pass
            else:
                final_abstract = None
        else:
            final_abstract = None
            
    return final_abstract

### Load inference input

In [0]:
df = spark.sql("SELECT * FROM openalex.works.works_topics_compare").repartition(96)
print(f"Input Row Count: {df.count()}")

### Run Inference

In [0]:
import json
from pyspark.sql.functions import *
from pyspark.sql.types import *

import torch
import re

def process_partition(rows, batch_size=12):
    import torch
    ModelCache.load()
    model = ModelCache.model

    batch_rows = []
    batch_texts = []

    def yield_batch(rows_batch, texts_batch):
        try:
            batch_outputs = model(texts_batch)
        except Exception as e:
            batch_outputs = [[] for _ in texts_batch]  # fail-safe: empty predictions

        for row, output in zip(rows_batch, batch_outputs):
            row_dict = row.asDict()
            lm_output = [
                {
                    "topic_id": 10000 + int(topic["label"].split(":")[0]),
                    "score": float(topic["score"])
                }
                for topic in output
            ] if output else None

            row_dict["lm_topics"] = lm_output
            yield row_dict

    for row in rows:
        if row is None:
            continue

        title = clean_title(row['title']) or ""
        abstract = clean_abstract(row['abstract']) or ""
        full_text = f"<TITLE>{title.strip()}<ABSTRACT>{abstract.strip()}"

        batch_rows.append(row)
        batch_texts.append(full_text)

        if len(batch_texts) >= batch_size:
            yield from yield_batch(batch_rows, batch_texts)
            batch_rows = []
            batch_texts = []

    # Process remaining rows
    if batch_rows:
        yield from yield_batch(batch_rows, batch_texts)

topic_struct = StructType([
    StructField("topic_id", IntegerType(), True),
    StructField("score", FloatType(), True)
])

output_schema = StructType([
    StructField("work_id", StringType(), True),
    StructField("title", StringType(), True),
    StructField("abstract", StringType(), True),
    StructField("journal_name", StringType(), True),
    StructField("lm_topics", ArrayType(topic_struct), True)
])

res_rdd = df.select("work_id", "title", "abstract", "journal_name").foreachPartition(process_partition)
res_df = spark.createDataFrame(res_rdd, output_schema).cache()
print(f"Output Row count: {res_df.count()}")

res_df = (res_df.select("work_id", "title", "abstract", "journal_name", "lm_topics")
                .withColumn("lm_primary_topic", col("lm_topics")[0])
                .withColumn("created_timestamp", current_timestamp()))
display(res_df)

### Profile with Serverless GPU (A10)
Batching is effective, but not dramatically so when compared to single inputs (10-15%)

In [0]:
import time

# Repeat the same input for X iterations
X = 1000
inputs = [olfactory_input, paleotidal_input] * (X // 2)

# Warm up the model (important for GPU / XLA compilation)
_ = classifier_multi(inputs[:2])

# Time the loop
start_time = time.time()

for i in range(X):
    _ = classifier_multi(inputs[i]) # run with single input only

total_time = time.time() - start_time
print(f"🔥 Ran {X} inferences in {total_time:.2f} seconds")
print(f"⚡ Avg per inference: {total_time / X:.3f} sec")


In [0]:
from transformers import pipeline
import time

# Prepare inputs
X = 640
inputs = [olfactory_input, paleotidal_input] * (X // 2)

# Try different batch sizes
for bs in [4, 8, 16, 32, 64]:
    print(f"\n🚀 Testing batch_size={bs}")

    # Re-initialize pipeline with new batch size
    classifier_multi = pipeline(
        model="OpenAlex/bert-base-multilingual-cased-finetuned-openalex-topic-classification-title-abstract",
        top_k=5,
        batch_size=bs
    )

    # Warm-up (important for GPU/XLA)
    _ = classifier_multi(inputs[:2])

    # Time batch inference
    start = time.time()
    _ = classifier_multi(inputs) # takes in the whole dataset and batches internally
    duration = time.time() - start

    print(f"⏱️  Total time: {duration:.2f} sec for {X} inputs")
    print(f"⚡ Avg time per input: {duration / X:.3f} sec")
