# Simple RAG Demo with OLMo and ArXiv `astro-ph` Dataset

The RAG based approach below is a demonstration of how to use [OLMo-1B](https://huggingface.co/allenai/OLMo-1B) LLM model by AI2 to generate an abstract completion for a given input text. The input text is a random starting abstract from `astro-ph` category of [ArXiv Dataset](https://www.kaggle.com/datasets/Cornell-University/arxiv). The abstract completion is generated by the model using the RAG approach. The RAG approach retrieves relevant documents from [Qdrant Vector Database](https://qdrant.tech/), which provides contextual information to the model for generating the completion.

The input text was retrieved from the [AstroLLaMa Paper](https://arxiv.org/abs/2309.06126). Rather than fine-tuning a model, we wanted to see if RAG approach can also work.

We will use the following statement as user input:

In [1]:
statement = """The Magellanic Stream (MS) - an enormous ribbon of gas spanning 140∘ of the southern
sky trailing the Magellanic Clouds - has been exquisitely mapped in the five decades since
its discovery. However, despite concerted efforts, no stellar counterpart to the MS has been
conclusively identified. This stellar stream would reveal the distance and 6D kinematics of
the MS, constraining its formation and the past orbital history of the Clouds. We"""

In [3]:
from langchain_community.vectorstores import Qdrant
from qdrant_client import QdrantClient
from langchain.embeddings import HuggingFaceEmbeddings


In [5]:
embedding = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")

qdrant_path="./qdrant_data"
qdrant_collection="arxiv_astro-ph_abstracts"
 
print(f"Loading existing Qdrant collection '{qdrant_collection}'")
# If the Qdrant Vector Database Collection already exists, load it
client = QdrantClient(path=qdrant_path)
qdrant = Qdrant(
    client=client,
    collection_name=qdrant_collection,
    embeddings=embedding
)



Loading existing Qdrant collection 'arxiv_astro-ph_abstracts'


In [2]:
%load_ext autoreload
%autoreload 2

import zipfile
import json
import pandas as pd
import io
import fsspec

def fetch_arxiv_dataset(zip_url: str) -> pd.DataFrame:
    cols = ['id', 'title', 'abstract', 'categories']

    with fsspec.open(zip_url) as f:
        with zipfile.ZipFile(f) as archive:
            data = []
            json_file = archive.filelist[0]
            with archive.open(json_file) as f:
                for line in io.TextIOWrapper(f, encoding="latin-1"):
                    doc = json.loads(line)
                    lst = [doc['id'], doc['title'], doc['abstract'], doc['categories']]
                    data.append(lst)
                    
            df_data = pd.DataFrame(data=data, columns=cols)
    return df_data

# https://github.com/allenai/open-instruct/blob/main/eval/templates.py
def create_prompt_with_olmo_chat_format(messages, bos="|||IP_ADDRESS|||", eos="|||IP_ADDRESS|||", add_bos=True):
    formatted_text = ""
    for message in messages:
        if message["role"] == "system":
            formatted_text += "<|system|>\n" + message["content"] + "\n"
        elif message["role"] == "user":
            formatted_text += "<|user|>\n" + message["content"] + "\n"
        elif message["role"] == "assistant":
            formatted_text += "<|assistant|>\n" + message["content"].strip() + eos + "\n"
        else:
            raise ValueError(
                "Olmo chat template only supports 'system', 'user' and 'assistant' roles. Invalid role: {}.".format(message["role"])
                )
    formatted_text += "<|assistant|>\n"
    formatted_text = bos + formatted_text  # forcibly add bos
    return formatted_text

# Post-processing
def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)



### Retrieve documents (arXiv `astro-ph` abstracts)

This section retrieves the arXiv abstracts and creates documents
for loading into a vector database. You can skip running the following sections
if you have a local copy of the Qdrant Vector Database data ready to go.

In [3]:
from langchain_community.document_loaders import DataFrameLoader

In [4]:
# zip_url = "https://storage.googleapis.com/kaggle-data-sets/612177/7925852/bundle/archive.zip?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gcp-kaggle-com@kaggle-161607.iam.gserviceaccount.com/20240327/auto/storage/goog4_request&X-Goog-Date=20240327T183523Z&X-Goog-Expires=259200&X-Goog-SignedHeaders=host&X-Goog-Signature=4747ce35edc693785c00b4ade2fc7f62149173bf160f1b04f97fc6a752bfb1ccb5408359a16b475e7d955f04a52f2fb9f916d8090330993839fabfb1835847e0c62452243ecc74e232eeed1d747beaf6da1209b9614d305c020e6bd09bb096e6c6e2bb4711d96fb457ed1533c04bb78690253d3b6f4a4068aa3b9cd073742a3ed68562fa2a88a29e646a629dee0a26f99ff0539b5f81c926bc2b5a62642ac9f0a92febc7ca812a61351191334baad93b3ecca2ac408da8ca35a4d6e8afda67d6e8196b50c20ee18358a19cb21c25dfbcc7394bc99b280ed9222c8a933ea91f7d4b65aba05156ab985b36e761a70a35f6bbd208b9507a04ff68e15c258ec5920f"
zip_url = "./archive.zip"

In [5]:
# Fetch the dataset containing all arXiv abstracts
df_data = fetch_arxiv_dataset(zip_url)
# Filter the dataset to only include astro-ph category
astro_df = df_data[df_data.categories.str.contains('astro-ph')].reset_index(drop=True)
print("Number of astro-ph papers: ", len(astro_df))

Number of astro-ph papers:  338991


In [10]:
astro_df = astro_df[:100]
print("Limited Set of of astro-ph papers: ", len(astro_df))
astro_df

Limited Set of of astro-ph papers:  100


Unnamed: 0,id,title,abstract,categories
0,0704.0009,"The Spitzer c2d Survey of Large, Nearby, Inste...",We discuss the results from the combined IRA...,astro-ph
1,0704.0017,Spectroscopic Observations of the Intermediate...,Results from spectroscopic observations of t...,astro-ph
2,0704.0023,ALMA as the ideal probe of the solar chromosphere,"The very nature of the solar chromosphere, i...",astro-ph
3,0704.0044,Astrophysical gyrokinetics: kinetic and fluid ...,We present a theoretical framework for plasm...,astro-ph nlin.CD physics.plasm-ph physics.spac...
4,0704.0048,Inference on white dwarf binary systems using ...,We report on the analysis of selected single...,gr-qc astro-ph
...,...,...,...,...
95,0704.0486,Kinematic Decoupling of Globular Clusters with...,About 25% of the Milky Way globular clusters...,astro-ph
96,0704.0490,Long Distance Signaling Using Axion-like Parti...,The possible existence of axion-like particl...,hep-ph astro-ph hep-th
97,0704.0510,Axino warm dark matter and $\Omega_b - \Omega_...,"We show that axinos, which are dominantly ge...",hep-ph astro-ph
98,0704.0513,SDSS J233325.92+152222.1 and the evolution of ...,Intermediate polars (IPs) are cataclysmic va...,astro-ph


In [11]:
# Eargerly load the dataframe full of abstracts
# to memory in the form of langchain Document objects
loader = DataFrameLoader(astro_df, page_content_column="abstract")
all_docs = loader.load()

In [12]:
print(f"{len(all_docs)=}")
for doc in all_docs[:5]:
    print(doc)

len(all_docs)=100
page_content="  We discuss the results from the combined IRAC and MIPS c2d Spitzer Legacy\nobservations of the Serpens star-forming region. In particular we present a set\nof criteria for isolating bona fide young stellar objects, YSO's, from the\nextensive background contamination by extra-galactic objects. We then discuss\nthe properties of the resulting high confidence set of YSO's. We find 235 such\nobjects in the 0.85 deg^2 field that was covered with both IRAC and MIPS. An\nadditional set of 51 lower confidence YSO's outside this area is identified\nfrom the MIPS data combined with 2MASS photometry. We describe two sets of\nresults, color-color diagrams to compare our observed source properties with\nthose of theoretical models for star/disk/envelope systems and our own modeling\nof the subset of our objects that appear to be star+disks. These objects\nexhibit a very wide range of disk properties, from many that can be fit with\nactively accreting disks to some 

### Document Embeddings to Qdrant/FAISS Vector Database

In [13]:
# from langchain_community.vectorstores import Qdrant, FAISS
from langchain.embeddings import HuggingFaceEmbeddings

#### Setup Vector DB

In [14]:
import os

In [15]:
# Setup the embedding, we are using the MiniLM model here
embedding = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")

In [16]:
%%time
documents = all_docs
faiss = MyFAISS.from_documents(documents, embedding)


Inside FAISS __from
args=(HuggingFaceEmbeddings(client=SentenceTransformer(
  (0): Transformer({'max_seq_length': 256, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
  (2): Normalize()
), model_name='sentence-transformers/all-MiniLM-L6-v2', cache_folder=None, model_kwargs={}, encode_kwargs={}, multi_process=False, show_progress=False), <faiss.swigfaiss_avx2.IndexFlatL2; proxy of <Swig Object of type 'faiss::IndexFlatL2 *' at 0x7aea4f980300> >, <langchain_community.docstore.in_memory.InMemoryDocstore object at 0x7aea69def950>, {})
kwargs={'normalize_L2': False, 'distance_strategy': <DistanceStrategy.EUCLIDEAN_DISTANCE: 'EUCLIDEAN_DISTANCE'>}
CPU times: user 16 s, sys: 4.63 s, total: 2

In [20]:
faiss.index_to_docstore_id

{0: '0d6c2c20-4000-4036-9f42-3bde4a8683ed',
 1: '67673884-642a-4d31-8480-57138309e31a',
 2: 'e6347c5e-69f7-4a6f-8e11-53a9524e6651',
 3: '21833176-58d5-4be5-89f5-7b9d5a1634b6',
 4: 'f23afe1f-60a2-41d7-9e41-c6af1f46e1c7',
 5: 'a51cebf2-c4bf-45ab-8157-3c89cc72083a',
 6: '8d28865f-826c-486d-8b8d-3f4e362bbd3c',
 7: '348deed3-5d87-489f-9ec7-95d6f43b2893',
 8: '7912d55d-e9ac-409e-88c8-09b5dcbcf305',
 9: '0a6873ca-3c9b-46f8-bbe1-225309170e2a',
 10: 'cd677485-8e6f-4857-b095-d438efa32549',
 11: '13fae921-943e-4d7d-8e4a-5a2cb3c06b61',
 12: 'eebedc7f-28e8-4cc3-a2f1-14371a3caec4',
 13: 'd304767a-6a48-4cb3-84c7-a6c9ba88d189',
 14: '0655dfe3-dc1c-460e-a4c0-7a875ca24c88',
 15: 'a4bcfde3-96f3-47d7-b9bd-aa24470bebd6',
 16: 'f47c564f-4ff6-41bc-8ae7-361299a06289',
 17: '3abf4ab8-bcfa-4cee-818c-c8280df768d2',
 18: '825bd57a-bde1-4357-ab82-4acd50e75369',
 19: 'dad01875-baa4-4b4f-8be4-64769afee775',
 20: 'ce70b179-af71-42b8-9c9b-85809af4dadd',
 21: '1b111f03-d79a-43b1-a744-adc564019b9e',
 22: 'b465ec2a-b5de-

In [21]:
faiss.docstore._dict

{'0d6c2c20-4000-4036-9f42-3bde4a8683ed': Document(page_content="  We discuss the results from the combined IRAC and MIPS c2d Spitzer Legacy\nobservations of the Serpens star-forming region. In particular we present a set\nof criteria for isolating bona fide young stellar objects, YSO's, from the\nextensive background contamination by extra-galactic objects. We then discuss\nthe properties of the resulting high confidence set of YSO's. We find 235 such\nobjects in the 0.85 deg^2 field that was covered with both IRAC and MIPS. An\nadditional set of 51 lower confidence YSO's outside this area is identified\nfrom the MIPS data combined with 2MASS photometry. We describe two sets of\nresults, color-color diagrams to compare our observed source properties with\nthose of theoretical models for star/disk/envelope systems and our own modeling\nof the subset of our objects that appear to be star+disks. These objects\nexhibit a very wide range of disk properties, from many that can be fit with\na

In [24]:
faiss.index.reconstruct_n(0, 1)[0]

array([-5.85617349e-02, -2.59089544e-02,  8.65069479e-02,  3.44750658e-02,
        3.57468501e-02, -1.18778750e-01, -1.53680576e-03,  5.02209626e-02,
        6.04934655e-02,  2.14006640e-02, -5.24109006e-02, -1.81740280e-02,
        4.38794456e-02, -7.64516219e-02,  3.28899845e-02, -5.40960059e-02,
        4.95361425e-02, -7.64981583e-02,  8.54167063e-03,  7.36551434e-02,
        2.51255110e-02, -1.46400603e-02, -4.73659784e-02, -1.86123140e-02,
        3.95741537e-02, -1.37401009e-02, -2.88463999e-02,  2.49426663e-02,
       -5.70209650e-03, -3.78701203e-02,  5.92349172e-02,  1.18892016e-02,
       -2.64140088e-02, -9.54738352e-03,  3.95799391e-02, -7.68278241e-02,
        6.44649118e-02, -3.74760255e-02, -2.53463686e-02, -2.41455790e-02,
       -1.01645570e-02, -5.81586082e-03,  1.89951602e-02,  1.51501298e-02,
       -2.10147332e-02, -4.46411446e-02, -6.13166206e-02,  4.47287364e-03,
       -8.85273740e-02, -6.69401661e-02, -1.98561400e-02, -7.56961405e-02,
       -4.49669659e-02,  

In [23]:
embedding_vector = faiss.index.reconstruct_n(0, 1)[0]
embedding_vector

array([-5.85617349e-02, -2.59089544e-02,  8.65069479e-02,  3.44750658e-02,
        3.57468501e-02, -1.18778750e-01, -1.53680576e-03,  5.02209626e-02,
        6.04934655e-02,  2.14006640e-02, -5.24109006e-02, -1.81740280e-02,
        4.38794456e-02, -7.64516219e-02,  3.28899845e-02, -5.40960059e-02,
        4.95361425e-02, -7.64981583e-02,  8.54167063e-03,  7.36551434e-02,
        2.51255110e-02, -1.46400603e-02, -4.73659784e-02, -1.86123140e-02,
        3.95741537e-02, -1.37401009e-02, -2.88463999e-02,  2.49426663e-02,
       -5.70209650e-03, -3.78701203e-02,  5.92349172e-02,  1.18892016e-02,
       -2.64140088e-02, -9.54738352e-03,  3.95799391e-02, -7.68278241e-02,
        6.44649118e-02, -3.74760255e-02, -2.53463686e-02, -2.41455790e-02,
       -1.01645570e-02, -5.81586082e-03,  1.89951602e-02,  1.51501298e-02,
       -2.10147332e-02, -4.46411446e-02, -6.13166206e-02,  4.47287364e-03,
       -8.85273740e-02, -6.69401661e-02, -1.98561400e-02, -7.56961405e-02,
       -4.49669659e-02,  

In [None]:
# %%time
# qdrant_path="./qdrant_data"
# qdrant_collection="arxiv_astro-ph_abstracts"
# 
# print(f"Loading existing Qdrant collection '{qdrant_collection}'")
# from qdrant_client import QdrantClient
# # If the Qdrant Vector Database Collection already exists, load it
# client = QdrantClient(path=qdrant_path)
# qdrant = Qdrant(
#     client=client,
#     collection_name=qdrant_collection,
#     embeddings=embedding
# )


In [None]:
# %%time
# 
# documents = all_docs[:10000]
# 
# # docker run -p 6333:6333 -v $(pwd)/path/to/data:/qdrant/storage qdrant/qdrant
# qdrant_path="./qdrant_data_v2"
# qdrant_collection="arxiv_astro-ph_abstracts"
# 
# print(f"Creating new Qdrant collection '{qdrant_collection}' from {len(documents)} documents")
# 
# # Load the documents into a Qdrant Vector Database Collection
# # this will save locally in the current directory as sqlite
# qdrant = Qdrant.from_documents(
#     documents,
#     embedding,
#     path=qdrant_path,
#     collection_name=qdrant_collection,
# )


#### Test out the Qdrant collection

In [None]:
# # Setup the retriever for later step
# retriever = qdrant.as_retriever(search_type="mmr", search_kwargs={"k": 5})

In [None]:
retriever = faiss.as_retriever(search_type="mmr", search_kwargs={"k": 5})


In [None]:
# Test out the statement retrieval
found_docs = retriever.get_relevant_documents(statement)

In [None]:
print(format_docs(found_docs))

### Setup OLMo Model

In [None]:
from pathlib import Path
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

In [None]:
model_name = "allenai/OLMo-1B"

In [None]:
# # Download the model and its configuration file locally
# # from the Hugging Face Hub
# # we will only download the configuration file and the model as safetensors file
# local_dir = Path("./OLMo-1B")
# model_path = snapshot_download(
#     repo_id=model_name,
#     ignore_patterns=["*.bin"],
#     local_dir=local_dir,
#     local_dir_use_symlinks=True)

In [None]:
model_path = Path("./OLMo-1B")
olmo = AutoModelForCausalLM.from_pretrained(
    model_path,
    trust_remote_code=True,
    local_files_only=True
)
tokenizer = AutoTokenizer.from_pretrained(model_path)

In [None]:
# Setup the text generation pipeline with the OLMo model
olmo_pipe = pipeline(
    task="text-generation",
    model=olmo,
    tokenizer=tokenizer,
    temperature=0.2,
    do_sample=True,
    repetition_penalty=1.1,
    return_full_text=True,
    max_new_tokens=400,
)

### Setup the langchain pipeline for the OLMo model

In [None]:
from langchain.llms import HuggingFacePipeline

In [None]:
llm = HuggingFacePipeline(pipeline=olmo_pipe)

#### Define the system prompts

In [None]:
from langchain.prompts import PromptTemplate

In [None]:
no_context_prompt = PromptTemplate(
    input_variables=["question"],
    template=create_prompt_with_olmo_chat_format(messages=[
        {"role": "system", "content": "You are an astrophysics expert. Finish the given statement."}, 
        {"role": "user", "content": "{question}"}
    ]),
)

with_context_prompt = PromptTemplate(
    input_variables=["context", "question"],
    template=create_prompt_with_olmo_chat_format(messages=[
        {"role": "system", "content": "You are an astrophysics expert. Use the following pieces of retrieved context to finish the given statement:\n{context}"}, 
        {"role": "user", "content": "{question}"}
    ]),
)

#### Define the chain of processes for the LLM

In [None]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough

In [None]:
llm_chain = llm | StrOutputParser()
no_context_chain = {"question": RunnablePassthrough()} | no_context_prompt | llm_chain
rag_chain = {"context": retriever | format_docs, "question": RunnablePassthrough()} | with_context_prompt | llm_chain

### Invoke the no-context pipeline

In [None]:
no_context_answer = no_context_chain.invoke(statement)

In [None]:
print(no_context_answer)

### Invoke the RAG chain

In [None]:
rag_answer = rag_chain.invoke(statement)

In [None]:
print(rag_answer)

In [None]:
faiss.save_local("faiss_index")

In [None]:
faiss_local = FAISS.load_local("faiss_index", embedding, allow_dangerous_deserialization=True)

In [None]:
faiss_local = faiss_local.vectors

In [None]:
retriever = faiss_local.as_retriever(search_type="mmr", search_kwargs={"k": 5})

# Test out the statement retrieval
found_docs = retriever.get_relevant_documents(statement)
print(format_docs(found_docs))