# RAG: Create an index

Let's use langchain, FAISS, and an embedding model to prepare an index. This will be used in the RAG pipeline at `rag_inference.ipynb`

Start by
 - reading in all the .mmd files
 - split into <=512 character chunks
 - write out a flattened list of langchain doc (text chunks) as pickle file

In [None]:
import pickle
import glob
import os
from langchain_community.document_loaders import TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from multiprocessing import Pool

# chunking code in a helper function to be used with Pool
def process_document(doc_path):
    doc = TextLoader(doc_path).load()
    chunked_doc = splitter.split_documents(doc)
    return chunked_doc

cache_file = "datasets/mmd_cache.pkl"
splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=30)

if os.path.exists(cache_file):
    with open(cache_file, 'rb') as f:
        d = pickle.load(f)
    chunked_docs = d['chunked_docs']
else:
    filenames = glob.glob("datasets/*mmd/*.mmd")
    with Pool() as pool:
        chunked_docs_list = pool.map(process_document, filenames)
    chunked_docs = [chunk for sublist in chunked_docs_list for chunk in sublist] # flatten
    d = {'filenames': filenames, 'chunked_docs': chunked_docs}
    with open(cache_file, 'wb') as f:
        pickle.dump(d, f)

Generating the index can be quite slow. If you're ok with doing it locally, uncomment the following:

In [None]:
# # alternatively, run it all locally:
# import os
# from langchain.vectorstores import FAISS
# from langchain.embeddings import HuggingFaceEmbeddings
# db_cache_file = "datasets/faiss_index.bin"
# if os.path.exists(db_cache_file):
#     # Load the existing FAISS index
#     db = FAISS.load_local(db_cache_file)
# else:
#     db = FAISS.from_documents(chunked_docs, HuggingFaceEmbeddings(model_name='BAAI/bge-large-en-v1.5', multi_process=True))
#     db.save_local(db_cache_file)

Otherwise, let's offload this to a GPU cluster.

Let's start by splitting the list of chunks into 8 equal parts.

In [None]:
chunks = [chunked_docs[i::8] for i in range(8)]
for i, chunk in enumerate(chunks):
    with open(f"datasets/mmd_cache_chunk{i}.pkl", 'wb') as f:
        pickle.dump(chunk, f)

calculate the embeddings with up to 8 GPUs in parallel using a script like:

```
import os
import argparse 
import pickle
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings

parser = argparse.ArgumentParser(description="Create a FAISS index from a chunk of the MMD cache")
parser.add_argument("chunk", type=int, help="The chunk number (0-7)")
args = parser.parse_args()
chunk = args.chunk

in_file = f"input/mmd_cache_chunk/mmd_cache_chunk{chunk}.pkl"
out_file = f"output/mmd_cache_chunk/faiss_index_chunk{chunk}.bin"

with open(in_file, "rb") as f:
    chunked_docs = pickle.load(f)
index = FAISS.from_documents(chunked_docs, HuggingFaceEmbeddings(model_name='BAAI/bge-large-en-v1.5'))
index.save_local(out_file)
```

Once we have these, let's recombine into a single index:

In [None]:
import os
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
embeddings = HuggingFaceEmbeddings(model_name='BAAI/bge-large-en-v1.5')
indices = [FAISS.load_local(f"datasets/faiss_index_chunk{j}.bin", embeddings) for j in range(8)]
for i, thisindex in enumerate(indices):
    if i == 0:
        db = thisindex
    else:
        db.merge_from(thisindex)
db.save_local("datasets/faiss_index.bin")