In [8]:
import arxiv
import pandas as pd
from tqdm import tqdm
import time
import requests
import re
import json
from __future__ import annotations
import csv
import tarfile
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Tuple


OPENALEX = "https://api.openalex.org"
arxiv_source_id = "https://openalex.org/S4306400194"


In [9]:
PHYSICS_PREFIXES = (
    "physics.", "astro-ph", "cond-mat", "hep-", "nucl-", "gr-qc", "quant-ph", "math-ph", "nlin"
)
BIOLOGY_PREFIX = "q-bio"

def is_physics(cat : str) -> bool:
    return bool(cat) and cat.startswith(PHYSICS_PREFIXES)

def is_biology(cat : str) -> bool:
    return bool(cat) and cat.startwith(BIOLOGY_PREFIX)

In [10]:
#Use re.compile() to efficiently reuse regex pattern (otherwise python creates a new one each time)

#Newer arxiv ids are in the format of YYMM.numbers(version optional)eg 2105.12345
NEWSTYLE = re.compile(r"^\d{4}\.\d{4,5}(v\d+)?$")
#Old style ids are in format of category(.optional subcategory)/numbers(version optional) eg: cs.AI/0102030
OLDSTYLE = re.compile(r"^[a-z\-]+(\.[A-Z]{2})?\/\d{7}(v\d+)?$", re.IGNORECASE)


def normalize_arxiv_id(aid : str) -> str:
    if not isinstance(aid, str):
        return ""
    #Strip any erroneous whitespace, and also returns empty string in case nothing given
    aid = (aid or "").strip()
    #Substitutes the optional version ending with empty string
    aid = re.sub(r"v\d+$", "", aid)
    return aid


def is_valid_arxiv_id(aid : str) -> bool:
    #Arxiv id must be either new or old style
    return bool(NEWSTYLE.match(aid) or OLDSTYLE.match(aid))

In [11]:
#Function inputs url, and returns raw text parsed as json
def get_json(url, params = None, retries = 6, backoff = 1.6):
    for attempt in range(retries):
        #using requests library to pull website data from url
        r = requests.get(url, params = params, timeout = 45)
        #status code 200 on successful return
        if r.status_code == 200:
            return r.json()
        #failure codes, wait before trying again
        if r.status_code in (429, 500, 502, 503, 504):
            time.sleep(backoff**attempt)
            continue
    

In [12]:
#All Arxiv abstract urls are in the format of arxiv.org/abs/alphanumericid. It's the same thing accordingly for pdf pages
_ARXIV_ABS_RE = re.compile(r"arxiv\.org/abs/([^?#/]+)", re.IGNORECASE)
_ARXIV_PDF_RE = re.compile(r"arxiv\.org/pdf/([^?#/]+)", re.IGNORECASE)

#Taking the json from the previous function and extracting the proper arxiv id
def extract_arxiv_id_from_work(work):
    ids = work.get("ids") or {}
    #The path we care about in the OpenAlex hierarchy goes work->locations->pdf_url/landing_page_url->arxiv link
    for loc in (work.get("locations") or []):
        for key in ("landing_page_url", "pdf_url"):
            u = loc.get(key)
            if not u:
                continue
            #the search function in the re. package checks the entire string if it finds a match of the regex pattern
            m = _ARXIV_ABS_RE.search(u) or _ARXIV_PDF_RE.search(u)
            if m:
                #group takes the first () in the regex which will be the id (also removing any possible 'pdf')
                aid = normalize_arxiv_id(m.group(1).replace(".pdf", ""))
                if is_valid_arxiv_id(aid):
                    return aid
    #In many cases, there will exist papers on OpenAlex (eg: Random Forests) but are not on arXiv. Usually, these papers were
    #published in the pre internet days, so no corresponding upload to arXiv were made. For these papers, we'll just return nothing
    #and treat them as if they don't exist (since we can't parse them from arXiv)
    return None

In [13]:
def fetch_openalex_arxiv_works_cursor(max_works, mailto=None):
    #filter for selecting only sources on OpenAlex that are from arXiv
    the_filter = f"locations.source.id:{arxiv_source_id}"
    #selecting only relevant metadata
    select = ",".join([
        "id", "doi", "title", "publication_year", "cited_by_count",
        "ids", "locations", "type"
    ])

    per_page = 200
    cursor = "*"
    #storing final output here in this format -> [{OpenAlex id: _, doi: _, title: _, year: _, citations: _, arXiv id: _, type: _} ...]
    rows = []
    #for visuals, how far until completion
    pbar = tqdm(total=max_works, desc="OpenAlex fetch")

    while len(rows) < max_works:
        params = {
            "filter": the_filter,
            "sort": "cited_by_count:desc",
            "per-page": per_page,
            "cursor": cursor,
            "select": select
        }
        #using get_json as defined earlier, starting from the OpenAlex API and searching through all the top works
        data = get_json(f"{OPENALEX}/works", params=params)
        results = data.get("results", [])
        #end early if miss
        if not results:
            break
        #For each found article:
        for w in results:
            aid = extract_arxiv_id_from_work(w)
            if not aid:
                continue
            #Saving article metadata into rows
            rows.append({
                "openalex_id": w.get("id"),
                "doi": w.get("doi") or "",
                "title": w.get("title") or "",
                "publication_year": w.get("publication_year"),
                "cited_by_count": int(w.get("cited_by_count") or 0),
                "arxiv_id": aid,
                "type": w.get("type") or ""
            })

            if len(rows) >= max_works:
                break

        pbar.update(min(len(results), max_works - pbar.n))
        cursor = data.get("meta", {}).get("next_cursor")

        if not cursor:
            break
        #make sure to wait to not get timed out by the API
        time.sleep(0.1)

    pbar.close()
    return rows

In [14]:
#given the list of arxiv ids, find the categories of these ids
def arxiv_query_categories(arxiv_ids, batch_size = 50):
    #output will be a map: ids->category
    out = {}
    i = 0
    pbar = tqdm(total=len(arxiv_ids), desc = "arXiv category resolution")
    
    while i < len (arxiv_ids):
        #To avoid bottlenecking the arXiv api, only search in chunks of 50
        chunk = arxiv_ids[i:i+batch_size]
        #arXiv.Search returns a search object that contains all the ids we want arXiv to query
        search = arxiv.Search(id_list = chunk)
        #No actual api call is made until .results() is called, results contains all the metadata about the corresponding article
        #we want its id, and map to its primary category
        for result in search.results():
            aid = result.get_short_id()
            out[aid] = result.primary_category
        
        i+=len(chunk)
        pbar.update(len(chunk))
        #wait to avoid bottlenecking
        time.sleep(0.2)
    pbar.close()
    return out

In [15]:
def sort_key(w):
    year = w["publication_year"] if w["publication_year"] else 9999
    return (-w["cited_by_count"], year, w["arxiv_id"])

def save(name, rows, out_prefix):
    df = pd.DataFrame(rows)
    df.to_csv(f"{out_prefix}_{name}.csv", index=False)
    with open(f"{out_prefix}_{name}.jsonl", "w") as f:
        for r in rows:
            f.write(json.dumps(r) + "\n")

def build_top10k_OpenAlex(oversample_factor = 10, out_prefix = "Test"):
    target = 1000
    fetch_n = target * oversample_factor
    works = fetch_openalex_arxiv_works_cursor(fetch_n)
    pd.DataFrame(works).to_csv(f"{out_prefix}_openalex_raw.csv", index = False)


In [16]:
#Class to hold a single CSV row in a structured way, arXiv ID links to arXiv source download to extract equations and for category lookup
@dataclass
class WorkRow:
    openalex_id: str
    doi: str
    title: str
    publication_year: int
    cited_by_count: int
    arxiv_id: Optional[str]
    work_type: str


In [17]:
def read_openalex_csv_file(csv_path: Path) -> List[WorkRow]:
    """
    Read an OpenAlex CSV file with formatting of build_top10k_OpenAlex():
      openalex_url, doi_url, title, year, cited_by_count, arxiv_id, type

    Returns WorkRow list. Skips a header row if present.
    """
    rows: List[WorkRow] = []
    with csv_path.open("r", encoding="utf-8", newline="") as f:
        reader = csv.reader(f)
        #skip the first line which is just formatting data for the csv: openalex_id,doi,title,publication_year,cited_by_count,arxiv_id,type

        next(reader, None)
        for parts in reader:
            if not parts:
                continue

            # Ensures parts has exactly 7 columns so append more in case something is missing
            while len(parts) < 7:
                parts.append("")

            openalex_id, doi, title, year, cited_by, arxiv_id, work_type = parts[:7]


            year_i = int(year)
            cited_i = int(cited_by)


            arxiv_id = arxiv_id.strip() or None
            #after extracting the data from the csv, create a WorkRow object to append to the output list
            rows.append(
                WorkRow(
                    openalex_id=openalex_id.strip(),
                    doi=doi.strip(),
                    title=title.strip(),
                    publication_year=year_i,
                    cited_by_count=cited_i,
                    arxiv_id=arxiv_id,
                    work_type=work_type.strip(),
                )
            )

    return rows

In [18]:
def fetch_arxiv_metadata(arxiv_id: str) -> Tuple[str, Optional[str]]:
    """
    Fetch arXiv metadata for a given id using arxiv.Search.

    Returns:
      (status, primary_category)
      status:
        - "ok" if found
        - "not_found" if no result returned
        - "metadata_error" on exceptions
    """
    try:
        search = arxiv.Search(id_list=[arxiv_id])
        result = next(search.results(), None)
        if result is None:
            return ("not_found", None)
        primary_category = getattr(result, "primary_category", None)
        return ("ok", primary_category)
    except Exception:
        return ("metadata_error", None)

In [19]:



def download_and_extract_arxiv_source(arxiv_id: str, out_dir: Path) -> str:
    """
    Download the arXiv source package and extract it into out_dir.

    Returns status:
      - "ok" if extracted
      - "download_error" if download failed
      - "extract_error" if tar extraction failed
    """
    try:
        search = arxiv.Search(id_list=[arxiv_id])
        result = next(search.results(), None)
        if result is None:
            return "download_error"
        tar_path = Path(result.download_source(dirpath=str(out_dir)))
    except Exception:
        return "download_error"

    try:
        with tarfile.open(tar_path, "r:*") as tf:
            tf.extractall(path=out_dir)
        return "ok"
    except Exception:
        return "extract_error"


# =========================
# 3) File classification + TeX collection
# =========================

def is_probably_binary_file(path: Path, sample_bytes: int = 4096) -> bool:
    """
    Heuristic to detect binary files:
      - contains NUL byte
      - too many control characters
    """
    try:
        data = path.read_bytes()[:sample_bytes]
    except Exception:
        return True
    if b"\x00" in data:
        return True
    nontext = sum(1 for b in data if b < 9 or (b > 13 and b < 32))
    return nontext > 0.30 * max(1, len(data))


def collect_tex_files(root_dir: Path) -> Tuple[List[Path], int, int]:
    """
    Recursively scan root_dir and return:
      - tex_files: list of .tex files
      - skipped_binary: count of binary-like files
      - skipped_notlatex: count of non-binary but non-.tex files
    """
    tex_files: List[Path] = []
    skipped_binary = 0
    skipped_notlatex = 0

    # common binary extensions (fast skip)
    binary_exts = {".pdf", ".png", ".jpg", ".jpeg", ".gif", ".eps", ".zip", ".gz", ".tar"}

    for p in root_dir.rglob("*"):
        if not p.is_file():
            continue

        if p.suffix.lower() in binary_exts:
            skipped_binary += 1
            continue

        if is_probably_binary_file(p):
            skipped_binary += 1
            continue

        if p.suffix.lower() == ".tex":
            tex_files.append(p)
        else:
            skipped_notlatex += 1

    return tex_files, skipped_binary, skipped_notlatex


# =========================
# 4) Equation extraction
# =========================

EQUATION_ENVS = [
    "equation", "equation*", "align", "align*", "gather", "gather*",
    "multline", "multline*", "eqnarray", "eqnarray*", "flalign", "flalign*",
    "alignat", "alignat*"
]

ENV_PATTERN = re.compile(
    r"\\begin\{(" + "|".join(re.escape(e) for e in EQUATION_ENVS) + r")\}(.*?)\\end\{\1\}",
    re.DOTALL
)
BRACKET_DISPLAY_PATTERN = re.compile(r"\\\[(.*?)\\\]", re.DOTALL)
DOLLAR_DISPLAY_PATTERN = re.compile(r"\$\$(.*?)\$\$", re.DOTALL)


def strip_latex_comments(tex: str) -> str:
    """
    Remove LaTeX comments (%) while preserving escaped \\%.
    """
    return re.sub(r"(?<!\\)%.*", "", tex)


def extract_equations_from_tex(tex: str) -> List[Tuple[str, str]]:
    """
    Extract equation blocks from a TeX document.

    Returns list of (kind, latex_block), where kind is one of:
      - environment:<name>
      - display:\\[...\\]
      - display:$$...$$

    Inline $...$ is NOT extracted by default (too noisy).
    """
    tex = strip_latex_comments(tex)
    found: List[Tuple[str, str]] = []

    for m in ENV_PATTERN.finditer(tex):
        env = m.group(1)
        found.append((f"environment:{env}", m.group(0).strip()))

    for m in BRACKET_DISPLAY_PATTERN.finditer(tex):
        found.append(("display:\\[...\\]", m.group(0).strip()))

    for m in DOLLAR_DISPLAY_PATTERN.finditer(tex):
        found.append(("display:$$...$$", m.group(0).strip()))

    return found


# =========================
# 5) Process a single row -> summary + equations
# =========================

def process_work_row(
    row: WorkRow,
    keep_equations: bool,
) -> Tuple[dict, List[dict]]:
    """
    For one WorkRow (must have arxiv_id):
      - fetch metadata (category)
      - download+extract source
      - collect .tex files
      - extract equations
      - return:
          summary_record, equation_records

    IMPORTANT:
      Every returned record includes cited_by_count and arxiv_primary_category (may be None).
    """
    arxiv_id = normalize_arxiv_id(row.arxiv_id or "")

    # Ensure required fields always present in output
    base_summary = {
        "arxiv_id": arxiv_id if arxiv_id else None,
        "cited_by_count": row.cited_by_count,
        "arxiv_primary_category": None,  # will fill if possible
        "status": "init",
        "tex_files": 0,
        "skipped_binary": 0,
        "skipped_notlatex": 0,
        "equations": 0,
    }

    equation_records: List[dict] = []

    if not arxiv_id or not is_valid_arxiv_id(arxiv_id):
        base_summary["status"] = "not_arxiv"
        return base_summary, equation_records

    # Metadata (category)
    meta_status, primary_cat = fetch_arxiv_metadata(arxiv_id)
    base_summary["arxiv_primary_category"] = primary_cat

    if meta_status != "ok":
        # We *still* try to download sources; category may remain None.
        # But status should reflect metadata problem only if download succeeds.
        pass

    with tempfile.TemporaryDirectory() as tmp:
        tmp_dir = Path(tmp)

        dl_status = download_and_extract_arxiv_source(arxiv_id, tmp_dir)
        if dl_status != "ok":
            base_summary["status"] = dl_status if meta_status == "ok" else f"{meta_status}+{dl_status}"
            return base_summary, equation_records

        tex_files, skipped_binary, skipped_notlatex = collect_tex_files(tmp_dir)

        eq_count = 0
        for tex_path in tex_files:
            try:
                content = tex_path.read_text(errors="ignore")
            except Exception:
                continue

            extracted = extract_equations_from_tex(content)
            for idx, (kind, latex_block) in enumerate(extracted):
                eq_count += 1
                if keep_equations:
                    equation_records.append({
                        "arxiv_id": arxiv_id,
                        "cited_by_count": row.cited_by_count,
                        "arxiv_primary_category": primary_cat,
                        "file": str(tex_path.relative_to(tmp_dir)),
                        "kind": kind,
                        "index": idx,
                        "latex": latex_block,
                    })

        base_summary.update({
            "status": "ok" if meta_status == "ok" else f"{meta_status}+ok",
            "tex_files": len(tex_files),
            "skipped_binary": skipped_binary,
            "skipped_notlatex": skipped_notlatex,
            "equations": eq_count,
        })

        return base_summary, equation_records


# =========================
# 6) Batch driver: CSV file -> JSONL outputs
# =========================

def build_equation_dataset_from_openalex_csv_file(
    csv_path: Path,
    out_summary_jsonl: Path,
    out_equations_jsonl: Optional[Path] = None,
    max_papers: Optional[int] = None,
) -> None:
    """
    Full pipeline:
      - read OpenAlex CSV file
      - filter rows with arxiv_id
      - process each row
      - write JSONL:
          summary.jsonl (required)
          equations.jsonl (optional)
    """
    rows = read_openalex_csv_file(csv_path)
    rows = [r for r in rows if r.arxiv_id]

    if max_papers is not None:
        rows = rows[:max_papers]

    out_summary_jsonl.parent.mkdir(parents=True, exist_ok=True)
    if out_equations_jsonl is not None:
        out_equations_jsonl.parent.mkdir(parents=True, exist_ok=True)

    with out_summary_jsonl.open("w", encoding="utf-8") as fsum:
        feq = out_equations_jsonl.open("w", encoding="utf-8") if out_equations_jsonl else None
        try:
            for row in rows:
                summary, eqs = process_work_row(row, keep_equations=(feq is not None))
                # Guarantee required fields
                assert "cited_by_count" in summary
                assert "arxiv_primary_category" in summary

                fsum.write(json.dumps(summary, ensure_ascii=False) + "\n")

                if feq is not None:
                    for rec in eqs:
                        # Also guarantee required fields per equation record
                        assert "cited_by_count" in rec
                        assert "arxiv_primary_category" in rec
                        feq.write(json.dumps(rec, ensure_ascii=False) + "\n")
        finally:
            if feq is not None:
                feq.close()


# =========================
# 7) Example usage
# =========================

if __name__ == "__main__":
    csv_path = Path("Test_openalex_raw.csv")

    build_equation_dataset_from_openalex_csv_file(
        csv_path=csv_path,
        out_summary_jsonl=Path("out/summary.jsonl"),
        out_equations_jsonl=Path("out/equations.jsonl"),
        max_papers=50,  # remove/None for full run
    )

    print("Done: out/summary.jsonl and out/equations.jsonl")


  result = next(search.results(), None)
  result = next(search.results(), None)


Done: out/summary.jsonl and out/equations.jsonl
