# Simple RAG Demo with OLMo and ArXiv `astro-ph` Dataset

The RAG based approach below is a demonstration of how to use [OLMo-1B](https://huggingface.co/allenai/OLMo-1B) LLM model by AI2 to generate an abstract completion for a given input text. The input text is a random starting abstract from `astro-ph` category of [ArXiv Dataset](https://www.kaggle.com/datasets/Cornell-University/arxiv). The abstract completion is generated by the model using the RAG approach. The RAG approach retrieves relevant documents from [Qdrant Vector Database](https://qdrant.tech/), which provides contextual information to the model for generating the completion.

The input text was retrieved from the [AstroLLaMa Paper](https://arxiv.org/abs/2309.06126). Rather than fine-tuning a model, we wanted to see if RAG approach can also work.

We will use the following statement as user input:

In [1]:
statement = """The Magellanic Stream (MS) - an enormous ribbon of gas spanning 140∘ of the southern
sky trailing the Magellanic Clouds - has been exquisitely mapped in the five decades since
its discovery. However, despite concerted efforts, no stellar counterpart to the MS has been
conclusively identified. This stellar stream would reveal the distance and 6D kinematics of
the MS, constraining its formation and the past orbital history of the Clouds. We"""

### Utility Functions

In [2]:
import zipfile
import json
import pandas as pd
import io
import fsspec

def fetch_arxiv_dataset(zip_url: str) -> pd.DataFrame:
    cols = ['id', 'title', 'abstract', 'categories']

    with fsspec.open(zip_url) as f:
        with zipfile.ZipFile(f) as archive:
            data = []
            json_file = archive.filelist[0]
            with archive.open(json_file) as f:
                for line in io.TextIOWrapper(f, encoding="latin-1"):
                    doc = json.loads(line)
                    lst = [doc['id'], doc['title'], doc['abstract'], doc['categories']]
                    data.append(lst)
                    
            df_data = pd.DataFrame(data=data, columns=cols)
    return df_data

# https://github.com/allenai/open-instruct/blob/main/eval/templates.py
def create_prompt_with_olmo_chat_format(messages, bos="|||IP_ADDRESS|||", eos="|||IP_ADDRESS|||", add_bos=True):
    formatted_text = ""
    for message in messages:
        if message["role"] == "system":
            formatted_text += "<|system|>\n" + message["content"] + "\n"
        elif message["role"] == "user":
            formatted_text += "<|user|>\n" + message["content"] + "\n"
        elif message["role"] == "assistant":
            formatted_text += "<|assistant|>\n" + message["content"].strip() + eos + "\n"
        else:
            raise ValueError(
                "Olmo chat template only supports 'system', 'user' and 'assistant' roles. Invalid role: {}.".format(message["role"])
                )
    formatted_text += "<|assistant|>\n"
    formatted_text = bos + formatted_text  # forcibly add bos
    return formatted_text

# Post-processing
def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)



### Retrieve documents (arXiv `astro-ph` abstracts)

This section retrieves the arXiv abstracts and creates documents
for loading into a vector database. You can skip running the following sections
if you have a local copy of the Qdrant Vector Database data ready to go.

In [3]:
from langchain_community.document_loaders import DataFrameLoader

In [4]:
# zip_url = "https://storage.googleapis.com/kaggle-data-sets/612177/7925852/bundle/archive.zip?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gcp-kaggle-com@kaggle-161607.iam.gserviceaccount.com/20240327/auto/storage/goog4_request&X-Goog-Date=20240327T183523Z&X-Goog-Expires=259200&X-Goog-SignedHeaders=host&X-Goog-Signature=4747ce35edc693785c00b4ade2fc7f62149173bf160f1b04f97fc6a752bfb1ccb5408359a16b475e7d955f04a52f2fb9f916d8090330993839fabfb1835847e0c62452243ecc74e232eeed1d747beaf6da1209b9614d305c020e6bd09bb096e6c6e2bb4711d96fb457ed1533c04bb78690253d3b6f4a4068aa3b9cd073742a3ed68562fa2a88a29e646a629dee0a26f99ff0539b5f81c926bc2b5a62642ac9f0a92febc7ca812a61351191334baad93b3ecca2ac408da8ca35a4d6e8afda67d6e8196b50c20ee18358a19cb21c25dfbcc7394bc99b280ed9222c8a933ea91f7d4b65aba05156ab985b36e761a70a35f6bbd208b9507a04ff68e15c258ec5920f"
zip_url = "./archive.zip"

In [5]:
# Fetch the dataset containing all arXiv abstracts
df_data = fetch_arxiv_dataset(zip_url)
# Filter the dataset to only include astro-ph category
astro_df = df_data[df_data.categories.str.contains('astro-ph')].reset_index(drop=True)
print("Number of astro-ph papers: ", len(astro_df))

Number of astro-ph papers:  338991


In [6]:
# Eargerly load the dataframe full of abstracts
# to memory in the form of langchain Document objects
loader = DataFrameLoader(astro_df, page_content_column="abstract")
all_docs = loader.load()

In [7]:
print(f"{len(all_docs)=}")
for doc in all_docs[:5]:
    print(doc)

len(all_docs)=338991
page_content="  We discuss the results from the combined IRAC and MIPS c2d Spitzer Legacy\nobservations of the Serpens star-forming region. In particular we present a set\nof criteria for isolating bona fide young stellar objects, YSO's, from the\nextensive background contamination by extra-galactic objects. We then discuss\nthe properties of the resulting high confidence set of YSO's. We find 235 such\nobjects in the 0.85 deg^2 field that was covered with both IRAC and MIPS. An\nadditional set of 51 lower confidence YSO's outside this area is identified\nfrom the MIPS data combined with 2MASS photometry. We describe two sets of\nresults, color-color diagrams to compare our observed source properties with\nthose of theoretical models for star/disk/envelope systems and our own modeling\nof the subset of our objects that appear to be star+disks. These objects\nexhibit a very wide range of disk properties, from many that can be fit with\nactively accreting disks to so

### Document Embeddings to Qdrant Vector Database

In [8]:
from langchain_community.vectorstores import Qdrant
from langchain.embeddings import HuggingFaceEmbeddings

#### Setup Vector DB

In [9]:
import os

In [10]:
# Setup the embedding, we are using the MiniLM model here
embedding = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")

In [11]:
from langchain_community.vectorstores import FAISS


In [13]:
%%time
documents = all_docs[:10000]
faiss = FAISS.from_documents(documents, embedding)


CPU times: user 32min 32s, sys: 7min 25s, total: 39min 58s
Wall time: 4min 36s


In [None]:
%%time
qdrant_path="./qdrant_data"
qdrant_collection="arxiv_astro-ph_abstracts"

print(f"Loading existing Qdrant collection '{qdrant_collection}'")
from qdrant_client import QdrantClient
# If the Qdrant Vector Database Collection already exists, load it
client = QdrantClient(path=qdrant_path)
qdrant = Qdrant(
    client=client,
    collection_name=qdrant_collection,
    embeddings=embedding
)


In [None]:
# %%time
# 
# documents = all_docs[:10000]
# 
# # docker run -p 6333:6333 -v $(pwd)/path/to/data:/qdrant/storage qdrant/qdrant
# qdrant_path="./qdrant_data_v2"
# qdrant_collection="arxiv_astro-ph_abstracts"
# 
# print(f"Creating new Qdrant collection '{qdrant_collection}' from {len(documents)} documents")
# 
# # Load the documents into a Qdrant Vector Database Collection
# # this will save locally in the current directory as sqlite
# qdrant = Qdrant.from_documents(
#     documents,
#     embedding,
#     path=qdrant_path,
#     collection_name=qdrant_collection,
# )


#### Test out the Qdrant collection

In [None]:
# Setup the retriever for later step
retriever = qdrant.as_retriever(search_type="mmr", search_kwargs={"k": 5})

In [14]:
retriever = faiss.as_retriever(search_type="mmr", search_kwargs={"k": 5})


In [18]:
# Test out the statement retrieval
found_docs = retriever.get_relevant_documents(statement)

In [19]:
print(format_docs(found_docs))

  We explore the Magellanic Stream (MS) using a Gaussian decomposition of the
HI velocity profiles in the Leiden-Argentine-Bonn (LAB) all-sky HI survey. This
decomposition exposes the MS to be composed of two filaments distinct both
spatially (as first pointed out by Putman et al.) and in velocity. Using the
velocity coherence of the filaments, one can be traced back to its origin in
what we identify as the SouthEast HI Overdensity (SEHO) of the Large Magellanic
Cloud (LMC), which includes 30 Doradus. Parts of the Leading Arm (LA) can also
be traced back to the SEHO in velocity and position. Therefore, at least
one-half of the trailing Stream and most of the LA originates in the LMC,
contrary to previous assertions that both the MS and the LA originate in the
Small Magellanic Cloud (SMC) and/or in the Magellanic Bridge. The two MS
filaments show strong periodic, undulating spatial and velocity patterns that
we speculate are an imprint of the LMC rotation curve. If true, then the drift


### Setup OLMo Model

In [20]:
from pathlib import Path
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

In [21]:
model_name = "allenai/OLMo-1B"

In [None]:
# # Download the model and its configuration file locally
# # from the Hugging Face Hub
# # we will only download the configuration file and the model as safetensors file
# local_dir = Path("./OLMo-1B")
# model_path = snapshot_download(
#     repo_id=model_name,
#     ignore_patterns=["*.bin"],
#     local_dir=local_dir,
#     local_dir_use_symlinks=True)

In [22]:
model_path = Path("./OLMo-1B")
olmo = AutoModelForCausalLM.from_pretrained(
    model_path,
    trust_remote_code=True,
    local_files_only=True
)
tokenizer = AutoTokenizer.from_pretrained(model_path)

In [23]:
# Setup the text generation pipeline with the OLMo model
olmo_pipe = pipeline(
    task="text-generation",
    model=olmo,
    tokenizer=tokenizer,
    temperature=0.2,
    do_sample=True,
    repetition_penalty=1.1,
    return_full_text=True,
    max_new_tokens=400,
)

### Setup the langchain pipeline for the OLMo model

In [24]:
from langchain.llms import HuggingFacePipeline

In [25]:
llm = HuggingFacePipeline(pipeline=olmo_pipe)

#### Define the system prompts

In [26]:
from langchain.prompts import PromptTemplate

In [27]:
no_context_prompt = PromptTemplate(
    input_variables=["question"],
    template=create_prompt_with_olmo_chat_format(messages=[
        {"role": "system", "content": "You are an astrophysics expert. Finish the given statement."}, 
        {"role": "user", "content": "{question}"}
    ]),
)

with_context_prompt = PromptTemplate(
    input_variables=["context", "question"],
    template=create_prompt_with_olmo_chat_format(messages=[
        {"role": "system", "content": "You are an astrophysics expert. Use the following pieces of retrieved context to finish the given statement:\n{context}"}, 
        {"role": "user", "content": "{question}"}
    ]),
)

#### Define the chain of processes for the LLM

In [28]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough

In [29]:
llm_chain = llm | StrOutputParser()
no_context_chain = {"question": RunnablePassthrough()} | no_context_prompt | llm_chain
rag_chain = {"context": retriever | format_docs, "question": RunnablePassthrough()} | with_context_prompt | llm_chain

### Invoke the no-context pipeline

In [30]:
no_context_answer = no_context_chain.invoke(statement)

In [31]:
print(no_context_answer)

|||IP_ADDRESS|||<|system|>
You are an astrophysics expert. Finish the given statement.
<|user|>
The Magellanic Stream (MS) - an enormous ribbon of gas spanning 140∘ of the southern
sky trailing the Magellanic Clouds - has been exquisitely mapped in the five decades since
its discovery. However, despite concerted efforts, no stellar counterpart to the MS has been
conclusively identified. This stellar stream would reveal the distance and 6D kinematics of
the MS, constraining its formation and the past orbital history of the Clouds. We
<|assistant|>
have developed a novel technique for measuring the velocity dispersion of stars in the
MS. The method is based on the measurement of the radial velocity of a star as it passes
through the central region of the MS. It uses the fact that the MS is a highly extended
structure, which means that the radial velocity of any star will be affected by the
velocity of surrounding stars. In order to measure the radial velocity of a star, we
measure the Do

### Invoke the RAG chain

In [32]:
rag_answer = rag_chain.invoke(statement)

In [33]:
print(rag_answer)

|||IP_ADDRESS|||<|system|>
You are an astrophysics expert. Use the following pieces of retrieved context to finish the given statement:
  We explore the Magellanic Stream (MS) using a Gaussian decomposition of the
HI velocity profiles in the Leiden-Argentine-Bonn (LAB) all-sky HI survey. This
decomposition exposes the MS to be composed of two filaments distinct both
spatially (as first pointed out by Putman et al.) and in velocity. Using the
velocity coherence of the filaments, one can be traced back to its origin in
what we identify as the SouthEast HI Overdensity (SEHO) of the Large Magellanic
Cloud (LMC), which includes 30 Doradus. Parts of the Leading Arm (LA) can also
be traced back to the SEHO in velocity and position. Therefore, at least
one-half of the trailing Stream and most of the LA originates in the LMC,
contrary to previous assertions that both the MS and the LA originate in the
Small Magellanic Cloud (SMC) and/or in the Magellanic Bridge. The two MS
filaments show strong