# Batch Work Embeddings

Generates embeddings for all OpenAlex works using OpenAI text-embedding-3-small.

Uses OpenAI Python API directly (ai_query only works on SQL warehouses, not job clusters).

In [None]:
# Install openai if needed
%pip install openai --quiet

In [None]:
# Configuration
BATCH_SIZE = 100  # OpenAI batch limit
BATCHES_PER_COMMIT = 10  # Commit to Delta every N batches
OUTPUT_TABLE = "openalex.vector_search.work_embeddings"
SOURCE_TABLE = "openalex.works.openalex_works"
MODEL = "text-embedding-3-small"

# Get OpenAI API key from secret scope
import os
api_key = dbutils.secrets.get(scope="openalex", key="openai_api_key")
print(f"API key loaded (length: {len(api_key)})")

# Quick check - count existing embeddings
existing = spark.sql(f"SELECT COUNT(*) as n FROM {OUTPUT_TABLE}").collect()[0]['n']
print(f"Existing embeddings: {existing:,}")

In [None]:
import time
from datetime import datetime
from openai import OpenAI
import hashlib

client = OpenAI(api_key=api_key)

def get_embeddings(texts):
    """Get embeddings for a list of texts using OpenAI API."""
    response = client.embeddings.create(
        model=MODEL,
        input=texts
    )
    return [item.embedding for item in response.data]

def get_batch_of_works(limit):
    """Get a batch of works that need embeddings."""
    return spark.sql(f"""
        SELECT 
            CAST(w.id AS STRING) as work_id,
            w.title,
            w.abstract,
            w.publication_year,
            w.type,
            w.open_access.is_oa as is_oa
        FROM {SOURCE_TABLE} w
        WHERE w.type != 'dataset'
          AND w.abstract IS NOT NULL
          AND w.title IS NOT NULL
          AND NOT EXISTS (
              SELECT 1 FROM {OUTPUT_TABLE} e WHERE e.work_id = CAST(w.id AS STRING)
          )
        LIMIT {limit}
    """).collect()

def process_batch(works):
    """Process a batch of works: get embeddings and return rows to insert."""
    texts = [f"Title: {w.title}\n\nAbstract: {w.abstract}" for w in works]
    embeddings = get_embeddings(texts)
    
    rows = []
    now = datetime.utcnow()
    for work, embedding in zip(works, embeddings):
        text = f"Title: {work.title}\n\nAbstract: {work.abstract}"
        text_hash = hashlib.md5(text.encode()).hexdigest()
        rows.append((
            work.work_id,
            embedding,
            text_hash,
            work.publication_year,
            work.type,
            work.is_oa,
            True,  # has_abstract
            now,
            now
        ))
    return rows

def insert_rows(all_rows):
    """Insert rows into Delta table."""
    from pyspark.sql.types import StructType, StructField, StringType, ArrayType, DoubleType, IntegerType, BooleanType, TimestampType
    
    schema = StructType([
        StructField("work_id", StringType(), False),
        StructField("embedding", ArrayType(DoubleType()), False),
        StructField("text_hash", StringType(), True),
        StructField("publication_year", IntegerType(), True),
        StructField("type", StringType(), True),
        StructField("is_oa", BooleanType(), True),
        StructField("has_abstract", BooleanType(), True),
        StructField("created_at", TimestampType(), True),
        StructField("updated_at", TimestampType(), True)
    ])
    
    df = spark.createDataFrame(all_rows, schema)
    df.write.mode("append").saveAsTable(OUTPUT_TABLE)
    return len(all_rows)

def get_count():
    """Get current embedding count."""
    return spark.sql(f"SELECT COUNT(*) as n FROM {OUTPUT_TABLE}").collect()[0]['n']

# Main loop
total_processed = 0
start_time = time.time()
batch_num = 0
last_count = get_count()

print("="*70)
print(f"[{datetime.now().strftime('%H:%M:%S')}] Starting from {last_count:,} embeddings")
print(f"Batch size: {BATCH_SIZE}, Commits every {BATCHES_PER_COMMIT} batches")
print("="*70)

all_rows = []

while True:
    batch_start = time.time()
    batch_num += 1
    
    try:
        # Get works to process
        works = get_batch_of_works(BATCH_SIZE)
        
        if len(works) == 0:
            # Flush any remaining rows
            if all_rows:
                insert_rows(all_rows)
            print(f"\n[{datetime.now().strftime('%H:%M:%S')}] No more works to process!")
            break
        
        # Process batch
        rows = process_batch(works)
        all_rows.extend(rows)
        total_processed += len(rows)
        
        batch_elapsed = time.time() - batch_start
        
        # Commit every N batches
        if batch_num % BATCHES_PER_COMMIT == 0:
            insert_rows(all_rows)
            all_rows = []
            
            new_count = get_count()
            total_elapsed = time.time() - start_time
            rate = (new_count - last_count) / total_elapsed if total_elapsed > 0 else 0
            remaining_est = 217000000 - new_count
            eta_hours = remaining_est / rate / 3600 if rate > 0 else 0
            
            print(f"[{datetime.now().strftime('%H:%M:%S')}] Batch {batch_num}: "
                  f"Total: {new_count:,} | "
                  f"Rate: {rate:.1f}/s | "
                  f"ETA: {eta_hours:.1f}h")
        else:
            print(f"[{datetime.now().strftime('%H:%M:%S')}] Batch {batch_num}: "
                  f"+{len(rows)} in {batch_elapsed:.1f}s")
              
    except Exception as e:
        print(f"[{datetime.now().strftime('%H:%M:%S')}] Error in batch {batch_num}: {e}")
        # Flush what we have
        if all_rows:
            try:
                insert_rows(all_rows)
                all_rows = []
            except:
                pass
        print("Waiting 60s before retry...")
        time.sleep(60)
        continue

print("="*70)
final_count = get_count()
print(f"Complete! Total embeddings: {final_count:,}")
print(f"Added {final_count - existing:,} in {(time.time() - start_time)/3600:.1f} hours")