### Utilities


In [0]:
import tensorflow as tf
from transformers import BertTokenizerFast
import os

BATCH_SIZE = 64
SRC_SAVED_MODEL = "/Volumes/openalex/works/models/sdg/saved_model"
TOKENIZER_PATH = "/Volumes/openalex/works/models/sdg/tokenizer"

goal_names = {
    "Goal 1": "No poverty",
    "Goal 2": "Zero hunger",
    "Goal 3": "Good health and well-being",
    "Goal 4": "Quality Education",
    "Goal 5": "Gender equality",
    "Goal 6": "Clean water and sanitation",
    "Goal 7": "Affordable and clean energy",
    "Goal 8": "Decent work and economic growth",
    "Goal 9": "Industry, innovation and infrastructure",
    "Goal 10": "Reduced inequalities",
    "Goal 11": "Sustainable cities and communities",
    "Goal 12": "Responsible consumption and production",
    "Goal 13": "Climate action",
    "Goal 14": "Life below water",
    "Goal 15": "Life in Land",
    "Goal 16": "Peace, Justice and strong institutions",
    "Goal 17": "Partnerships for the goals"
}

class ModelCache:
    model = None
    tokenizer = None
    predict_fn = None

    @classmethod
    def load(cls):
        if cls.model is not None:
            return

        cls.model = tf.saved_model.load(SRC_SAVED_MODEL)
        cls.tokenizer = BertTokenizerFast.from_pretrained(TOKENIZER_PATH)
        cls.predict_fn = cls.model.signatures['serving_default']

    @classmethod
    def predict(cls, text):
        """Predict SDG scores for arbitrary text string"""
        if cls.model is None:
            cls.load()
        
        if not text or not text.strip():
            return []
        
        try:
            # Tokenize
            enc = cls.tokenizer(
                text.lower(),
                truncation=True,
                padding="max_length",
                max_length=512,
                return_tensors="tf"
            )
            
            # Call SavedModel
            out = cls.predict_fn(
                input_ids=enc["input_ids"],
                attention_masks=enc["attention_mask"]
            )
            
            logits = out["target_layer"].numpy()[0]  # float32 [17]
            
            # Build SDG array with id, display_name, score
            sdg_results = []
            for idx, score in enumerate(logits):
                sdg_number = idx + 1
                sdg_label = f"Goal {sdg_number}"
                
                sdg_results.append({
                    "id": f"https://metadata.un.org/sdg/{sdg_number}",
                    "display_name": goal_names[sdg_label],
                    "score": float(score)
                })
            
            # Sort by score descending
            sdg_results.sort(key=lambda x: x["score"], reverse=True)
            
            return sdg_results
        except Exception as e:
            print(f"Error predicting SDG for text: {e}")
            return []


In [0]:
### Create tables (if needed)

### Load input data and cache


In [0]:
df = (spark.table("openalex.works.works_sdg_frontfill_input")
      .select("work_id", "title", "abstract")
      .limit(2048000)
      .repartition(512)
)

print(f"Input Row Count: {df.count()}")

### Run inference


In [0]:
from pyspark.sql.functions import *
from pyspark.sql.types import *
import time

def process_partition(rows):
    """
    Process a partition using mapPartitions
    Returns tuples to avoid driver memory pressure
    """
    ModelCache.load()
    
    batch_rows = []
    batch_texts = []
    
    def yield_batch(rows_batch, texts_batch):
        try:
            # Process each text in batch using ModelCache.predict
            for row, text in zip(rows_batch, texts_batch):
                sdg_results = ModelCache.predict(text)
                yield (row.work_id, sdg_results)
        except Exception as e:
            print(f"Error processing batch: {e}")
            # Yield empty results for failed batch
            for row in rows_batch:
                yield (row.work_id, [])
    
    for row in rows:
        if row is None:
            continue
        
        # Combine title and abstract for prediction
        title = (row.title or "").strip()
        abstract = (row.abstract or "").strip()
        combined_text = f"{title}\n{abstract}"
        
        batch_rows.append(row)
        batch_texts.append(combined_text)
        
        if len(batch_texts) >= BATCH_SIZE:
            yield from yield_batch(batch_rows, batch_texts)
            batch_rows = []
            batch_texts = []
    
    # Process remaining rows
    if batch_texts:
        yield from yield_batch(batch_rows, batch_texts)

start_time = time.time()

# Define schema upfront
sdg_struct = StructType([
    StructField("id", StringType(), True),
    StructField("display_name", StringType(), True),
    StructField("score", FloatType(), True)
])

output_schema = StructType([
    StructField("work_id", StringType(), nullable=False),
    StructField("sdg_array", ArrayType(sdg_struct), nullable=True)
])

# Use mapPartitions - creates RDD, converts to DataFrame, writes ONCE
res_rdd = df.select("work_id", "title", "abstract").rdd.mapPartitions(process_partition)
inferred_sdg_df = spark.createDataFrame(res_rdd, output_schema).cache()

output_count = inferred_sdg_df.count()
print(f"Output Row count: {output_count}")

runtime = time.time() - start_time
print(f"Total runtime: {runtime:.4f} seconds")
print(f"Total throughput: {output_count / runtime:.4f} inferences/sec")


### Write to works_sdg_frontfill table


In [0]:
from pyspark.sql.functions import current_timestamp

final_df = (inferred_sdg_df
    .withColumn("created_timestamp", current_timestamp())
    .select("work_id", "sdg", "created_timestamp"))

# Append to table
final_df.write.format("delta") \
    .mode("append") \
    .option("mergeSchema", "true") \
    .saveAsTable("openalex.works.works_sdg_frontfill")

# Register as temp view for cleanup
final_df.createOrReplaceTempView("res_df_temp")

# Delete from input table using the work_ids in res_df
spark.sql("""
    DELETE FROM openalex.works.works_sdg_frontfill_input 
    WHERE work_id IN (SELECT work_id FROM res_df_temp)
""")
print(f"Removed processed work_ids from works_sdg_frontfill_input")

### Verify table structure and sample results


In [0]:
%sql
SELECT 
    work_id,
    size(sdg) AS num_sdgs,
    slice(sdg, 1, 3) AS top_3_sdgs,
    created_timestamp
FROM openalex.works.works_sdg_frontfill;