# Batch Work Embeddings at Scale

Generates embeddings for all OpenAlex works using `ai_query` with OpenAI text-embedding-3-small.

**Strategy**: Process in batches using SQL, respecting OpenAI rate limits.

**Note**: Truncates abstracts to 30K chars (~7500 tokens) to stay under 8192 token limit.

In [None]:
# Configuration
BATCH_SIZE = 50000  # Works per batch
OUTPUT_TABLE = "openalex.vector_search.work_embeddings"
SOURCE_TABLE = "openalex.works.openalex_works"
ENDPOINT_NAME = "openai-embedding-3-small"
MAX_ABSTRACT_CHARS = 30000  # ~7500 tokens, leaves room for title within 8192 limit

In [None]:
# Check current progress
progress = spark.sql(f"""
    SELECT 
        (SELECT COUNT(*) FROM {OUTPUT_TABLE}) as embedded,
        (SELECT COUNT(*) FROM {SOURCE_TABLE} WHERE type != 'dataset' AND abstract IS NOT NULL) as total_with_abstract,
        (SELECT COUNT(*) FROM {SOURCE_TABLE} WHERE type != 'dataset') as total_all
""").collect()[0]

print(f"Progress: {progress['embedded']:,} / {progress['total_with_abstract']:,} works with abstracts")
print(f"Remaining: {progress['total_with_abstract'] - progress['embedded']:,}")
print(f"Percent complete: {100 * progress['embedded'] / progress['total_with_abstract']:.2f}%")

In [None]:
# Count works still needing embeddings
remaining = spark.sql(f"""
    SELECT COUNT(*) as remaining
    FROM {SOURCE_TABLE} w
    WHERE w.type != 'dataset'
      AND w.abstract IS NOT NULL
      AND w.title IS NOT NULL
      AND NOT EXISTS (
          SELECT 1 FROM {OUTPUT_TABLE} e WHERE e.work_id = CAST(w.id AS STRING)
      )
""").collect()[0]['remaining']

print(f"Works still needing embeddings: {remaining:,}")
print(f"Estimated batches: {remaining // BATCH_SIZE + 1}")

In [None]:
import time

start_time = time.time()

# Process one batch - with truncation to avoid token limit errors
result = spark.sql(f"""
    INSERT INTO {OUTPUT_TABLE}
    SELECT 
        CAST(w.id AS STRING) as work_id,
        ai_query(
            '{ENDPOINT_NAME}',
            CONCAT('Title: ', w.title, '\n\nAbstract: ', LEFT(w.abstract, {MAX_ABSTRACT_CHARS}))
        ) as embedding,
        md5(CONCAT('Title: ', w.title, '\n\nAbstract: ', LEFT(w.abstract, {MAX_ABSTRACT_CHARS}))) as text_hash,
        w.publication_year,
        w.type,
        w.open_access.is_oa as is_oa,
        true as has_abstract,
        current_timestamp() as created_at,
        current_timestamp() as updated_at
    FROM {SOURCE_TABLE} w
    WHERE w.type != 'dataset'
      AND w.abstract IS NOT NULL
      AND w.title IS NOT NULL
      AND NOT EXISTS (
          SELECT 1 FROM {OUTPUT_TABLE} e WHERE e.work_id = CAST(w.id AS STRING)
      )
    LIMIT {BATCH_SIZE}
""")

elapsed = time.time() - start_time
print(f"Batch complete in {elapsed:.1f} seconds")
print(f"Rate: {BATCH_SIZE / elapsed:.0f} works/second" if elapsed > 0 else "")

## Continuous Processing Loop

Run this to continuously process batches until complete (or notebook is stopped).

In [None]:
import time
from datetime import datetime

def get_remaining_count():
    """Get count of works still needing embeddings."""
    return spark.sql(f"""
        SELECT COUNT(*) as remaining
        FROM {SOURCE_TABLE} w
        WHERE w.type != 'dataset'
          AND w.abstract IS NOT NULL
          AND w.title IS NOT NULL
          AND NOT EXISTS (
              SELECT 1 FROM {OUTPUT_TABLE} e WHERE e.work_id = CAST(w.id AS STRING)
          )
    """).collect()[0]['remaining']

def process_batch():
    """Process a single batch and return rows inserted."""
    spark.sql(f"""
        INSERT INTO {OUTPUT_TABLE}
        SELECT 
            CAST(w.id AS STRING) as work_id,
            ai_query(
                '{ENDPOINT_NAME}',
                CONCAT('Title: ', w.title, '\n\nAbstract: ', LEFT(w.abstract, {MAX_ABSTRACT_CHARS}))
            ) as embedding,
            md5(CONCAT('Title: ', w.title, '\n\nAbstract: ', LEFT(w.abstract, {MAX_ABSTRACT_CHARS}))) as text_hash,
            w.publication_year,
            w.type,
            w.open_access.is_oa as is_oa,
            true as has_abstract,
            current_timestamp() as created_at,
            current_timestamp() as updated_at
        FROM {SOURCE_TABLE} w
        WHERE w.type != 'dataset'
          AND w.abstract IS NOT NULL
          AND w.title IS NOT NULL
          AND NOT EXISTS (
              SELECT 1 FROM {OUTPUT_TABLE} e WHERE e.work_id = CAST(w.id AS STRING)
          )
        LIMIT {BATCH_SIZE}
    """)
    return BATCH_SIZE

# Main loop
total_processed = 0
start_time = time.time()
batch_num = 0

remaining = get_remaining_count()
print(f"Starting continuous processing. {remaining:,} works remaining.")
print(f"Batch size: {BATCH_SIZE:,}")
print("="*60)

while remaining > 0:
    batch_start = time.time()
    batch_num += 1
    
    try:
        rows = process_batch()
        total_processed += rows
        
        batch_elapsed = time.time() - batch_start
        total_elapsed = time.time() - start_time
        rate = total_processed / total_elapsed if total_elapsed > 0 else 0
        
        # Get updated remaining count every 10 batches (expensive query)
        if batch_num % 10 == 0:
            remaining = get_remaining_count()
        else:
            remaining = max(0, remaining - rows)
        
        eta_hours = remaining / rate / 3600 if rate > 0 else 0
        
        print(f"[{datetime.now().strftime('%H:%M:%S')}] Batch {batch_num}: "
              f"+{rows:,} works in {batch_elapsed:.0f}s | "
              f"Total: {total_processed:,} | "
              f"Remaining: {remaining:,} | "
              f"Rate: {rate:.0f}/s | "
              f"ETA: {eta_hours:.1f}h")
              
    except Exception as e:
        print(f"Error in batch {batch_num}: {e}")
        print("Waiting 60s before retry...")
        time.sleep(60)
        continue

print("="*60)
print(f"Complete! Processed {total_processed:,} works in {(time.time() - start_time)/3600:.1f} hours")

## Verify Results

In [None]:
%%sql
SELECT 
    COUNT(*) as total_embeddings,
    SUM(CASE WHEN has_abstract THEN 1 ELSE 0 END) as with_abstract,
    MIN(created_at) as oldest,
    MAX(created_at) as newest,
    AVG(SIZE(embedding)) as avg_dims
FROM openalex.vector_search.work_embeddings