In [None]:
import pdfplumber
from langchain.text_splitter import RecursiveCharacterTextSplitter, SentenceTransformersTokenTextSplitter
from sentence_transformers import SentenceTransformer
import glob
import os
from tqdm import tqdm
from typing import Generator, Tuple, Iterator

import faiss
import numpy as np
import pickle

In [None]:
def pickle_read(filename: str):
    with open("data.pkl", "rb") as f:
        loaded_data = pickle.load(f)
    return loaded_data
def pickle_write(data, filename: str):
    with open(filename + ".pkl", "wb") as f:
        pickle.dump(data, f)

In [None]:
def index_documents(knowledge_base_path: str = "raw-kb") -> Iterator[Tuple[str, str, int]]:
    for filepath in glob.glob(os.path.join(knowledge_base_path, "*.pdf")):
        filename = os.path.basename(filepath)
        with pdfplumber.open(filepath) as pdf:
            for idx, page in enumerate(pdf.pages):
                yield page.extract_text(), filename, idx

In [None]:
fine_splitter = SentenceTransformersTokenTextSplitter(
    model_name="all-mpnet-base-v2", #"sentence-transformers/all-MiniLM-L6-v2",
    tokens_per_chunk=384,
    chunk_overlap=50
)

model = fine_splitter._model

In [None]:
try:
    #1/0
    index = faiss.read_index("basic_rag.faiss")
    meta = pickle_read("meta")
    print("Index read")
except:
    index = faiss.IndexFlatL2(fine_splitter._model[1].word_embedding_dimension)
    meta = {}
    print("Index created")

In [None]:
%%time
document_index = index.ntotal
meta = {}
for text, filename, page_index in tqdm(index_documents()):
    for chunk in fine_splitter.split_text(text):
        embeddings = model.encode([chunk])
        meta[document_index] = {"filename": filename, "page_index": page_index, "text": chunk}
        index.add(embeddings)
        document_index += 1
faiss.write_index(index, "basic_rag.faiss")
pickle_write(meta, "meta")

In [None]:
chunk

In [None]:
embeddings[0].tolist()[:10]

In [None]:
model.encode([chunk])[0].tolist()[:10]

In [None]:
scores, indexies = index.search(model.encode(["Philipp Schindler"]), k=10)
print(indexies[0])
#scores, indexies = index.search(embeddings, k=10)

In [None]:
for score, idx in list(zip(scores[0], indexies[0])):
    print(score, idx, meta[idx])