### Utilities


In [0]:
import tensorflow as tf
from transformers import BertTokenizerFast
import os

# Batch size for TensorFlow inference - tune based on GPU memory
# L40S (48GB): Recommended 128-256 for optimal throughput
# Start with 64 and increase if memory allows and throughput improves
BATCH_SIZE = 64
SRC_SAVED_MODEL = "/Volumes/openalex/works/models/sdg/saved_model"
TOKENIZER_PATH = "/Volumes/openalex/works/models/sdg/tokenizer"

class ModelCache:
    model = None
    tokenizer = None
    predict_fn = None
    assigned_device = None
    gpu_device = '/cpu:0'  # Default to CPU, will be set during load()

    goal_names = {
        "Goal 1": "No poverty",
        "Goal 2": "Zero hunger",
        "Goal 3": "Good health and well-being",
        "Goal 4": "Quality Education",
        "Goal 5": "Gender equality",
        "Goal 6": "Clean water and sanitation",
        "Goal 7": "Affordable and clean energy",
        "Goal 8": "Decent work and economic growth",
        "Goal 9": "Industry, innovation and infrastructure",
        "Goal 10": "Reduced inequalities",
        "Goal 11": "Sustainable cities and communities",
        "Goal 12": "Responsible consumption and production",
        "Goal 13": "Climate action",
        "Goal 14": "Life below water",
        "Goal 15": "Life in Land",
        "Goal 16": "Peace, Justice and strong institutions",
        "Goal 17": "Partnerships for the goals"
    }

    @classmethod
    def load(cls):
        if cls.model is not None:
            return

        # Determine available GPU count (TensorFlow native detection)
        gpus = tf.config.list_physical_devices('GPU')
        num_devices = len(gpus)
        
        if num_devices == 0:
            cls.assigned_device = -1
            cls.gpu_device = '/cpu:0'
        else:
            # Assign device using pid hash (per Databricks docs: GPUs are zero-indexed)
            cls.assigned_device = os.getpid() % num_devices
            cls.gpu_device = f'/gpu:{cls.assigned_device}'
            
            # Configure TensorFlow GPU (per Databricks documentation)
            tf.config.set_visible_devices(gpus[cls.assigned_device], 'GPU')
            tf.config.experimental.set_memory_growth(gpus[cls.assigned_device], True)

        # Load model with explicit device placement (per Databricks recommendation)
        with tf.device(cls.gpu_device):
            cls.model = tf.saved_model.load(SRC_SAVED_MODEL)
        cls.tokenizer = BertTokenizerFast.from_pretrained(TOKENIZER_PATH)
        cls.predict_fn = cls.model.signatures['serving_default']

    @classmethod
    def predict_batch(cls, texts):
        """Predict SDG scores for a batch of text strings - much faster than individual calls"""
        if cls.model is None:
            cls.load()
        
        if not texts:
            return []
        
        # Filter out empty texts
        valid_texts = [t.strip().lower() if t and t.strip() else "" for t in texts]
        if not any(valid_texts):
            return [[] for _ in texts]
        
        try:
            # Batch tokenize all texts at once
            enc = cls.tokenizer(
                valid_texts,
                truncation=True,
                padding="max_length",
                max_length=512,
                return_tensors="tf"
            )
            
            # Call SavedModel on batch with explicit device placement
            with tf.device(cls.gpu_device):
                out = cls.predict_fn(
                    input_ids=enc["input_ids"],
                    attention_masks=enc["attention_mask"]
                )
            
            logits_batch = out["target_layer"].numpy()  # shape: [batch_size, 17]
            
            # Process each result in the batch
            results = []
            score_threshold = 0.1
            top_k = 3
            
            for logits in logits_batch:
                # Build SDG array with id, display_name, score
                sdg_results = []
                for idx, score in enumerate(logits):
                    sdg_number = idx + 1
                    sdg_label = f"Goal {sdg_number}"
                    
                    sdg_results.append({
                        "id": f"https://metadata.un.org/sdg/{sdg_number}",
                        "display_name": cls.goal_names[sdg_label],
                        "score": float(score)
                    })
                
                # Sort by score descending, filter by threshold, take top 3
                sdg_results.sort(key=lambda x: x["score"], reverse=True)
                filtered = [sdg for sdg in sdg_results if sdg["score"] > score_threshold]
                top_results = filtered[:top_k]
                results.append(top_results)
            
            return results
        except Exception as e:
            print(f"Error predicting SDG batch: {e}")
            return [[] for _ in texts]

    @classmethod
    def predict(cls, text):
        """Predict SDG scores for arbitrary text string (single)"""
        results = cls.predict_batch([text])
        return results[0] if results else []


In [0]:
### Create tables (if needed)

### Load input data and cache


In [0]:
df = (spark.table("openalex.works.works_sdg_frontfill_input")
      .select("work_id", "title", "abstract")
      .limit(4096000)
      .repartition(1024)
)
df.cache()

print(f"Input Row Count: {df.count()}")

### Run inference


In [0]:
from pyspark.sql.functions import *
from pyspark.sql.types import *
import time

def process_partition(rows):
    """
    Process a partition using mapPartitions
    Returns tuples to avoid driver memory pressure
    """
    ModelCache.load()
    
    batch_rows = []
    batch_texts = []
    
    def yield_batch(rows_batch, texts_batch):
        try:
            # Process entire batch at once - much faster!
            sdg_results_batch = ModelCache.predict_batch(texts_batch)
            
            # Yield results for each row
            for row, sdg_results in zip(rows_batch, sdg_results_batch):
                yield (row.work_id, sdg_results)
        except Exception as e:
            print(f"Error processing batch: {e}")
            # Yield empty results for failed batch
            for row in rows_batch:
                yield (row.work_id, [])
    
    for row in rows:
        if row is None:
            continue
        
        # Combine title and abstract for prediction
        title = (row.title or "").strip()
        abstract = (row.abstract or "").strip()
        combined_text = f"{title}\n{abstract}"
        
        batch_rows.append(row)
        batch_texts.append(combined_text)
        
        if len(batch_texts) >= BATCH_SIZE:
            yield from yield_batch(batch_rows, batch_texts)
            batch_rows = []
            batch_texts = []
    
    # Process remaining rows
    if batch_texts:
        yield from yield_batch(batch_rows, batch_texts)

start_time = time.time()

# Define schema upfront
sdg_struct = StructType([
    StructField("id", StringType(), True),
    StructField("display_name", StringType(), True),
    StructField("score", FloatType(), True)
])

output_schema = StructType([
    StructField("work_id", StringType(), nullable=False),
    StructField("sdg", ArrayType(sdg_struct), nullable=True)
])

# Use mapPartitions - creates RDD, converts to DataFrame
res_rdd = df.select("work_id", "title", "abstract").rdd.mapPartitions(process_partition)
inferred_sdg_df = spark.createDataFrame(res_rdd, output_schema)
inferred_sdg_df.cache()

output_count = inferred_sdg_df.count()
print(f"Output Row count: {output_count}")

runtime = time.time() - start_time
print(f"Total runtime: {runtime:.4f} seconds")
if runtime > 0:
    print(f"Total throughput: {output_count / runtime:.4f} inferences/sec")


### Write to works_sdg_frontfill table


In [0]:
from pyspark.sql.functions import current_timestamp

# Write to table
(inferred_sdg_df
    .withColumn("created_timestamp", current_timestamp())
    .select("work_id", "sdg", "created_timestamp")
    .write.mode("append")
    .option("mergeSchema", "true")
    .saveAsTable("openalex.works.works_sdg_frontfill")
)

# Cleanup: delete processed work_ids from input table
inferred_sdg_df.createOrReplaceTempView("processed_ids")
spark.sql("""
    DELETE FROM openalex.works.works_sdg_frontfill_input 
    WHERE work_id IN (SELECT work_id FROM processed_ids)
""")
print("Removed processed work_ids from works_sdg_frontfill_input")

### Verify table structure and sample results


In [0]:
%sql
SELECT 
    work_id,
    size(sdg) AS num_sdgs,
    slice(sdg, 1, 3) AS top_3_sdgs,
    created_timestamp
FROM openalex.works.works_sdg_frontfill;