# AbstractAna (cleaned)

Refactored to work with the SQLite database produced by **json2db_refactored.ipynb**.

### What this notebook does
1. **Config**: set DB path, vLLM/OpenAI-compatible endpoint, and model name.
2. **DB setup**: connect to SQLite and ensure a `llm_tags` column exists on `papers`.
3. **Batch tagging**: read abstracts, generate 10–20 tags per paper with a local LLM, and write back to DB.
4. **Plotting**: quick top-tags frequency plot.

**Notes**:
- The new SQLite schema (from `json2db_refactored.ipynb`) typically has tables: `papers(id, arxiv_id, version, title, summary, published, year, month)`, `authors(paper_id, position, author)`, and an FTS table.
- If you need arXiv categories (e.g., `cs.CV`), we should extend the importer to store them. For now, this notebook works *without* categories.


In [ ]:
# --- Config ---
from pathlib import Path
DB_PATH = Path('/Users/wenzheng/Desktop/LLM CS quant/ZZW-LLM/RAGAnalyzer/arxiv.db')  # <-- update if needed
OPENAI_BASE_URL = 'http://localhost:8889/v1'  # vLLM/OpenAI-compatible endpoint
OPENAI_MODEL = '/models/Qwen3-8B'             # your local model name on vLLM
BATCH_SIZE = 8                                 # how many abstracts per request
MAX_TAGS = 20                                   # cap tags per abstract
DRY_RUN = False                                 # True → do not write to DB
LIMIT = 200                                     # process at most this many papers (None for all)
print(DB_PATH.resolve())


In [ ]:
# --- Imports ---
import sqlite3, re, aiohttp, asyncio, json
from typing import List, Tuple

def remove_think_tag(text: str) -> str:
    return re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL)

def ensure_llm_tags_column(con: sqlite3.Connection):
    cur = con.cursor()
    cur.execute("PRAGMA table_info(papers)")
    cols = {row[1] for row in cur.fetchall()}
    if 'llm_tags' not in cols:
        cur.execute("ALTER TABLE papers ADD COLUMN llm_tags TEXT")
        con.commit()
        print("[init] Added llm_tags TEXT to papers")
    else:
        print("[init] llm_tags column already present")

def load_papers(con: sqlite3.Connection, limit: int | None = None) -> List[Tuple[str, str]]:
    # Returns list of (id, summary) for papers without llm_tags and with non-empty summary
    cur = con.cursor()
    q = "SELECT id, summary FROM papers WHERE summary IS NOT NULL AND trim(summary) <> '' AND (llm_tags IS NULL OR trim(llm_tags) = '') ORDER BY published DESC"
    if limit:
        q += f" LIMIT {int(limit)}"
    cur.execute(q)
    return cur.fetchall()

def parse_tags(text: str, max_tags: int = 20) -> list:
    # Accept either comma-separated or line-separated tags
    text = remove_think_tag(text)
    parts = re.split(r"[,\n]", text)
    tags = []
    for p in parts:
        t = p.strip().strip('-•*').strip()
        if t:
            tags.append(t)
    # Dedup while preserving order
    seen = set()
    uniq = []
    for t in tags:
        if t.lower() not in seen:
            uniq.append(t)
            seen.add(t.lower())
    return uniq[:max_tags]


In [ ]:
# --- Async tagging with local vLLM ---
async def tag_batch(session: aiohttp.ClientSession, items: list[tuple[str,str]]) -> list[tuple[str, list[str]]]:
    """
    items: list of (paper_id, abstract)
    returns: list of (paper_id, tags)
    """
    # Build one prompt that includes multiple abstracts, to reduce overhead
    prompt_parts = [
        "You are an expert curator. For each abstract, produce 10–20 concise tags (comma-separated), no preface, no numbering.\n",
        "Keep tags as simple phrases (2–4 words). Avoid duplicates and generic terms like 'paper', 'study'.\n",
    ]
    for idx, (pid, abs_text) in enumerate(items, 1):
        prompt_parts.append(f"\n[#{idx}] ABSTRACT:\n{abs_text}\nTAGS:")
    user_prompt = "".join(prompt_parts)

    payload = {
        "model": OPENAI_MODEL,
        "messages": [
            {"role": "system", "content": "Return only the tags for each abstract in order, separated by new lines; one abstract per line."},
            {"role": "user", "content": user_prompt},
        ],
        "temperature": 0.2,
        "max_tokens": 1024,
    }
    url = f"{OPENAI_BASE_URL}/chat/completions"
    async with session.post(url, json=payload, timeout=aiohttp.ClientTimeout(total=None)) as r:
        r.raise_for_status()
        data = await r.json()
    content = data["choices"][0]["message"]["content"]
    content = remove_think_tag(content)
    lines = [ln.strip() for ln in content.splitlines() if ln.strip()]
    out: list[tuple[str, list[str]]] = []
    for (pid, _), line in zip(items, lines):
        out.append((pid, parse_tags(line, MAX_TAGS)))
    # If model returned fewer lines than inputs, pad empties
    while len(out) < len(items):
        pid, _ = items[len(out)]
        out.append((pid, []))
    return out

async def process_all(papers: list[tuple[str,str]]):
    results: list[tuple[str, list[str]]] = []
    async with aiohttp.ClientSession() as session:
        for i in range(0, len(papers), BATCH_SIZE):
            batch = papers[i:i+BATCH_SIZE]
            tagged = await tag_batch(session, batch)
            results.extend(tagged)
    return results


In [ ]:
# --- Run tagging ---
con = sqlite3.connect(DB_PATH)
ensure_llm_tags_column(con)
papers = load_papers(con, limit=LIMIT)
print(f"Loaded {len(papers)} papers to tag")
if papers:
    tagged = asyncio.run(process_all(papers))
    print(f"Tagged {len(tagged)} papers")
    if not DRY_RUN:
        cur = con.cursor()
        for pid, tags in tagged:
            cur.execute("UPDATE papers SET llm_tags = ? WHERE id = ?", (", ".join(tags), pid))
        con.commit()
        print("[write] Tags stored in DB")
    else:
        print("[dry-run] Skipped writing tags to DB")
else:
    print("Nothing to tag.")
con.close()


## Plot: top tags
Simple exploratory plot of the most frequent tags in the `llm_tags` column.


In [ ]:
# --- Plot top tags ---
import sqlite3, pandas as pd
import matplotlib.pyplot as plt

con = sqlite3.connect(DB_PATH)
df = pd.read_sql_query("SELECT llm_tags FROM papers WHERE llm_tags IS NOT NULL AND trim(llm_tags) <> ''", con)
con.close()

if df.empty:
    print("No tags yet. Run the tagging cell first.")
else:
    # explode comma-separated tags
    tags = (
        df['llm_tags']
        .str.split(',')
        .explode()
        .str.strip()
    )
    top = tags.value_counts().head(30)
    plt.figure()
    top.sort_values().plot(kind='barh')
    plt.title('Top tags')
    plt.xlabel('Count')
    plt.tight_layout()
    plt.show()
