# Embedding-based Scientific Article Recommendation System

In this section, we implement the *embedding-based* component of our article recommendation pipeline.
The goal is to take a short text fragment as input and return the **N most semantically similar scientific articles** 
from our arXiv-based dataset.

We will:
1. Load and preprocess the text data.
2. Generate dense embeddings for all articles using a pre-trained model.
3. Encode the input query into the same semantic space.
4. Retrieve the top-N most similar articles using **K-Nearest Neighbors (KNN)** or **Approximate Nearest Neighbors (ANN)**.
5. Evaluate the system performance using the **Top-N Accuracy** metric.


In [1]:
# Core Python libraries
import numpy as np
import pandas as pd
from tqdm import tqdm
import json
import re
import os

# NLP and Embeddings
from sentence_transformers import SentenceTransformer
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics.pairwise import cosine_similarity
import faiss


# Visualization and evaluation
import matplotlib.pyplot as plt
import seaborn as sns

  from .autonotebook import tqdm as notebook_tqdm


Function: load_and_preprocess_data()

Description:
    Load and preprocess arXiv metadata from a large JSONL file.
    Reads the arXiv dataset line by line (to handle very large files efficiently),
    extracts essential fields, cleans the text from HTML tags and symbols, and
    constructs a DataFrame containing one row per article.

    The function also adds a 'categories' column, which may contain multiple
    space-separated category labels such as 'math.PR math.AG'.

Inputs:
    data_path (str): Path to the JSONL file (each line is a separate JSON object).
    max_rows (int, optional): Maximum number of rows to read (useful for testing
                                on a subset). If None, reads the entire file.

Outputs:
    df (pd.DataFrame): DataFrame containing the following columns:
        - 'id'         : arXiv identifier
        - 'title'      : article title
        - 'authors'    : author list (string)
        - 'abstract'   : original abstract text
        - 'clean_text' : preprocessed abstract text
        - 'categories' : string of one or more categories (e.g. 'math.PR math.AG')

Notes:
    - Lines that cannot be decoded as JSON are skipped automatically.
    - If an abstract is missing or empty, the entry is ignored.
    - Special characters and LaTeX commands are stripped from 'clean_text'.

In [2]:
def load_and_preprocess_data(data_path: str, max_rows: int = None) -> pd.DataFrame:
    records = []
    with open(data_path, "r", encoding="utf-8") as f:
        for i, line in enumerate(f):
            if max_rows and i >= max_rows:
                break
            try:
                obj = json.loads(line)
                if "abstract" not in obj or not obj["abstract"]:
                    continue
                abstract = obj["abstract"]
                clean_text = (
                    re.sub(r"<.*?>", "", abstract)
                    .replace("\n", " ")
                    .replace("\\", "")
                )
                clean_text = re.sub(r"[^a-zA-Z0-9\s]", "", clean_text).lower().strip()
                categories = obj.get("categories", "").strip()
                records.append({
                    "id": obj.get("id", ""),
                    "title": obj.get("title", ""),
                    "authors": obj.get("authors", ""),
                    "abstract": abstract,
                    "clean_text": clean_text,
                    "categories": categories
                })
            except json.JSONDecodeError:
                # Skip malformed lines
                continue

    df = pd.DataFrame(records)
    return df

In [3]:
# Example usage
df = load_and_preprocess_data('./dataset/arxiv-metadata-oai-snapshot.json')

print("Number of articles loaded:", len(df))
print("Columns:", df.columns.tolist())
print(df.head(3)[["id", "title", "categories"]])

Number of articles loaded: 2872766
Columns: ['id', 'title', 'authors', 'abstract', 'clean_text', 'categories']
          id                                              title  \
0  0704.0001  Calculation of prompt diphoton production cros...   
1  0704.0002           Sparsity-certifying Graph Decompositions   
2  0704.0003  The evolution of the Earth-Moon system based o...   

       categories  
0          hep-ph  
1   math.CO cs.CG  
2  physics.gen-ph  


Function: sample_by_top_categories()

Description:
Select a balanced sample of articles by top categories.
    - Computes the frequency of each (primary) category in `df`.
    - Selects the top-k categories by frequency.
    - For each selected category, samples up to `n` articles.
    - If some selected categories have fewer than `n` articles, attempts to
        fill the remaining quota by drawing from the next-most-frequent categories,
        in order, until approximately N = n * k articles have been collected
        or the dataset is exhausted.
    - Preserves uniqueness by 'id' (no duplicate articles in the output).

Inputs:
    df (pd.DataFrame): Input DataFrame containing at least the category column
                        and an 'id' column.
    k (int): Number of top categories to consider.
    n (int): Desired number of articles per category (target).
    category_col (str): Name of the column holding category information
                        (default: "categories"). The cell expects strings
                        like "cs.AI cs.LG" or "hep-ph".
    primary_split (str): Delimiter to split multi-label category strings;
                            default is space (" "). The first token becomes
                            the "primary" category.
    random_state (int): Random seed used for sampling for reproducibility.

Outputs:
    sampled_df (pd.DataFrame): DataFrame containing the sampled articles.
    info (dict): Summary information including:
        - 'selected_categories' : list of the final categories used (in order)
        - 'counts_by_category'  : dict mapping category -> number sampled
        - 'target_total'        : requested total (n * k)
        - 'actual_total'        : actual number of returned articles

Notes:
    - The function will raise a ValueError if the category_col or 'id' column
        are missing from `df`.
    - If df contains categories as lists rather than strings, it will try to
        handle them by taking the first element.

In [4]:
def sample_by_top_categories(df: pd.DataFrame, k: int, n: int,
                             category_col: str = "categories",
                             primary_split: str = " ",
                             random_state: int = 42) -> (pd.DataFrame, dict):
    
    if category_col not in df.columns:
        raise ValueError(f"DataFrame does not contain column '{category_col}'")
    if 'id' not in df.columns:
        raise ValueError("DataFrame must contain an 'id' column for uniqueness checks")

    # Create a primary category column
    def _extract_primary(cat):
        if pd.isna(cat):
            return "unknown"
        if isinstance(cat, list) and len(cat) > 0:
            return str(cat[0]).strip()
        s = str(cat).strip()
        if s == "":
            return "unknown"
        return s.split(primary_split)[0] if primary_split in s else s

    df = df.copy()
    df["primary_category"] = df[category_col].apply(_extract_primary)

    # Compute counts and order categories
    counts = df["primary_category"].value_counts()
    all_categories = counts.index.tolist()
    if len(all_categories) == 0:
        return pd.DataFrame(columns=df.columns), {
            "selected_categories": [],
            "counts_by_category": {},
            "target_total": n * k,
            "actual_total": 0
        }

    # Determine how many top categories we can actually use
    k_adj = min(k, len(all_categories))
    top_k_categories = all_categories[:k_adj]

    target_total = n * k
    sampled_ids = set()
    sampled_rows = []
    counts_by_category = {}

    # First pass: take up to n from each of the top_k categories
    for cat in top_k_categories:
        cat_df = df[df["primary_category"] == cat]
        take = min(n, len(cat_df))
        if take > 0:
            sampled_cat = cat_df.sample(n=take, random_state=random_state)
            # ensure uniqueness by id
            sampled_cat = sampled_cat[~sampled_cat['id'].isin(sampled_ids)]
            # update sets
            sampled_rows.append(sampled_cat)
            sampled_ids.update(sampled_cat['id'].tolist())
            counts_by_category[cat] = len(sampled_cat)
        else:
            counts_by_category[cat] = 0

    # If we still need more to reach target_total, iterate over remaining categories
    if len(sampled_ids) < target_total:
        remaining_needed = target_total - len(sampled_ids)
        # get remaining categories in order of frequency (after the initial top_k)
        remaining_categories = [c for c in all_categories if c not in top_k_categories]
        # iterate and pull up to `n` from each until we fill or run out
        for cat in remaining_categories:
            if remaining_needed <= 0:
                break
            cat_df = df[df["primary_category"] == cat]
            # exclude already sampled ids
            cat_df = cat_df[~cat_df['id'].isin(sampled_ids)]
            if len(cat_df) == 0:
                continue
            take = min(n, len(cat_df), remaining_needed)
            sampled_cat = cat_df.sample(n=take, random_state=random_state)
            sampled_rows.append(sampled_cat)
            sampled_ids.update(sampled_cat['id'].tolist())
            counts_by_category[cat] = counts_by_category.get(cat, 0) + len(sampled_cat)
            remaining_needed = target_total - len(sampled_ids)

    # Concatenate sampled parts and return
    if sampled_rows:
        result_df = pd.concat(sampled_rows).drop_duplicates(subset=['id']).reset_index(drop=True)
    else:
        result_df = pd.DataFrame(columns=df.columns)

    info = {
        "selected_categories": list(counts.index[:k_adj]),
        "counts_by_category": counts_by_category,
        "target_total": target_total,
        "actual_total": len(result_df)
    }
    return result_df, info

In [5]:
sampled, summary = sample_by_top_categories(df, k=5, n=20, category_col="categories", random_state=123)
print("Target total (n*k):", summary["target_total"])
print("Actual retrieved:", summary["actual_total"])
print("Selected categories:", summary["selected_categories"])
print("Counts by category (sampled):", summary["counts_by_category"])
display(sampled.head())

Target total (n*k): 100
Actual retrieved: 100
Selected categories: ['hep-ph', 'cs.CV', 'quant-ph', 'cs.LG', 'hep-th']
Counts by category (sampled): {'hep-ph': 20, 'cs.CV': 20, 'quant-ph': 20, 'cs.LG': 20, 'hep-th': 20}


Unnamed: 0,id,title,authors,abstract,clean_text,categories,primary_category
0,hep-ph/9502275,Nonleptonic Two-Body Decays of D Mesons in Bro...,Ian Hinchliffe and Thomas A. Kaeding,"Decays of the D mesons to two pseudoscalars,...",decays of the d mesons to two pseudoscalars to...,hep-ph,hep-ph
1,1301.1123,Axions : Theory and Cosmological Role,"Masahiro Kawasaki, Kazunori Nakayama",We review recent developments on axion cosmo...,we review recent developments on axion cosmolo...,hep-ph astro-ph.CO,hep-ph
2,hep-ph/9607210,Froissart boundary for deep inelastic structur...,"A.L.Ayala (IF UFRGS), M.B.Gay Ducati (IF UFRGS...",In this letter we derive the Froissart bound...,in this letter we derive the froissart boundar...,hep-ph,hep-ph
3,hep-ph/0604243,Higher $\eta_c(nS)$ and $\eta_b (nS)$ mesons,A.M.Badalian (Institute of Theoretical and Exp...,The hyperfine splittings in heavy quarkonia ...,the hyperfine splittings in heavy quarkonia ar...,hep-ph,hep-ph
4,1702.08417,Strong couplings and form factors of charmed m...,"Alfonso Ballon-Bayona, Gastao Krein, Carlisson...",We extend the two-flavor hard-wall holograph...,we extend the twoflavor hardwall holographic m...,hep-ph hep-lat hep-th,hep-ph


Function: compute_and_save_embeddings()

Description:
    Computes sentence or document embeddings for a given collection of texts using
    a pre-trained SentenceTransformer model, normalizes them for cosine similarity
    computation, and saves both the embeddings and their corresponding document IDs
    to disk. 

    This implementation is optimized for large datasets by using NumPy memory-mapped
    arrays (memmaps), which prevent excessive RAM usage and allow writing embeddings
    incrementally in batches.

Inputs:
    texts (List[str]): List of raw or preprocessed document texts to embed.
    doc_ids (List[str]): List of unique document identifiers corresponding to the texts.
    model_name (str): Name of the SentenceTransformer model to use
                      (default: 'all-mpnet-base-v2').
    batch_size (int): Number of documents processed per batch (default: 64).
    out_emb_path (str): Path to save the resulting NumPy memmap file of embeddings (.npy).
    out_ids_path (str): Path to save the document IDs (.npy).

Outputs:
    None. The function saves two files to disk:
        - out_emb_path : A NumPy array of shape (N, d) containing normalized embeddings.
        - out_ids_path : A NumPy array of shape (N,) containing corresponding document IDs.

Notes:
    - Each embedding vector is L2-normalized to ensure that cosine similarity 
      corresponds to inner product similarity.
    - Using memmap storage allows the function to handle millions of documents efficiently.
    - The SentenceTransformer model must be compatible with batch encoding.

In [6]:
def compute_and_save_embeddings(texts, doc_ids, model_name='all-mpnet-base-v2',
                                batch_size=64, out_emb_path='embeddings.npy',
                                out_ids_path='doc_ids.json', overwrite=True):
    if os.path.exists(out_emb_path) and not overwrite:
        raise FileExistsError(f"{out_emb_path} already exists. Set overwrite=True to replace.")

    model = SentenceTransformer(model_name)
    n = len(texts)
    emb_dim = model.get_sentence_embedding_dimension()

    emb_memmap = np.lib.format.open_memmap(out_emb_path, mode='w+', dtype='float32', shape=(n, emb_dim))

    for i in tqdm(range(0, n, batch_size), desc="Embedding batches"):
        batch_texts = texts[i:i+batch_size]
        batch_emb = model.encode(batch_texts, show_progress_bar=False, convert_to_numpy=True)
        # normalize rows to unit vectors (for cosine via inner product)
        norms = np.linalg.norm(batch_emb, axis=1, keepdims=True)
        norms[norms == 0] = 1.0
        batch_emb = batch_emb / norms
        emb_memmap[i:i+len(batch_emb)] = batch_emb.astype('float32')

    # ensure data flushed to disk
    del emb_memmap

    # save doc ids as JSON
    with open(out_ids_path, 'w', encoding='utf-8') as f:
        json.dump(list(doc_ids), f, ensure_ascii=False)

    print(f"Saved embeddings -> {out_emb_path}")
    print(f"Saved doc ids -> {out_ids_path}")

In [7]:
# Example usage
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

texts = sampled["clean_text"].tolist()
doc_ids = sampled["id"].tolist()

compute_and_save_embeddings(
    texts,
    doc_ids,
    model_name='all-MiniLM-L6-v2',
    batch_size=8,
    out_emb_path='embeddings.npy',
    out_ids_path='doc_ids.json',
    overwrite=True
)
# Load and verify
loaded_emb = np.load('embeddings.npy', mmap_mode='r')
with open('doc_ids.json', 'r', encoding='utf-8') as f:
    loaded_ids = json.load(f)

print("Embeddings shape:", loaded_emb.shape)
print("Number of doc ids:", len(loaded_ids))
print("Example embedding (first doc) first 10 dims:", loaded_emb[0][:10])

Embedding batches: 100%|██████████| 13/13 [00:01<00:00, 11.13it/s]

Saved embeddings -> embeddings.npy
Saved doc ids -> doc_ids.json
Embeddings shape: (100, 384)
Number of doc ids: 100
Example embedding (first doc) first 10 dims: [-0.0930824  -0.08645523  0.02369268  0.05961437  0.01182459 -0.06102546
 -0.03752167  0.06795102 -0.05304568 -0.01412609]





Function: build_faiss_index()

Description:
    Builds and saves a FAISS index for fast Approximate Nearest Neighbor (ANN) search
    over precomputed document embeddings. This function loads a NumPy array of embeddings
    and constructs an index structure optimized for efficient similarity queries.

    Two types of FAISS indices are supported:
        1. "flat"  – Exact search using a flat (brute-force) inner product index.
                     Suitable for smaller datasets or evaluation baselines.
        2. "hnsw"  – Hierarchical Navigable Small World (HNSW) graph-based index
                     providing approximate nearest neighbor search with high recall
                     and significantly faster query times for large datasets.

    The index can later be loaded and used to efficiently retrieve the top-N most similar
    items to a given query embedding.

Inputs:
    embeddings_path (str): Path to the NumPy file (.npy) containing the precomputed
                           document embeddings of shape (N, d).
    index_path (str): Path where the built FAISS index will be saved (.index file).
    index_type (str): Type of FAISS index to build. Supported values:
                      - 'flat' : Exact search (IndexFlatIP)
                      - 'hnsw' : Approximate search (IndexHNSWFlat)
    ef_construction (int): Construction parameter controlling recall vs. build time
                           for HNSW. Higher values improve recall but increase
                           build time and memory usage.
    M (int): Number of bi-directional links created for each node in the HNSW graph.
             Larger values increase recall but also memory footprint.

Outputs:
    faiss.Index: The constructed FAISS index object (also saved to disk at index_path).

Notes:
    - The embeddings should be normalized if cosine similarity is intended, since
      FAISS IndexFlatIP and IndexHNSWFlat use inner product as the similarity metric.
    - This function can handle large embedding matrices efficiently using memory mapping.

In [8]:
def build_faiss_index(embeddings_path='embeddings.npy', index_path='faiss.index',
                      index_type='hnsw', ef_construction=200, M=32):
    emb = np.load(embeddings_path, mmap_mode='r')  # shape (N, d)
    d = emb.shape[1]
    if index_type == 'flat':
        index = faiss.IndexFlatIP(d)  # inner product -> cosine if vectors normalized
        index.add(emb)
    elif index_type == 'hnsw':
        index = faiss.IndexHNSWFlat(d, M)  # M controls connectivity
        index.hnsw.efConstruction = ef_construction
        index.add(emb)
    else:
        raise ValueError('index_type not supported')
    faiss.write_index(index, index_path)
    return index

Function: retrieve_similar_articles()

Description:
    Given a query text (e.g., a fragment of scientific writing), this function embeds
    the query in the same semantic space as the precomputed article embeddings and
    retrieves the top-N most similar articles based on cosine similarity.

    Two methods can be used:
        (1) Exact K-Nearest Neighbors (KNN) search using sklearn’s NearestNeighbors
        (2) Approximate Nearest Neighbors (ANN) search using Faiss
    depending on the size of the dataset and computational constraints.

Inputs:
    query (str): The input text for which to find related articles.
    model (SentenceTransformer): The embedding model used for encoding.
    embeddings (numpy.ndarray): Precomputed embeddings of all documents.
    articles (List[str]): Corresponding article IDs or titles.
    top_n (int): Number of most similar articles to return.
    use_ann (bool): If True, use approximate search (e.g., FAISS).

Outputs:
    pandas.DataFrame: A ranked table of the top-N retrieved articles with their
                      similarity scores.

In [14]:
def retrieve_similar_articles(query, model, embeddings, articles, top_n=5, use_ann=False):
    query_emb = model.encode([query], convert_to_numpy=True)
    # normalize
    query_emb = query_emb / np.linalg.norm(query_emb, axis=1, keepdims=True)
    
    if use_ann:
        index = build_faiss_index()
        distances, indices = index.search(query_emb.astype('float32'), top_n)
    else:
        # Exact search using sklearn
        nbrs = NearestNeighbors(n_neighbors=top_n, metric="cosine").fit(embeddings)
        distances, indices = nbrs.kneighbors(query_emb)
    
    results = pd.DataFrame({
        "article_id": [articles[i] for i in indices[0]],
        "similarity": [1 - d for d in distances[0]]
    })
    return results

In [None]:
# query taken from hep-ph/9502275
query = "Decays of the D mesons to two pseudoscalars to two vectors and to pseudoscalar plus vector are discussed in the context of broken flavor SU 3 A few assumptions are used to reduce the number of parameters"
top_n = 5

embeddings = np.load('embeddings.npy', mmap_mode='r')
results = retrieve_similar_articles(query, model, embeddings, doc_ids, top_n, use_ann=True)

print(f"Top-{top_n} Recommended Articles:")
display(results)

Top-5 Recommended Articles:


Unnamed: 0,article_id,similarity
0,hep-ph/9502275,0.833745
1,1702.08417,0.304641
2,1904.12566,0.129331
3,1707.03249,-0.035064
4,hep-ph/0604243,-0.035543
