# Export Embeddings to Parquet

Exports the 217M embeddings from `openalex.vector_search.work_embeddings_v2` to S3 as Parquet.

This is a safe prep step that doesn't touch Elasticsearch. The Parquet files can then be
used by `sync_embeddings_to_es.ipynb` to bulk load embeddings into ES.

**Expected runtime**: 2-4 hours
**Output**: `s3://openalex-ingest/embeddings/work_embeddings_v2/`

In [None]:
from pyspark.sql import functions as F
from datetime import datetime, timezone
import time

# Configuration
SOURCE_TABLE = "openalex.vector_search.work_embeddings_v2"
S3_OUTPUT_PATH = "s3://openalex-ingest/embeddings/work_embeddings_v2"
NUM_PARTITIONS = 1000  # ~217K rows per partition for parallel processing

print(f"Source: {SOURCE_TABLE}")
print(f"Output: {S3_OUTPUT_PATH}")
print(f"Partitions: {NUM_PARTITIONS}")

In [None]:
# Load source table
print("Loading embeddings table...")
t0 = time.time()

df = spark.table(SOURCE_TABLE)

# Get count
total_count = df.count()
print(f"Total embeddings: {total_count:,}")
print(f"Count took {time.time() - t0:.1f}s")

In [None]:
# Check schema
print("Schema:")
df.printSchema()

# Sample row
print("\nSample row:")
sample = df.limit(1).collect()[0]
print(f"  work_id: {sample.work_id}")
print(f"  embedding dims: {len(sample.embedding)}")
print(f"  embedding type: {type(sample.embedding[0])}")

In [None]:
# Repartition for efficient parallel writes
# Using range partitioning on work_id for even distribution
print(f"Repartitioning to {NUM_PARTITIONS} partitions...")
t0 = time.time()

df_partitioned = df.repartitionByRange(NUM_PARTITIONS, "work_id")

print(f"Actual partitions: {df_partitioned.rdd.getNumPartitions()}")
print(f"Repartition took {time.time() - t0:.1f}s")

In [None]:
# Write to Parquet
print(f"\nWriting to {S3_OUTPUT_PATH}...")
print(f"Started at: {datetime.now(timezone.utc).isoformat()}")
t0 = time.time()

(
    df_partitioned
    .write
    .mode("overwrite")
    .parquet(S3_OUTPUT_PATH)
)

elapsed = time.time() - t0
print(f"\nWrite complete!")
print(f"Elapsed: {elapsed/60:.1f} minutes ({elapsed/3600:.2f} hours)")
print(f"Throughput: {total_count/elapsed:,.0f} rows/sec")

In [None]:
# Verify the output
print("Verifying output...")

df_verify = spark.read.parquet(S3_OUTPUT_PATH)
verify_count = df_verify.count()

print(f"Source count: {total_count:,}")
print(f"Output count: {verify_count:,}")

if verify_count == total_count:
    print("\n✓ Counts match! Export successful.")
else:
    print(f"\n✗ COUNT MISMATCH! Missing {total_count - verify_count:,} rows")
    raise Exception("Export verification failed")

In [None]:
# Show file stats
print("\nOutput file stats:")
files = dbutils.fs.ls(S3_OUTPUT_PATH.replace("s3://", "s3a://"))
parquet_files = [f for f in files if f.name.endswith(".parquet")]
total_size_gb = sum(f.size for f in parquet_files) / (1024**3)

print(f"  Files: {len(parquet_files)}")
print(f"  Total size: {total_size_gb:.2f} GB")
print(f"\nReady for bulk load to Elasticsearch!")