# Batch Work Embeddings

Generates embeddings for all OpenAlex works using `ai_query` with OpenAI text-embedding-3-small.

**Optimized for job execution**: Skips expensive COUNT queries, starts processing immediately.

In [None]:
# Configuration
BATCH_SIZE = 1000  # Start small to verify throughput, increase once confirmed working
OUTPUT_TABLE = "openalex.vector_search.work_embeddings"
SOURCE_TABLE = "openalex.works.openalex_works"
ENDPOINT_NAME = "openai-embedding-3-small"

# Quick check - just count existing embeddings (small table, fast)
existing = spark.sql(f"SELECT COUNT(*) as n FROM {OUTPUT_TABLE}").collect()[0]['n']
print(f"Existing embeddings: {existing:,}")
print(f"Batch size: {BATCH_SIZE:,}")
print("Starting continuous processing...")

In [None]:
import time
from datetime import datetime

def process_batch():
    """Process a single batch. Returns number of rows inserted."""
    result = spark.sql(f"""
        INSERT INTO {OUTPUT_TABLE}
        SELECT 
            CAST(w.id AS STRING) as work_id,
            ai_query(
                '{ENDPOINT_NAME}',
                CONCAT('Title: ', w.title, '\n\nAbstract: ', w.abstract)
            ) as embedding,
            md5(CONCAT('Title: ', w.title, '\n\nAbstract: ', w.abstract)) as text_hash,
            w.publication_year,
            w.type,
            w.open_access.is_oa as is_oa,
            true as has_abstract,
            current_timestamp() as created_at,
            current_timestamp() as updated_at
        FROM {SOURCE_TABLE} w
        WHERE w.type != 'dataset'
          AND w.abstract IS NOT NULL
          AND w.title IS NOT NULL
          AND NOT EXISTS (
              SELECT 1 FROM {OUTPUT_TABLE} e WHERE e.work_id = CAST(w.id AS STRING)
          )
        LIMIT {BATCH_SIZE}
    """)
    return BATCH_SIZE

def get_count():
    """Get current embedding count (fast - small table)."""
    return spark.sql(f"SELECT COUNT(*) as n FROM {OUTPUT_TABLE}").collect()[0]['n']

# Main loop - run until stopped or complete
total_processed = 0
start_time = time.time()
batch_num = 0
last_count = get_count()

print("="*70)
print(f"[{datetime.now().strftime('%H:%M:%S')}] Starting from {last_count:,} embeddings")
print("="*70)

while True:
    batch_start = time.time()
    batch_num += 1
    
    try:
        process_batch()
        
        batch_elapsed = time.time() - batch_start
        total_elapsed = time.time() - start_time
        
        # Check actual count every 10 batches to detect if we're done
        if batch_num % 10 == 0:
            new_count = get_count()
            actual_added = new_count - last_count
            if actual_added == 0:
                print(f"\n[{datetime.now().strftime('%H:%M:%S')}] No new rows added - embedding complete!")
                break
            last_count = new_count
            rate = actual_added / total_elapsed if total_elapsed > 0 else 0
            remaining_est = 217000000 - new_count
            eta_hours = remaining_est / rate / 3600 if rate > 0 else 0
            print(f"[{datetime.now().strftime('%H:%M:%S')}] Batch {batch_num}: "
                  f"Total: {new_count:,} | "
                  f"Rate: {rate:.1f}/s | "
                  f"ETA: {eta_hours:.1f}h")
        else:
            print(f"[{datetime.now().strftime('%H:%M:%S')}] Batch {batch_num} done in {batch_elapsed:.0f}s")
              
    except Exception as e:
        print(f"[{datetime.now().strftime('%H:%M:%S')}] Error in batch {batch_num}: {e}")
        print("Waiting 60s before retry...")
        time.sleep(60)
        continue

print("="*70)
final_count = get_count()
print(f"Complete! Total embeddings: {final_count:,}")
print(f"Added {final_count - existing:,} in {(time.time() - start_time)/3600:.1f} hours")