# Batch Work Embeddings at Scale

Generates embeddings for all OpenAlex works using `ai_query` with OpenAI text-embedding-3-small.

**Token limit handling**: Truncates title to 500 chars and abstract to 5500 chars (~6K total).
This guarantees staying under 8192 tokens even for CJK text (1 char â‰ˆ 1 token).

In [None]:
# Configuration
BATCH_SIZE = 50000  # Works per batch
OUTPUT_TABLE = "openalex.vector_search.work_embeddings"
SOURCE_TABLE = "openalex.works.openalex_works"
ENDPOINT_NAME = "openai-embedding-3-small"

In [None]:
# Check current progress
progress = spark.sql(f"""
    SELECT 
        (SELECT COUNT(*) FROM {OUTPUT_TABLE}) as embedded,
        (SELECT COUNT(*) FROM {SOURCE_TABLE} WHERE type != 'dataset' AND abstract IS NOT NULL) as total_with_abstract
""").collect()[0]

print(f"Progress: {progress['embedded']:,} / {progress['total_with_abstract']:,} works with abstracts")
print(f"Remaining: {progress['total_with_abstract'] - progress['embedded']:,}")
print(f"Percent complete: {100 * progress['embedded'] / progress['total_with_abstract']:.4f}%")

In [None]:
import time

start_time = time.time()

# Process one batch
# Hardcoded truncation: title=500, abstract=5500 chars (safe for all languages)
result = spark.sql(f"""
    INSERT INTO {OUTPUT_TABLE}
    SELECT 
        CAST(w.id AS STRING) as work_id,
        ai_query(
            '{ENDPOINT_NAME}',
            CONCAT('Title: ', LEFT(w.title, 500), '\n\nAbstract: ', LEFT(w.abstract, 5500))
        ) as embedding,
        md5(CONCAT('Title: ', LEFT(w.title, 500), '\n\nAbstract: ', LEFT(w.abstract, 5500))) as text_hash,
        w.publication_year,
        w.type,
        w.open_access.is_oa as is_oa,
        true as has_abstract,
        current_timestamp() as created_at,
        current_timestamp() as updated_at
    FROM {SOURCE_TABLE} w
    WHERE w.type != 'dataset'
      AND w.abstract IS NOT NULL
      AND w.title IS NOT NULL
      AND NOT EXISTS (
          SELECT 1 FROM {OUTPUT_TABLE} e WHERE e.work_id = CAST(w.id AS STRING)
      )
    LIMIT {BATCH_SIZE}
""")

elapsed = time.time() - start_time
print(f"Batch complete in {elapsed:.1f} seconds")
print(f"Rate: {BATCH_SIZE / elapsed:.0f} works/second" if elapsed > 0 else "")

## Continuous Processing Loop

Run this to continuously process batches until complete (or notebook is stopped).

In [None]:
import time
from datetime import datetime

def get_embedded_count():
    """Get count of embeddings (fast - small table)."""
    return spark.sql(f"SELECT COUNT(*) as n FROM {OUTPUT_TABLE}").collect()[0]['n']

def process_batch():
    """Process a single batch."""
    # Hardcoded truncation: title=500, abstract=5500 chars
    spark.sql(f"""
        INSERT INTO {OUTPUT_TABLE}
        SELECT 
            CAST(w.id AS STRING) as work_id,
            ai_query(
                '{ENDPOINT_NAME}',
                CONCAT('Title: ', LEFT(w.title, 500), '\n\nAbstract: ', LEFT(w.abstract, 5500))
            ) as embedding,
            md5(CONCAT('Title: ', LEFT(w.title, 500), '\n\nAbstract: ', LEFT(w.abstract, 5500))) as text_hash,
            w.publication_year,
            w.type,
            w.open_access.is_oa as is_oa,
            true as has_abstract,
            current_timestamp() as created_at,
            current_timestamp() as updated_at
        FROM {SOURCE_TABLE} w
        WHERE w.type != 'dataset'
          AND w.abstract IS NOT NULL
          AND w.title IS NOT NULL
          AND NOT EXISTS (
              SELECT 1 FROM {OUTPUT_TABLE} e WHERE e.work_id = CAST(w.id AS STRING)
          )
        LIMIT {BATCH_SIZE}
    """)

# Main loop
start_time = time.time()
batch_num = 0
start_count = get_embedded_count()
target = 217000000

print(f"Starting from {start_count:,} embeddings")
print(f"Batch size: {BATCH_SIZE:,}")
print("="*70)

while True:
    batch_start = time.time()
    batch_num += 1
    
    try:
        process_batch()
        batch_elapsed = time.time() - batch_start
        
        # Check progress every batch
        current = get_embedded_count()
        added_this_batch = current - start_count - (batch_num - 1) * BATCH_SIZE
        
        if added_this_batch <= 0 and batch_num > 1:
            print(f"\nNo new rows added - complete!")
            break
        
        total_added = current - start_count
        total_elapsed = time.time() - start_time
        rate = total_added / total_elapsed if total_elapsed > 0 else 0
        remaining = target - current
        eta_hours = remaining / rate / 3600 if rate > 0 else 0
        
        print(f"[{datetime.now().strftime('%H:%M:%S')}] Batch {batch_num}: "
              f"{batch_elapsed:.0f}s | "
              f"Total: {current:,} | "
              f"Rate: {rate:.0f}/s | "
              f"ETA: {eta_hours:.1f}h")
              
    except Exception as e:
        print(f"\n[{datetime.now().strftime('%H:%M:%S')}] Error: {e}")
        print("Waiting 60s before retry...")
        time.sleep(60)

print("="*70)
final = get_embedded_count()
print(f"Done! {final:,} total embeddings")
print(f"Added {final - start_count:,} in {(time.time() - start_time)/3600:.1f}h")

## Verify Results

In [None]:
spark.sql(f"""
    SELECT 
        COUNT(*) as total_embeddings,
        MIN(created_at) as oldest,
        MAX(created_at) as newest
    FROM {OUTPUT_TABLE}
""").show()