In [None]:
%pip install psycopg2-binary requests

In [22]:
# %% Cell 2: Imports & cleanup unused
import os
import math
import numpy as np
from concurrent.futures import ThreadPoolExecutor, as_completed
from sqlalchemy import create_engine, text
import pandas as pd
import psycopg2
import requests
from tqdm import tqdm


In [23]:
# %% Cell 3: Configuration
BATCH_SIZE      = 16
NUM_WORKERS     = 4
DATABASE_URI    = 'postgresql+psycopg2://rg5073:rg5073pass@129.114.27.112:5432/cleaned_meta_data_db'
REMOTE_EMBED_URL = 'http://localhost:8000/batch-embed'   # your hosted FastAPI batch-embed URL


In [24]:
# %% Cell 4: Ensure the vector column exists
engine = create_engine(DATABASE_URI, pool_size=8, max_overflow=0)
with engine.begin() as conn:
    conn.execute(text("""
      ALTER TABLE arxiv_chunks_eval_5
      ADD COLUMN IF NOT EXISTS chunk_embedding_768 vector(768)
    """))


In [25]:
# %% Cell 5: Count total rows
with engine.connect() as conn:
    total = conn.execute(text("SELECT COUNT(*) FROM arxiv_chunks_eval_5")).scalar_one()
print(f"Total rows to embed: {total}")


Total rows to embed: 52554


In [26]:
# %% Cell 6: Remote embedding via batch-embed endpoint
def embed_texts(texts: list[str]) -> list[list[float]]:
    """
    Send a list of texts to the remote batch-embed endpoint,
    receive back list-of-list embeddings.
    """
    resp = requests.post(REMOTE_EMBED_URL, json={"texts": texts})
    resp.raise_for_status()
    return resp.json()["embeddings"]


In [27]:
# %% Cell 7: Fetch, embed, and update one batch
def process_batch(offset: int) -> int:
    # 1) fetch batch
    with engine.connect() as conn:
        rows = conn.execute(
            text("""
              SELECT paper_id, chunk_id, chunk_data
                FROM arxiv_chunks_eval_5
               ORDER BY paper_id, chunk_id
               LIMIT :limit OFFSET :offset
            """),
            {"limit": BATCH_SIZE, "offset": offset}
        ).fetchall()
    if not rows:
        return 0

    # 2) compute embeddings remotely
    ids   = [(r.paper_id, r.chunk_id) for r in rows]
    texts = [r.chunk_data for r in rows]
    embs  = embed_texts(texts)  # calls your FastAPI

    # 3) bulk update back into Postgres
    params = [
        {"pid": pid, "cid": cid, "vec": vec}
        for (pid, cid), vec in zip(ids, embs)
    ]
    with engine.begin() as conn:
        conn.execute(
            text("""
              UPDATE arxiv_chunks_eval_5
                 SET chunk_embedding_768 = :vec
               WHERE paper_id = :pid
                 AND chunk_id   = :cid
            """),
            params
        )

    return len(rows)


In [28]:
# %% Cell 8: Compute all offsets
n_batches = math.ceil(total / BATCH_SIZE)
offsets   = [i * BATCH_SIZE for i in range(n_batches)]
print(f"{n_batches} batches, offsets: {offsets[:5]}…")


3285 batches, offsets: [0, 16, 32, 48, 64]…


In [30]:
# %% Cell: Sample Test with Detailed Logging
import logging
import time
import numpy as np

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(levelname)s %(message)s",
    datefmt="%H:%M:%S",
)

sample_texts = [
    "The quick brown fox jumps over the lazy dog.",
    "OpenAI's GPT models are powerful for NLP tasks.",
    "FastAPI + ONNX Runtime is great for serving ML models!"
]

logging.info(f"Prepared {len(sample_texts)} sample texts for embedding test")

# Measure round-trip time for the remote call
t0 = time.time()
logging.info("Sending request to remote /batch-embed endpoint")
embs = embed_texts(sample_texts)
t1 = time.time()

elapsed = t1 - t0
logging.info(f"Received embeddings in {elapsed:.3f} seconds")

# Validate and inspect embeddings
assert isinstance(embs, list) and len(embs) == len(sample_texts), "Unexpected response format"
dim = len(embs[0])
logging.info(f"Each embedding has dimension: {dim}")

for i, vec in enumerate(embs):
    norm = np.linalg.norm(vec)
    logging.info(f" Sample {i}: first 5 dims = {vec[:5]}, L2 norm = {norm:.4f}")


19:48:50 INFO Prepared 3 sample texts for embedding test
19:48:50 INFO Sending request to remote /batch-embed endpoint
19:48:50 INFO Received embeddings in 0.246 seconds
19:48:50 INFO Each embedding has dimension: 768
19:48:50 INFO  Sample 0: first 5 dims = [-0.1988082490473365, 0.003090559815367063, -0.30425620086801547, -0.05810393524977068, 0.40287098614498973], L2 norm = 11.5390
19:48:50 INFO  Sample 1: first 5 dims = [-0.4815902259142604, -0.19317454716656357, -0.12821805587009294, -0.17972209495928837, -0.15113638225011528], L2 norm = 11.8691
19:48:50 INFO  Sample 2: first 5 dims = [-0.15580386124715648, -0.4081037603866528, 0.6036769642549402, -0.05800980198032716, 0.04591636808917803], L2 norm = 10.8987


In [31]:
# %% Cell 9: Run batches in parallel and report progress
processed = 0
with ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor:
    futures = {executor.submit(process_batch, off): off for off in offsets}
    for fut in as_completed(futures):
        done = fut.result()
        processed += done
        print(f"… done {processed}/{total} rows", end="\r")

print(f"\n✅ Finished embedding & updating {processed} rows")


… done 52554/52554 rows
✅ Finished embedding & updating 52554 rows
