In [0]:
# ============================================================
# GLOBAL CLEANUP — suppress all logs, warnings, HF noise
# ============================================================

import os
import warnings
import logging

warnings.filterwarnings("ignore")

logging.getLogger("transformers").setLevel(logging.ERROR)
logging.getLogger("sentence_transformers").setLevel(logging.ERROR)
logging.getLogger("torch").setLevel(logging.ERROR)

os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["TRANSFORMERS_VERBOSITY"] = "error"

spark.sparkContext.setLogLevel("ERROR")

In [0]:
# ============================================================
# BRONZE INGESTION — Raw arXiv JSON → Delta
# ============================================================

from pyspark.sql.functions import *
import time

input_path = "/FileStore/arxiv_raw/unzipped/arxiv-metadata-oai-snapshot.json"
bronze_path = "/FileStore/delta/arxiv_bronze_v2"

start = time.time()

df_bronze = (
    spark.read
         .option("multiLine", True)
         .option("mode", "PERMISSIVE")
         .json(input_path)
         .withColumn("ingest_ts", current_timestamp())
)

(df_bronze.write.format("delta").mode("overwrite").save(bronze_path))

spark.sql(f"""
CREATE TABLE IF NOT EXISTS arxiv_bronze_v2
USING DELTA
LOCATION '{bronze_path}'
""")

elapsed = (time.time() - start) / 60
print(f"Bronze done in {elapsed:.2f} min")


Bronze done in 0.05 min


In [0]:
# ============================================================
# SILVER LAYER — Cleaned Metadata
# ============================================================

silver_path = "/FileStore/delta/arxiv_silver_v2"

start = time.time()

df_raw = spark.table("arxiv_bronze_v2")

df_silver = (
    df_raw
        .select(
            "id", "title", "abstract", "authors", "categories",
            "update_date", "versions"
        )
        .withColumn("title_clean", lower(trim(col("title"))))
        .withColumn("abstract_clean", lower(trim(col("abstract"))))
        .withColumn("authors_clean", lower(trim(col("authors"))))
        .withColumn("categories_clean", lower(trim(col("categories"))))
        .filter(col("abstract_clean").isNotNull())
        .withColumn("clean_ts", current_timestamp())
)

(df_silver.write.format("delta").mode("overwrite").save(silver_path))

spark.sql(f"""
CREATE TABLE IF NOT EXISTS arxiv_silver_v2
USING DELTA
LOCATION '{silver_path}'
""")

elapsed = (time.time() - start) / 60
print(f"Silver done in {elapsed:.2f} min")

Silver done in 0.12 min


In [0]:
# ============================================================
# GOLD LAYER — Chunking + Distributed Embeddings (Spark UDF)
# ============================================================

import time
from pyspark.sql.functions import *
from pyspark.sql.types import ArrayType, FloatType
from sentence_transformers import SentenceTransformer
import pandas as pd
import numpy as np

start = time.time()
gold_path = "/FileStore/delta/arxiv_gold_v2"

# ------------------------------------------------------------
# 1. Load Silver table
# ------------------------------------------------------------
df_silver = spark.table("arxiv_silver_v2")

# ------------------------------------------------------------
# 2. Chunk abstracts into groups of sentences
# ------------------------------------------------------------
CHUNK_SIZE = 5

df_sentences = df_silver.select(
    "id",
    col("title_clean").alias("title"),
    col("categories_clean").alias("categories"),
    split(
        regexp_replace(col("abstract_clean"), r"\s+", " "),
        r"(?<=[\.\?\!])\s+"
    ).alias("sentences")
)

df_exploded = (
    df_sentences
        .select(
            "id", "title", "categories",
            posexplode("sentences").alias("sent_idx", "sentence")
        )
        .filter(col("sentence") != "")
)

df_chunks = (
    df_exploded
        .withColumn("chunk_id", floor(col("sent_idx") / CHUNK_SIZE))
        .groupBy("id", "title", "categories", "chunk_id")
        .agg(concat_ws(" ", collect_list("sentence")).alias("chunk_text"))
        .filter(col("chunk_text") != "")
)

# ------------------------------------------------------------
# 3. Distributed embedding using Pandas UDF
# ------------------------------------------------------------

# Broadcast model name
model_name = "all-MiniLM-L6-v2"
bc_model_name = spark.sparkContext.broadcast(model_name)

@pandas_udf(ArrayType(FloatType()))
def embed_chunks(batch: pd.Series) -> pd.Series:
    # Load model once per worker
    model = SentenceTransformer(bc_model_name.value)
    embeddings = model.encode(batch.tolist(), show_progress_bar=False)
    return pd.Series([list(vec.astype("float32")) for vec in embeddings])

# Apply distributed embedding
df_gold = df_chunks.withColumn("embedding", embed_chunks("chunk_text"))

# ------------------------------------------------------------
# 4. Write Gold Delta table
# ------------------------------------------------------------
(
    df_gold.write
        .format("delta")
        .mode("overwrite")
        .save(gold_path)
)

spark.sql(f"""
CREATE TABLE IF NOT EXISTS arxiv_gold_v2
USING DELTA
LOCATION '{gold_path}'
""")

# ------------------------------------------------------------
# 5. Print timing
# ------------------------------------------------------------
elapsed = (time.time() - start) / 60
print(f"Gold done in {elapsed:.2f} min")


Gold done in 0.54 min


In [0]:
# ============================================================
# DISTRIBUTED NEAREST-NEIGHBOR SEARCH (Spark Pandas UDF)
# ============================================================

import numpy as np
from pyspark.sql.functions import pandas_udf
from pyspark.sql import functions as F

start = time.time()

df_gold = (
    spark.table("arxiv_gold_v2")
         .select("chunk_text", "embedding")
         .cache()
)

emb_model = SentenceTransformer("all-MiniLM-L6-v2")

def embed_query(q: str) -> np.ndarray:
    return emb_model.encode([q], show_progress_bar=False).astype("float32")[0]

@pandas_udf("float")
def cosine_sim(col: pd.Series) -> pd.Series:
    mat = np.stack(col.values).astype(np.float32)
    q = cosine_sim.qvec
    sims = np.dot(mat, q) / (np.linalg.norm(mat, axis=1) * np.linalg.norm(q))
    return pd.Series(sims)

def retrieve_spark(query: str, k: int = 5):
    qvec = embed_query(query)
    cosine_sim.qvec = qvec

    df_scored = df_gold.withColumn("score", cosine_sim("embedding"))

    topk = (
        df_scored.orderBy(F.desc("score"))
                 .limit(k)
                 .select("chunk_text")
                 .toPandas()
    )

    return [row.chunk_text for row in topk.itertuples()]

elapsed = (time.time() - start) / 60
print(f"Index setup done in {elapsed:.2f} min")


Loading weights:   0%|          | 0/103 [00:00<?, ?it/s]

Index setup done in 0.02 min


In [0]:
# ============================================================
# RAG WITH PHI-2 — CLEAN OUTPUT ONLY
# ============================================================

from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

LLM_MODEL_NAME = "microsoft/phi-2"  # example; you can swap this

start = time.time()

tok = AutoTokenizer.from_pretrained(LLM_MODEL_NAME)

tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(LLM_MODEL_NAME)

gen = pipeline(
    "text-generation",
    model=model,
    tokenizer=tok,
    device=-1,
    pad_token_id=tok.eos_token_id
)

def generate_answer(question: str, k: int = 5, max_new_tokens: int = 256):
    chunks = retrieve_spark(question, k=k)

    context = "\n\n".join(
        [f"Chunk {i+1}:\n{c}" for i, c in enumerate(chunks)]
    )

    prompt = (
        "You are a helpful assistant that answers questions using ONLY the provided context.\n"
        "If the answer is not in the context, say: 'The context does not contain the answer.'\n\n"
        f"Context:\n{context}\n\n"
        f"Question: {question}\n\n"
        "Answer:"
    )

    out = gen(
        prompt,
        max_new_tokens=max_new_tokens,
        do_sample=False,
        pad_token_id=tok.eos_token_id
    )[0]["generated_text"]

    return out.replace(prompt, "").strip()

elapsed = (time.time() - start) / 60
print(f"RAG setup done in {elapsed:.2f} min")


Loading weights:   0%|          | 0/453 [00:00<?, ?it/s]

RAG setup done in 0.04 min


In [0]:
# ============================================================
# FINAL CLEAN EXECUTION — ONLY PRINTS QUESTION + ANSWER
# ============================================================

start = time.time()

question = "How are graph neural networks used for molecular property prediction?"
answer = generate_answer(question, k=5)

print("Question:", question)
print("Answer:", answer)

elapsed = (time.time() - start) / 60
print(f"Answered found in {elapsed:.2f} min")

Question: How are graph neural networks used for molecular property prediction?
Answer: Graph neural networks are used for molecular property prediction by taking into account the structure of molecules and their interactions with other molecules. This allows for more accurate predictions of molecular properties, such as solubility and toxicity, which can be useful in drug discovery and other applications.
Answered found in 0.85 min
