# Arxiv Eval Data Generation Notebook

This notebook builds a BM25 index over chunks, filters ground truth citations, generates semantic-search queries via Ollama, and updates both a Parquet checkpoint and a Postgres table.


In [1]:
!pip install pandas fastparquet sqlalchemy whoosh psycopg2-binary lmstudio



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


In [2]:
import os
import json
import pandas as pd
from sqlalchemy import create_engine, text
from whoosh.index import create_in, open_dir, exists_in
from whoosh.fields import Schema, TEXT, ID
from whoosh.qparser import QueryParser
from ollama import chat, ChatResponse


In [3]:
# Configuration
DB_URL          = "postgresql://rg5073:rg5073pass@:port/cleaned_meta_data_db"  # <-- update to your Postgres connection
PARQUET_FILE    = "arxiv_eval_results.parquet"
INDEX_DIR       = "whoosh_index"
TOP_N           = 3       # BM25 top-N chunks to consider
MIN_SCORE       = None    # e.g. 1.0 to enforce a minimum BM25 score
OLLAMA_MODEL    = "phi4"   # Ollama model name
NUM_QUERIES     = 3       # Number of queries to generate per chunk

# Create SQLAlchemy engine
engine = create_engine(
    'postgresql+psycopg2://local:password@localhost:5433/mlops_local',
    max_overflow=0,   # disallow “extra” connections beyond pool_size
    pool_timeout=30,  # seconds to wait for an idle connection
)


In [4]:
import os
import logging
from sqlalchemy import create_engine, text
from whoosh.index import create_in, open_dir
from whoosh.fields import Schema, ID, TEXT

# ── Configuration ─────────────────────────────────────────────
BATCH_SIZE    = 5000      # rows per SQL fetch
NUM_PROCS     = 4         # parallel Whoosh writer processes
MEM_LIMIT_MB  = 512       # memory cap per process (MB)

# ── Logging Setup ──────────────────────────────────────────────
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)


def build_whoosh_index():
    # If index dir missing or empty, (re)create & populate it
    new_index = not (os.path.isdir(INDEX_DIR) and os.listdir(INDEX_DIR))
    if new_index:
        os.makedirs(INDEX_DIR, exist_ok=True)
        logger.info("Creating new Whoosh index in %s", INDEX_DIR)

        # Define schema
        schema = Schema(
            chunk_id   = ID(stored=True),
            paper_id   = ID(stored=True),
            chunk_data = TEXT
        )

        ix = create_in(INDEX_DIR, schema)
        writer = ix.writer(
            procs=NUM_PROCS,
            limitmb=MEM_LIMIT_MB,
            multisegment=True
        )

        with engine.connect() as conn:
            last_chunk_id = None
            total_indexed = 0
            batch_num = 0

            while True:
                batch_num += 1
                if last_chunk_id is None:
                    sql = text("""
                        SELECT chunk_id, paper_id, chunk_data
                          FROM arxiv_chunks_backup
                         ORDER BY chunk_id
                         LIMIT :batch
                    """)
                    params = {"batch": BATCH_SIZE}
                else:
                    sql = text("""
                        SELECT chunk_id, paper_id, chunk_data
                          FROM arxiv_chunks_backup
                         WHERE chunk_id > :last
                         ORDER BY chunk_id
                         LIMIT :batch
                    """)
                    params = {"last": last_chunk_id, "batch": BATCH_SIZE}

                # Fetch batch
                result = conn.execution_options(stream_results=True) \
                             .execute(sql, params)
                rows = result.fetchall()
                if not rows:
                    logger.info("No more rows to fetch, ending.")
                    break

                # Index this batch
                logger.info("Batch %d: fetched %d rows (chunk_id > %s)",
                            batch_num, len(rows), last_chunk_id)
                for cid, pid, data in rows:
                    writer.add_document(
                        chunk_id   = str(cid),
                        paper_id   = str(pid),
                        chunk_data = data or ""
                    )
                total_indexed += len(rows)
                logger.info("Batch %d: queued %d docs, total queued %d",
                            batch_num, len(rows), total_indexed)

                # Keyset pagination marker
                last_chunk_id = rows[-1][0]

        # Final commit (merges all parallel segments)
        logger.info("Committing index (this may take a moment)…")
        writer.commit()
        logger.info("Indexing complete: %d documents indexed.", total_indexed)

    else:
        logger.info("Opening existing index in %s", INDEX_DIR)
        ix = open_dir(INDEX_DIR)

    return ix

ix = build_whoosh_index()
logger.info("Whoosh index ready at: %s", INDEX_DIR)


2025-05-11 03:40:58 [INFO] Opening existing index in whoosh_index
2025-05-11 03:40:58 [INFO] Whoosh index ready at: whoosh_index


In [5]:
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq

def initialize_bm25_and_checkpoint(ix):
    # Prepare BM25 query parser
    parser = QueryParser("chunk_data", schema=ix.schema)
    logger.info("Initialized BM25 query parser.")

    # Setup Parquet checkpoint for resumability, with recovery on corrupt file
    if os.path.exists(PARQUET_FILE):
        try:
            df_existing = pd.read_parquet(PARQUET_FILE)
            processed   = set(df_existing["chunk_id"].astype(str).tolist())
            parquet_writer = None
            logger.info(
                "Loaded existing checkpoint '%s' with %d processed chunks.",
                PARQUET_FILE, len(processed)
            )
        except pa.lib.ArrowInvalid as e:
            logger.warning(
                "Corrupt Parquet '%s' (size=%d bytes): %s. Reinitializing checkpoint.",
                PARQUET_FILE, os.path.getsize(PARQUET_FILE), e
            )
            os.remove(PARQUET_FILE)
            # Reinitialize as new
            init_df       = pd.DataFrame(columns=[
                "chunk_id",
                "final_ground_truth",
                "query_list",
                "no_related_gt_flag"
            ])
            init_table    = pa.Table.from_pandas(init_df, preserve_index=False)
            parquet_writer = pq.ParquetWriter(PARQUET_FILE, init_table.schema)
            processed      = set()
            logger.info("Initialized new Parquet checkpoint '%s'.", PARQUET_FILE)
    else:
        init_df       = pd.DataFrame(columns=[
            "chunk_id",
            "final_ground_truth",
            "query_list",
            "no_related_gt_flag"
        ])
        init_table    = pa.Table.from_pandas(init_df, preserve_index=False)
        parquet_writer = pq.ParquetWriter(PARQUET_FILE, init_table.schema)
        processed      = set()
        logger.info("Initialized new Parquet checkpoint '%s'.", PARQUET_FILE)

    return parser, parquet_writer, processed

parser, parquet_writer, processed = initialize_bm25_and_checkpoint(ix)

logger.info("BM25 parser and Parquet checkpoint ready. %d chunks already processed.",
            len(processed))


2025-05-11 03:40:58 [INFO] Initialized BM25 query parser.
2025-05-11 03:40:58 [INFO] Initialized new Parquet checkpoint 'arxiv_eval_results.parquet'.
2025-05-11 03:40:58 [INFO] BM25 parser and Parquet checkpoint ready. 0 chunks already processed.


In [6]:
import lmstudio as lms
from lmstudio import Chat

# Load your model once at startup
MODEL_NAME     = "phi-4"  # or whichever you’ve installed
NUM_QUERIES    = 3
model = lms.get_default_client('192.168.1.150:1234').llm.model(MODEL_NAME)

def generate_queries_with_lmstudio(prompt: str) -> list[str]:
    """
    Uses LM Studio's Python SDK to generate NUM_QUERIES semantic-search prompts.
    Returns a list of up to NUM_QUERIES strings.
    """
    queries: list[str] = []

    # For each query, spin up a fresh chat context so we get independent completions
    for _ in range(NUM_QUERIES):
        chat = Chat()                       # new chat session
        chat.add_user_message(prompt)       # user prompt
        response = model.respond(chat)      # get assistant reply
        # strip out any leading/trailing whitespace
        text = response.strip() if isinstance(response, str) else response
        queries.append(text)
    
    return queries


2025-05-11 03:40:58 [INFO] {"client": "<lmstudio.sync_api.Client object at 0x7f6186379b50>", "event": "Websocket handling thread started", "thread_id": "Thread-4"}
2025-05-11 03:40:58 [INFO] {"event": "Websocket handling task started", "ws_url": "ws://192.168.1.150:1234/llm"}
2025-05-11 03:40:58 [INFO] HTTP Request: GET ws://192.168.1.150:1234/llm "HTTP/1.1 101 Switching Protocols"
2025-05-11 03:40:58 [INFO] {"event": "Websocket session established (ws://192.168.1.150:1234/llm)", "ws_url": "ws://192.168.1.150:1234/llm"}


In [None]:
# ── Cell 2: Semantic-Search Query Generation with ETA ──────────────

import time
import json
import logging
import datetime
from concurrent.futures import ThreadPoolExecutor, as_completed
from sqlalchemy import text
from pydantic import BaseModel
from typing import List

logger = logging.getLogger(__name__)

class QueryList(BaseModel):
    queries: List[str]

MAX_WORKERS = 2
MAX_RETRIES = 3
LOG_INTERVAL = 50

def generate_for_chunk(task):
    chunk_id, chunk_data = task
    snippet = chunk_data or ""

    prompt = (
        f"Below is a snippet from a scientific paper. Generate exactly {NUM_QUERIES} "
        "(1-6 words each) semantic-search queries a researcher might use to find it.\n\n"
        "IMPORTANT: return ONLY a JSON object with a single key \"queries\" whose value "
        "is an array of strings—no markdown, no backticks, no explanations. Exactly:\n"
        '{"queries":["query1","query2","query3"]}\n\n'
        "Snippet:\n\"\"\"\n"
        f"{snippet}\n"
        "\"\"\""
    )

    t0 = time.monotonic()
    queries: List[str] = []
    for attempt in range(1, MAX_RETRIES + 1):
        try:
            result = model.respond(prompt, response_format=QueryList)
            queries = result.parsed["queries"][:NUM_QUERIES]
            if len(queries) < NUM_QUERIES:
                logger.warning("Chunk %s: got %d queries, expected %d", 
                               chunk_id, len(queries), NUM_QUERIES)
            break
        except Exception as e:
            logger.warning("Chunk %s: attempt %d/%d failed (%s)", 
                           chunk_id, attempt, MAX_RETRIES, e)
            if attempt == MAX_RETRIES:
                logger.error("Chunk %s: all retries failed, defaulting to empty list", chunk_id)
                queries = []
    gen_time = time.monotonic() - t0
    logger.debug("Chunk %s: generation took %.2f s", chunk_id, gen_time)

    t1 = time.monotonic()
    with engine.begin() as conn:
        conn.execute(text("""
            UPDATE arxiv_chunks_eval_4
               SET query = :q
             WHERE chunk_id = :cid
        """), {"q": json.dumps(queries), "cid": chunk_id})
    db_time = time.monotonic() - t1
    logger.debug("Chunk %s: DB update took %.2f s", chunk_id, db_time)

    total_time = gen_time + db_time
    logger.info("Chunk %s → stored %d queries (%.2f s)", 
                chunk_id, len(queries), total_time)
    return total_time

# Fetch pending chunks
with engine.connect() as conn:
    rows = conn.execute(text("""
        SELECT chunk_id,
               chunk_data
          FROM public.arxiv_chunks_eval_4
         WHERE paper_cited IS NOT NULL
           AND paper_cited <> ''
           AND (query IS NULL OR query = '' OR query = '[]')
           AND array_length(
                 string_to_array(
                   trim(both '{}' FROM paper_cited),
                   ','
                 ),
                 1
               ) >= 5
    """)).fetchall()


tasks = [(r.chunk_id, r.chunk_data) for r in rows]
total = len(tasks)
logger.info("Generating queries for %d chunks with %d threads", total, MAX_WORKERS)

processed = 0
time_accum = 0.0
start_all = time.monotonic()

with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
    future_to_cid = {executor.submit(generate_for_chunk, t): t[0] for t in tasks}
    for future in as_completed(future_to_cid):
        cid = future_to_cid[future]
        try:
            duration = future.result()
            time_accum += duration
        except Exception as e:
            logger.error("Chunk %s: unexpected error %s", cid, e)
        processed += 1

        if processed % LOG_INTERVAL == 0 or processed == total:
            elapsed = time.monotonic() - start_all
            avg = time_accum / processed if processed else 0
            remaining = avg * (total - processed)
            eta = datetime.datetime.now() + datetime.timedelta(seconds=remaining)
            logger.info(
                "Progress %d/%d — Elapsed: %.1f m, ETA in %.1f m (at %s)",
                processed, total,
                elapsed/60, remaining/60,
                eta.strftime("%Y-%m-%d %H:%M:%S")
            )

logger.info("Query generation complete: %d/%d chunks done", processed, total)


2025-05-11 05:01:29 [INFO] Generating queries for 2746 chunks with 4 threads
2025-05-11 05:01:31 [INFO] Chunk 0712.3869v2_13 → stored 3 queries (1.56 s)


{
  "queries": [
    "genus calculation via branching data",
    "monodromy group reducible case",
    "Klein functions maximal decompositions"
  ]
}
['genus calculation via branching data', 'monodromy group reducible case', 'Klein functions maximal decompositions']


2025-05-11 05:01:32 [INFO] Chunk 0712.3869v2_1 → stored 3 queries (2.84 s)


{
  "queries": [
    "imprimitivity systems permutation groups",
    "generalization Ritt theorem rational functions",
    "Jordan H older theorem maximal decompositions"
  ]
}
['imprimitivity systems permutation groups', 'generalization Ritt theorem rational functions', 'Jordan H older theorem maximal decompositions']


2025-05-11 05:01:33 [INFO] Chunk 0712.3869v2_3 → stored 3 queries (4.03 s)


{
  "queries": [
    "lower semi modular lattices",
    "Ritt theorem rational functions",
    "Jordan H older imprimitivity systems"
  ]
}
['lower semi modular lattices', 'Ritt theorem rational functions', 'Jordan H older imprimitivity systems']


2025-05-11 05:01:34 [INFO] Chunk 0712.3869v2_7 → stored 3 queries (5.14 s)


{
  "queries": [
    "Jordan-H\"older theorem imprimitivity",
    "permutable subgroups lattice",
    "core complementary subgroups"
  ]
}
['Jordan-H"older theorem imprimitivity', 'permutable subgroups lattice', 'core complementary subgroups']


2025-05-11 05:01:36 [INFO] Chunk 0712.3869v2_8 → stored 3 queries (4.93 s)


{
  "queries": [
    "modular lattice isomorphic sublattice",
    "Jordan-H\"older theorem imprimitivity systems",
    "Hamiltonian group normal subgroups"
  ]
}
['modular lattice isomorphic sublattice', 'Jordan-H"older theorem imprimitivity systems', 'Hamiltonian group normal subgroups']


2025-05-11 05:01:37 [INFO] Chunk 1001.2978v1_2 → stored 3 queries (4.41 s)


{
  "queries": [
    "size multiplication preferential relations",
    "semantical interpolation monotonic non-monotonic",
    "revision distance relations"
  ]
}
['size multiplication preferential relations', 'semantical interpolation monotonic non-monotonic', 'revision distance relations']


2025-05-11 05:01:38 [INFO] Chunk 0712.3869v2_16 → stored 3 queries (4.48 s)


{
  "queries": [
    "Belyi functions tetrahedron",
    "maximal decompositions S4",
    "Ritt theorem counterexamples"
  ]
}
['Belyi functions tetrahedron', 'maximal decompositions S4', 'Ritt theorem counterexamples']


2025-05-11 05:01:39 [INFO] Chunk 0712.2596v2_17 → stored 3 queries (4.53 s)


{
  "queries": [
    "Cooper problem on a ring",
    "Little Parks period oscillations",
    "instanton approach flux oscillation"
  ]
}
['Cooper problem on a ring', 'Little Parks period oscillations', 'instanton approach flux oscillation']


2025-05-11 05:01:40 [INFO] Chunk 0712.2596v2_18 → stored 3 queries (4.51 s)


{
  "queries": [
    "instanton amplitude single electron",
    "ground state energy two electrons",
    "cooper pair separation scales"
  ]
}
['instanton amplitude single electron', 'ground state energy two electrons', 'cooper pair separation scales']


2025-05-11 05:01:42 [INFO] Chunk 0712.2596v2_4 → stored 3 queries (4.99 s)


{
  "queries": [
    "self consistency equation superconductors",
    "critical temperature Debye frequency",
    "Cooper pair size critical temperature"
  ]
}
['self consistency equation superconductors', 'critical temperature Debye frequency', 'Cooper pair size critical temperature']


2025-05-11 05:01:43 [INFO] Chunk 0712.2596v2_5 → stored 3 queries (4.98 s)


{
  "queries": [
    "exchange field superconductivity transitions",
    "flux effect on superconducting states",
    "phase diagram critical temperature flux"
  ]
}
['exchange field superconductivity transitions', 'flux effect on superconducting states', 'phase diagram critical temperature flux']


2025-05-11 05:01:44 [INFO] Chunk 0712.2596v2_9 → stored 3 queries (4.96 s)


{
  "queries": [
    "superconducting phase transition",
    "finite size effects superconductivity",
    "critical temperature oscillations"
  ]
}
['superconducting phase transition', 'finite size effects superconductivity', 'critical temperature oscillations']


2025-05-11 05:01:45 [INFO] Chunk 0712.2596v2_10 → stored 3 queries (4.87 s)


{
  "queries": [
    "finite radius superconducting rings",
    "oscillations in Tc with magnetic flux",
    "Little Parks effect modifications"
  ]
}
['finite radius superconducting rings', 'oscillations in Tc with magnetic flux', 'Little Parks effect modifications']


2025-05-11 05:01:46 [INFO] Chunk 0712.2596v2_11 → stored 3 queries (4.79 s)


{
  "queries": [
    "double solutions critical temperature",
    "orbital pair breaking effect",
    "Cooper pairs even odd effect"
  ]
}
['double solutions critical temperature', 'orbital pair breaking effect', 'Cooper pairs even odd effect']


2025-05-11 05:01:48 [INFO] Chunk 0712.2596v2_13 → stored 3 queries (4.74 s)


{
  "queries": [
    "Matsubara sum upper cutoff",
    "critical temperature Tc equation",
    "finiteness correction radius"
  ]
}
['Matsubara sum upper cutoff', 'critical temperature Tc equation', 'finiteness correction radius']


In [8]:
# # --- BM25 Filtering Test for One Sample Row (fixed) ---

# import json
# from sqlalchemy import text

# # Fetch one sample row that still needs a query
# with engine.connect() as conn:
#     sample_row = conn.execute(text("""
#         SELECT paper_id, chunk_id, chunk_data, paper_cited
#           FROM arxiv_chunks_eval_4
#          WHERE query = ''
#          LIMIT 1
#     """)).mappings().first()

# print("=== Sample Row ===")
# for k, v in sample_row.items():
#     print(f"{k}: {v!r}")

# # Preview the snippet
# snippet = sample_row["chunk_data"] or ""
# print("\n=== Snippet Preview ===")
# print(snippet[:200] + ("..." if len(snippet) > 200 else ""))

# # Parse and run BM25 search, extracting fields inside the searcher context
# q = parser.parse(snippet[:200])   # use first 200 characters as the query text
# with ix.searcher() as searcher:
#     hits = searcher.search(q, limit=TOP_N)
#     retrieved = set()
#     for h in hits:
#         if MIN_SCORE is None or h.score >= MIN_SCORE:
#             # Access stored field while searcher is open
#             retrieved.add(h["paper_id"])

# # Original ground truth list
# original_gt = (
#     sample_row["paper_cited"].strip("{}").split(",")
#     if sample_row["paper_cited"] else []
# )
# original_gt = [pid for pid in original_gt if pid]

# # Filter ground truths by BM25 hits
# filtered_gt = [pid for pid in original_gt if pid in retrieved]

# print("\nOriginal Ground Truth IDs:", original_gt)
# print("Retrieved by BM25:", retrieved)
# print("Filtered Ground Truth IDs:", filtered_gt)



In [9]:
# # Test snippet to verify BM25 filtering and LM Studio structured query generation for one row

# import json
# from sqlalchemy import text
# from pydantic import BaseModel
# import lmstudio as lms

# # --- Structured response schema for LM Studio ---
# class QuerySchema(BaseModel):
#     queries: list[str]

# # --- Initialize LM Studio model ---
# MODEL_NAME  = "phi-4"        # your installed LM Studio model
# NUM_QUERIES = 3
# model       = lms.llm(MODEL_NAME)

# def generate_queries_with_lmstudio(prompt: str) -> list[str]:
#     """
#     Uses LM Studio Python SDK to return exactly {"queries": [...]}
#     via the QuerySchema, and returns that list.
#     """
#     resp = model.respond(
#         prompt,
#         response_format=QuerySchema
#     )
#     parsed = resp.parsed  # this will be a dict conforming to QuerySchema
#     if isinstance(parsed, dict):
#         return parsed["queries"]
#     # fallback if parsed is a pydantic instance
#     return parsed.queries

# # --- Fetch one sample row that still needs a query ---
# with engine.connect() as conn:
#     sample_row = conn.execute(text("""
#         SELECT paper_id, chunk_id, chunk_data, paper_cited
#           FROM arxiv_chunks_eval_4
#          WHERE query = ''
#          LIMIT 1
#     """)).mappings().first()

# print("=== Sample Row ===")
# for k, v in sample_row.items():
#     print(f"{k}: {v!r}")

# # --- BM25 filtering ---
# snippet = sample_row["chunk_data"] or ""
# print("\n=== Snippet Preview ===")
# print(snippet[:200] + ("..." if len(snippet) > 200 else ""))

# q = parser.parse(snippet[:200])
# with ix.searcher() as searcher:
#     hits = searcher.search(q, limit=TOP_N)
#     retrieved = {
#         h["paper_id"] for h in hits
#         if MIN_SCORE is None or h.score >= MIN_SCORE
#     }

# original_gt = (
#     sample_row["paper_cited"].strip("{}").split(",")
#     if sample_row["paper_cited"] else []
# )
# original_gt = [pid for pid in original_gt if pid]
# filtered_gt = [pid for pid in original_gt if pid in retrieved]

# print("\nOriginal Ground Truth IDs:", original_gt)
# print("Retrieved by BM25:", retrieved)
# print("Filtered Ground Truth IDs:", filtered_gt)

# # --- LM Studio structured query generation ---
# prompt = (
#     f"Given this scientific snippet, write {NUM_QUERIES} concise "
#     f"semantic-search queries a researcher might use to find it:\n\n\"{snippet}\""
# )

# print("\n=== Prompt ===")
# print(prompt)

# generated_queries = generate_queries_with_lmstudio(prompt)

# print("\nGenerated Queries (structured list):")
# print(generated_queries)


In [10]:
# # ── Sample Test: BM25 Top-N Unique Paper Filtering ──────────────────────────

# import json
# from sqlalchemy import create_engine, text
# from whoosh.index import open_dir
# from whoosh.qparser import QueryParser

# TOP_N      = 5
# MIN_SCORE  = None        # or e.g. 1.0 to threshold scores

# with engine.connect() as conn:
#     sample = conn.execute(text("""
#         SELECT chunk_id, paper_cited, chunk_data
#           FROM arxiv_chunks_eval_4
#          WHERE paper_cited IS NOT NULL
#            AND paper_cited <> ''
#          LIMIT 1
#     """)).mappings().first()

# print("=== Sample Row ===")
# print(f"chunk_id   : {sample['chunk_id']}")
# print(f"paper_cited: {sample['paper_cited']!r}")
# print(f"snippet    : {sample['chunk_data'][:200]!r}…")

# # ── Parse original citations ────────────────────────────────────────────────
# original = sample["paper_cited"].strip("{}").split(",")
# original = [pid for pid in original if pid]
# print("\nOriginal cited papers:", original)

# # ── BM25 search & collect top-N unique paper_ids ────────────────────────────
# snippet = (sample["chunk_data"] or "")[:200]
# q       = parser.parse(snippet)

# unique_pids = []
# with ix.searcher() as searcher:
#     hits = searcher.search(q, limit=TOP_N * 10)
#     for hit in hits:
#         if MIN_SCORE is not None and hit.score < MIN_SCORE:
#             continue
#         pid = hit["paper_id"]
#         if pid not in unique_pids:
#             unique_pids.append(pid)
#             if len(unique_pids) == TOP_N:
#                 break

# print("BM25 top-N unique paper_ids:", unique_pids)

# # ── Filter original citations against those unique hits ────────────────────
# filtered = [pid for pid in original if pid in unique_pids]
# print("Filtered cited papers      :", filtered)


In [11]:
# # ── Cell 1: BM25‐Based Ground‐Truth Filtering (Thread‐Isolated, Verbose Logging) ──

# import json
# import logging
# from concurrent.futures import ThreadPoolExecutor, as_completed
# from sqlalchemy import text
# from whoosh.index import open_dir
# from whoosh.qparser import QueryParser

# # Assumes:
# # engine      : SQLAlchemy engine
# # INDEX_DIR   : path to your Whoosh index dir
# # TOP_N       : int, e.g. 5
# # MIN_SCORE   : float or None
# # MAX_WORKERS : int, e.g. 4
# # logger      : configured logger
# MAX_WORKERS = 4

# def filter_one(task):
#     chunk_id, cited_str, chunk_data = task
#     logger.info("Thread %s: start filtering", chunk_id)
#     # Re-open index and parser inside each thread for isolation
#     ix_local    = open_dir(INDEX_DIR)
#     parser_local = QueryParser("chunk_data", schema=ix_local.schema)

#     # 1) Parse original citations
#     if isinstance(cited_str, str) and cited_str.startswith("{") and cited_str.endswith("}"):
#         original = [pid.strip() for pid in cited_str[1:-1].split(",") if pid.strip()]
#     else:
#         original = []
#     logger.debug("Chunk %s: original citations %s", chunk_id, original)

#     # 2) BM25 search
#     snippet     = (chunk_data or "")[:200]
#     q_local     = parser_local.parse(snippet)
#     window_size = TOP_N * 10

#     unique_pids = []
#     with ix_local.searcher() as searcher:
#         hits = searcher.search(q_local, limit=window_size)
#         logger.debug("Chunk %s: retrieved %d hits", chunk_id, len(hits))
#         for h in hits:
#             score = h.score
#             pid   = h["paper_id"]
#             if MIN_SCORE is not None and score < MIN_SCORE:
#                 logger.debug("Chunk %s: skip pid %s (score %.2f < %.2f)", chunk_id, pid, score, MIN_SCORE)
#                 continue
#             if pid not in unique_pids:
#                 unique_pids.append(pid)
#                 logger.debug("Chunk %s: keep unique pid %s", chunk_id, pid)
#                 if len(unique_pids) >= TOP_N:
#                     break

#     # 3) Filter citations
#     filtered = [pid for pid in original if pid in unique_pids]
#     logger.info("Chunk %s: filtered down to %d/%d citations", chunk_id, len(filtered), len(original))

#     return chunk_id, filtered

# # 1) Load tasks (one-time)
# with engine.connect() as conn:
#     rows = conn.execute(text("""
#         SELECT chunk_id, paper_cited, chunk_data
#           FROM arxiv_chunks_eval_4
#          WHERE paper_cited IS NOT NULL AND paper_cited <> ''
#     """)).fetchall()

# tasks = [(r.chunk_id, r.paper_cited, r.chunk_data) for r in rows]
# total = len(tasks)
# logger.info("Loaded %d tasks; dispatching to %d threads", total, MAX_WORKERS)

# results = []
# processed = 0
# log_interval = 100

# # 2) Execute in parallel, but commit only on main thread
# with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
#     future_to_task = {executor.submit(filter_one, t): t for t in tasks}
#     for future in as_completed(future_to_task):
#         chunk_id = future_to_task[future][0]
#         try:
#             cid, filtered = future.result()
#             results.append((cid, filtered))
#         except Exception as e:
#             logger.error("Chunk %s failed: %s", chunk_id, e)

#         processed += 1
#         if processed % log_interval == 0 or processed == total:
#             logger.info("Overall progress: %d/%d", processed, total)

# # 3) Persist all updates in a single transaction batch
# logger.info("Persisting %d filtered results back to database", len(results))
# with engine.begin() as conn:
#     for chunk_id, filtered in results:
#         array_str = "{" + ",".join(filtered) + "}"
#         conn.execute(text("""
#             UPDATE arxiv_chunks_eval_4
#                SET paper_cited = :pc
#              WHERE chunk_id = :cid
#         """), {"pc": array_str, "cid": chunk_id})

# logger.info("All updates committed; filtering complete.")


In [12]:
# import logging
# import json
# import pandas as pd
# import pyarrow as pa
# import concurrent.futures
# from sqlalchemy import text

# # ── Logging Setup ─────────────────────────────────────────────
# logging.basicConfig(
#     level=logging.INFO,
#     format="%(asctime)s [%(levelname)s] %(message)s",
#     datefmt="%Y-%m-%d %H:%M:%S"
# )
# logger = logging.getLogger(__name__)

# # ── Assumes these are already defined in your notebook: ──────────
# # engine           : SQLAlchemy engine
# # ix               : Whoosh index
# # parser           : QueryParser("chunk_data", schema=ix.schema)
# # parquet_writer   : pyarrow.parquet.ParquetWriter or None
# # processed        : set of chunk_id strings already done
# # PARQUET_FILE     : path to your checkpoint file
# # generate_queries_with_ollama(prompt) : function returning list of strings

# TOP_N       = 3
# MIN_SCORE   = None
# NUM_QUERIES = 3
# MAX_WORKERS = 4

# def process_and_save(row):
#     chunk_id   = str(row["chunk_id"])
#     snippet    = row["chunk_data"] or ""
#     cited_str  = row.get("paper_cited", "")
#     original_gt = cited_str.strip("{}").split(",") if cited_str else []
#     original_gt = [pid for pid in original_gt if pid]

#     logger.info("Processing chunk %s (original_gt=%s)", chunk_id, original_gt)

#     # BM25 filtering
#     q = parser.parse(snippet[:200])
#     with ix.searcher() as searcher:
#         hits = searcher.search(q, limit=TOP_N)
#         retrieved = {
#             h["paper_id"] for h in hits
#             if MIN_SCORE is None or h.score >= MIN_SCORE
#         }
#     filtered_gt = [pid for pid in original_gt if pid in retrieved]
#     no_gt_flag  = not bool(filtered_gt)
#     logger.info("Chunk %s: filtered_gt=%s no_related_gt_flag=%s",
#                 chunk_id, filtered_gt, no_gt_flag)

#     # Ollama query generation
#     prompt = (
#         f"Given this scientific snippet, write {NUM_QUERIES} concise "
#         f"semantic-search queries a researcher might use to find it:\n\n"
#         f"\"{snippet}\""
#     )
#     query_list = generate_queries_with_ollama(prompt)
#     logger.info("Chunk %s: generated %d queries", chunk_id, len(query_list))

#     # Append to Parquet checkpoint
#     df_row = pd.DataFrame([{
#         "chunk_id":           chunk_id,
#         "final_ground_truth": filtered_gt,
#         "query_list":         query_list,
#         "no_related_gt_flag": no_gt_flag
#     }])
#     if parquet_writer:
#         table = pa.Table.from_pandas(df_row, preserve_index=False)
#         parquet_writer.write_table(table)
#         logger.info("Chunk %s: written to Parquet via writer", chunk_id)
#     else:
#         df_row.to_parquet(
#             PARQUET_FILE,
#             engine="fastparquet",
#             append=True,
#             index=False
#         )
#         logger.info("Chunk %s: appended to Parquet file", chunk_id)

#     processed.add(chunk_id)

#     # Update eval table
#     with engine.begin() as conn2:
#         conn2.execute(text("""
#             UPDATE arxiv_chunks_eval_4
#                SET query               = :q,
#                    no_related_gt_flag  = :flag
#              WHERE chunk_id = :cid
#         """), {
#             "q":    json.dumps(query_list),
#             "flag": no_gt_flag,
#             "cid":  chunk_id
#         })
#     logger.info("Chunk %s: database record updated", chunk_id)

#     return chunk_id

# # ── Fetch and filter tasks ───────────────────────────────────────
# logger.info("Fetching chunks needing query generation…")
# with engine.connect() as conn:
#     rows = conn.execute(text("""
#         SELECT paper_id, chunk_id, chunk_data, paper_cited
#           FROM arxiv_chunks_eval_4
#          WHERE query = ''
#     """)).mappings().all()

# tasks = [r for r in rows if str(r["chunk_id"]) not in processed]
# logger.info("Total chunks to process: %d", len(tasks))

# # ── Process in parallel ─────────────────────────────────────────
# with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
#     for cid in executor.map(process_and_save, tasks):
#         logger.info("Finished chunk %s", cid)

# logger.info("Done — total processed: %d", len(processed))
