### Load SDG Model and Tokenizer


In [None]:
import tensorflow as tf
from transformers import BertTokenizerFast

SRC_SAVED_MODEL = "/Volumes/openalex/works/models/sdg/saved_model"

model = tf.saved_model.load(SRC_SAVED_MODEL)
tokenizer = BertTokenizerFast.from_pretrained("/Volumes/openalex/works/models/sdg/tokenizer")
predict = model.signatures['serving_default']


### Load input data from openalex_works


In [None]:
df = spark.sql("""
    SELECT
        w.id AS work_id,
        w.title,
        w.abstract
    FROM openalex.works.openalex_works w
    LEFT ANTI JOIN openalex.works.works_sdg_frontfill wsf
        ON w.id = wsf.work_id
    WHERE (w.sustainable_development_goals IS NULL OR size(w.sustainable_development_goals) = 0)
        AND w.title IS NOT NULL
        AND w.abstract IS NOT NULL
        AND length(w.title) > 20
        AND length(w.abstract) > 50
""").repartition(4096).cache()

print(f"Total number of rows to process: {df.count()}")


### Run SDG inference using mapPartitions


In [None]:
from pyspark.sql.functions import *
from pyspark.sql.types import *
import tensorflow as tf
import numpy as np
from transformers import BertTokenizerFast

BATCH_SIZE = 64
SRC_SAVED_MODEL = "/Volumes/openalex/works/models/sdg/saved_model"

schema = StructType([
    StructField("work_id", StringType(), nullable=False),
    StructField("sdg_array", ArrayType(
        StructType([
            StructField("id", StringType(), nullable=False),
            StructField("display_name", StringType(), nullable=False),
            StructField("score", FloatType(), nullable=False),
        ])
    ), nullable=True)
])

goal_names = {
    "Goal 1": "No poverty",
    "Goal 2": "Zero hunger",
    "Goal 3": "Good health and well-being",
    "Goal 4": "Quality Education",
    "Goal 5": "Gender equality",
    "Goal 6": "Clean water and sanitation",
    "Goal 7": "Affordable and clean energy",
    "Goal 8": "Decent work and economic growth",
    "Goal 9": "Industry, innovation and infrastructure",
    "Goal 10": "Reduced inequalities",
    "Goal 11": "Sustainable cities and communities",
    "Goal 12": "Responsible consumption and production",
    "Goal 13": "Climate action",
    "Goal 14": "Life below water",
    "Goal 15": "Life in Land",
    "Goal 16": "Peace, Justice and strong institutions",
    "Goal 17": "Partnerships for the goals"
}

def process_batch(batch, predict_fn, tokenizer_local):
    """Process a batch of rows"""
    work_ids = []
    texts = []
    
    for row in batch:
        work_ids.append(row.work_id)
        # Combine title and abstract
        title = row.title or ""
        abstract = row.abstract or ""
        combined_text = f"{title.strip()}\n{abstract.strip()}"
        texts.append(combined_text)
    
    # Process each text
    for work_id, text in zip(work_ids, texts):
        if not text or not text.strip():
            yield (work_id, [])
            continue
        
        try:
            # Tokenize
            enc = tokenizer_local(
                text.lower(),
                truncation=True,
                padding="max_length",
                max_length=512,
                return_tensors="tf"
            )
            
            # Call SavedModel
            out = predict_fn(
                input_ids=enc["input_ids"],
                attention_masks=enc["attention_mask"]
            )
            
            logits = out["target_layer"].numpy()[0]  # float32 [17]
            
            # Build SDG array with id, display_name, score
            sdg_results = []
            for idx, score in enumerate(logits):
                sdg_number = idx + 1
                sdg_label = f"Goal {sdg_number}"
                
                sdg_results.append({
                    "id": f"https://metadata.un.org/sdg/{sdg_number}",
                    "display_name": goal_names[sdg_label],
                    "score": float(score)
                })
            
            # Sort by score descending
            sdg_results.sort(key=lambda x: x["score"], reverse=True)
            
            yield (work_id, sdg_results)
        except Exception as e:
            print(f"Error processing work_id {work_id}: {e}")
            yield (work_id, [])

def process_partition(rows_iter, batch_size=BATCH_SIZE):
    """Process partition - load model once per partition"""
    # Load model ONCE per partition
    model_local = tf.saved_model.load(SRC_SAVED_MODEL)
    predict_local = model_local.signatures['serving_default']
    tokenizer_local = BertTokenizerFast.from_pretrained("/Volumes/openalex/works/models/sdg/tokenizer")
    
    batch = []
    for row in rows_iter:
        batch.append(row)
        if len(batch) >= batch_size:
            yield from process_batch(batch, predict_local, tokenizer_local)
            batch = []
    
    if batch:
        yield from process_batch(batch, predict_local, tokenizer_local)

# Apply with mapPartitions
result_rdd = df.rdd.mapPartitions(process_partition)

# Convert to DataFrame
inferred_sdg_df = spark.createDataFrame(result_rdd, schema).cache()

print(f"Total number of rows processed: {inferred_sdg_df.count()}")


### Write to works_sdg_frontfill table


In [None]:
from pyspark.sql.functions import current_timestamp

final_df = (inferred_sdg_df
    .withColumn("created_timestamp", current_timestamp())
    .select("work_id", "sdg", "created_timestamp")
    .write.format("delta")
    .mode("append")
    .option("mergeSchema", "true")
    .saveAsTable("openalex.works.works_sdg_frontfill")
)

### Verify table structure and sample results


In [None]:
-- Show sample results
SELECT 
    work_id,
    size(sdg_array) AS num_sdgs,
    slice(sdg_array, 1, 3) AS top_3_sdgs,
    created_timestamp
FROM openalex.works.works_sdg_frontfill
LIMIT 10
