# Create Work Embeddings for Vector Search

Generates text embeddings for all OpenAlex works using OpenAI `text-embedding-3-small`.

**Format**: `Title: {title}\n\nAbstract: {abstract}`

**Output**: `openalex.vector_search.work_embeddings` Delta table

**Exclusions**: Works with type='dataset' (non-semantic titles)

In [None]:
# Configuration
EMBEDDING_MODEL = "text-embedding-3-small"
EMBEDDING_DIMENSIONS = 1536
BATCH_SIZE = 100  # OpenAI allows up to 2048 per request
OUTPUT_TABLE = "openalex.vector_search.work_embeddings"
SOURCE_TABLE = "openalex.works.openalex_works"

## Step 1: Create External Model Endpoint (run once)

First, create a Databricks Model Serving endpoint for OpenAI embeddings.
This only needs to be done once - skip if endpoint already exists.

In [None]:
# Run this cell ONCE to create the external model endpoint
# The OpenAI API key should be stored in Databricks secrets

import mlflow.deployments

ENDPOINT_NAME = "openai-embedding-3-small"

client = mlflow.deployments.get_deploy_client("databricks")

# Check if endpoint already exists
try:
    existing = client.get_endpoint(ENDPOINT_NAME)
    print(f"Endpoint '{ENDPOINT_NAME}' already exists")
except Exception:
    # Create new endpoint
    endpoint = client.create_endpoint(
        name=ENDPOINT_NAME,
        config={
            "served_entities": [
                {
                    "name": "openai-embeddings",
                    "external_model": {
                        "name": "text-embedding-3-small",
                        "provider": "openai",
                        "task": "llm/v1/embeddings",
                        "openai_config": {
                            "openai_api_key": "{{secrets/openalex/openai_api_key}}"
                        }
                    }
                }
            ],
            "rate_limits": [
                {
                    "calls": 1000,
                    "key": "endpoint",
                    "renewal_period": "minute"
                }
            ]
        }
    )
    print(f"Created endpoint: {endpoint}")

## Step 2: Create output schema and table

In [None]:
%%sql
-- Create schema if not exists
CREATE SCHEMA IF NOT EXISTS openalex.vector_search;

In [None]:
%%sql
-- Create embeddings table if not exists
CREATE TABLE IF NOT EXISTS openalex.vector_search.work_embeddings (
    work_id STRING NOT NULL,
    embedding ARRAY<FLOAT>,
    text_hash STRING,  -- Hash of input text for change detection
    publication_year INT,
    type STRING,
    is_oa BOOLEAN,
    has_abstract BOOLEAN,
    created_at TIMESTAMP,
    updated_at TIMESTAMP
)
USING DELTA
CLUSTER BY (work_id)
TBLPROPERTIES (
    'delta.enableChangeDataFeed' = 'true'
);

## Step 3: Define embedding function

In [None]:
import mlflow.deployments
import hashlib
from pyspark.sql.functions import udf, col, concat_ws, lit, md5, when, coalesce
from pyspark.sql.types import ArrayType, FloatType, StringType

# Initialize MLflow client
mlflow_client = mlflow.deployments.get_deploy_client("databricks")

def format_text_for_embedding(title, abstract):
    """Format title and abstract for embedding."""
    parts = []
    if title:
        parts.append(f"Title: {title}")
    if abstract:
        parts.append(f"Abstract: {abstract}")
    return "\n\n".join(parts) if parts else None

def get_embedding(text):
    """Get embedding for a single text."""
    if not text:
        return None
    try:
        response = mlflow_client.predict(
            endpoint=ENDPOINT_NAME,
            inputs={"input": text}
        )
        return response["data"][0]["embedding"]
    except Exception as e:
        print(f"Error getting embedding: {e}")
        return None

# Register as UDF for Spark
get_embedding_udf = udf(get_embedding, ArrayType(FloatType()))

## Step 4: Get works that need embeddings

This finds works that either:
1. Don't have embeddings yet
2. Have changed (title/abstract modified)

In [None]:
# Read source works
# Note: abstract is already a string column (not inverted index)
works_df = spark.table(SOURCE_TABLE).filter(
    # Exclude datasets - their titles are non-semantic
    col("type") != "dataset"
).select(
    col("id").cast("string").alias("work_id"),
    col("title"),
    col("abstract"),
    col("publication_year"),
    col("type"),
    col("open_access.is_oa").alias("is_oa")
)

print(f"Total works (excluding datasets): {works_df.count():,}")

In [None]:
# Abstract is already a string column - no reconstruction needed
# This cell kept for compatibility but the UDF is not used

In [None]:
# Add embedding text (abstract is already a string)
works_with_text = works_df.withColumn(
    "embedding_text",
    when(
        col("abstract").isNotNull(),
        concat_ws("\n\n", 
            concat_ws(": ", lit("Title"), col("title")),
            concat_ws(": ", lit("Abstract"), col("abstract"))
        )
    ).otherwise(
        concat_ws(": ", lit("Title"), col("title"))
    )
).withColumn(
    "text_hash", md5(col("embedding_text"))
).withColumn(
    "has_abstract", col("abstract").isNotNull()
)

# Show sample
works_with_text.select("work_id", "title", "has_abstract", "embedding_text").show(5, truncate=80)

In [None]:
# Find works that need new/updated embeddings
existing_embeddings = spark.table(OUTPUT_TABLE).select("work_id", "text_hash")

works_to_embed = works_with_text.join(
    existing_embeddings,
    on="work_id",
    how="left_anti"  # Works not in embeddings table
).union(
    # Or works where text has changed
    works_with_text.alias("w").join(
        existing_embeddings.alias("e"),
        (col("w.work_id") == col("e.work_id")) & (col("w.text_hash") != col("e.text_hash")),
        how="inner"
    ).select("w.*")
)

print(f"Works needing embeddings: {works_to_embed.count():,}")

## Step 5: Generate embeddings in batches

Process in batches to manage memory and allow checkpointing.

In [None]:
from pyspark.sql.functions import current_timestamp, monotonically_increasing_id

# Add batch ID for processing
RECORDS_PER_BATCH = 10000

works_batched = works_to_embed.withColumn(
    "batch_id", (monotonically_increasing_id() / RECORDS_PER_BATCH).cast("int")
)

total_batches = works_batched.select("batch_id").distinct().count()
print(f"Total batches: {total_batches}")

In [None]:
# Process batches
# Note: For production, use ai_query() with batch inference for better performance

for batch_num in range(total_batches):
    print(f"Processing batch {batch_num + 1}/{total_batches}...")
    
    batch_df = works_batched.filter(col("batch_id") == batch_num)
    
    # Generate embeddings
    embedded_df = batch_df.withColumn(
        "embedding", get_embedding_udf(col("embedding_text"))
    ).withColumn(
        "created_at", current_timestamp()
    ).withColumn(
        "updated_at", current_timestamp()
    ).select(
        "work_id",
        "embedding",
        "text_hash",
        "publication_year",
        "type",
        "is_oa",
        "has_abstract",
        "created_at",
        "updated_at"
    )
    
    # Write to Delta table (upsert)
    embedded_df.write.format("delta").mode("append").saveAsTable(OUTPUT_TABLE)
    
    print(f"  Completed batch {batch_num + 1}")

## Step 6: Verify output

In [None]:
%%sql
SELECT 
    COUNT(*) as total_embeddings,
    SUM(CASE WHEN has_abstract THEN 1 ELSE 0 END) as with_abstract,
    SUM(CASE WHEN NOT has_abstract THEN 1 ELSE 0 END) as title_only,
    MIN(created_at) as oldest,
    MAX(created_at) as newest
FROM openalex.vector_search.work_embeddings;

In [None]:
%%sql
-- Check embedding dimensions
SELECT 
    work_id,
    SIZE(embedding) as embedding_dims
FROM openalex.vector_search.work_embeddings
LIMIT 5;

## Alternative: Use ai_query for better batch performance

For production, `ai_query()` with the external model endpoint provides better batch performance.

In [None]:
# Alternative approach using ai_query (recommended for production)
# This requires the external model endpoint to be created first

# Example SQL (run in separate notebook or SQL editor):
sql_query = """
INSERT INTO openalex.vector_search.work_embeddings
SELECT 
    CAST(id AS STRING) as work_id,
    ai_query(
        'openai-embedding-3-small',
        CONCAT('Title: ', title, COALESCE(CONCAT('\n\nAbstract: ', abstract), ''))
    ) as embedding,
    md5(CONCAT('Title: ', title, COALESCE(CONCAT('\n\nAbstract: ', abstract), ''))) as text_hash,
    publication_year,
    type,
    open_access.is_oa as is_oa,
    abstract IS NOT NULL as has_abstract,
    current_timestamp() as created_at,
    current_timestamp() as updated_at
FROM openalex.works.openalex_works
WHERE type != 'dataset'
  AND title IS NOT NULL
  AND CAST(id AS STRING) NOT IN (SELECT work_id FROM openalex.vector_search.work_embeddings)
LIMIT 10000
"""
print("To use ai_query, run the above SQL in a Databricks SQL notebook")