In [None]:
from src.utils import load_docs_from_jsonl, uniquify, gen_report

import os
from tqdm.auto import tqdm

from src.hf_models import Engine
from src.embedders import Embedders
from src.faiss_store import VectorDB
from src.retrievers import ComposedRetriever
from src.prompts import RAGPromptTemplates
from src.chains import WrapperChains

from torch import bfloat16
from transformers import BitsAndBytesConfig

# Parameters

In [None]:
cwd = os.getcwd()
embed_cache = "./cache/"
model_cache =  '/models/model_cache' 

# Embedding model
embed_model_id = 'BAAI/bge-large-en-v1.5' 
embed_model_kwargs = {'device': 'cuda'}
embed_show_prog = True

# LLM
model_id = "cognitivecomputations/dolphin-2.6-mistral-7b-dpo-laser"
#model_id = "cognitivecomputations/WestLake-7B-v2-laser"  # - not bad
#model_id = "meta-llama/Llama-2-13b-chat-hf" # a bit better than dolphin-2.6-mistral-7b-dpo-laser
#model_id = "google/gemma-7b"
#model_id = "Deci/DeciLM-7B-instruct" # bad, repetitions
#model_id = "Deci/DeciLM-7B"  # bad


load_4bit = True
quant_type = "nf4"
double_quant = True
bnb_compute_type = bfloat16
max_new_tokens = 2048
# contrastive search
# https://huggingface.co/blog/introducing-csearch
penalty_alpha = 0.25
top_k = 4

# Vector store location
vect_db_name = 'break_reg_qa'
_dst = os.path.join(cwd, *("vector_db", vect_db_name))
os.makedirs(_dst, exist_ok=True)
vector_store_location = _dst

# Retriever
# Will use MMR method
# Modify search kwards accordingly
search_type = 'mmr'
search_kwargs = {
    # Amount of documents to return (Default: 4)
    "k": 20,
    # Amount of documents to pass to MMR algorithm (Default: 20)
    "fetch_k": 20,
    # 1 for minimum diversity and 0 for maximum. (Default: 0.5)
    "lambda_mult": 0.85
}

# Load data

In [None]:
preparsed_loc = "/home/sf/data/py_proj/2024/RAG-qa-your-docs/preparsed_datadata/regulations.jsonl"
req_loc = "/home/sf/data/py_proj/2024/RAG-qa-your-docs/preparsed_datadata/requirements.txt"
docs = load_docs_from_jsonl(preparsed_loc) 
requirements = []
with open(req_loc, 'r') as f:
    for line in f:
        requirements.append(line)

# Set up LLM and the embedder

## Embedder

In [None]:
EmbedCLS = Embedders(model_id=embed_model_id,
                     model_cache_dir=model_cache, 
                     model_kwargs=embed_model_kwargs, 
                     show_progress=embed_show_prog, 
                     embed_cache=embed_cache)
EmbedCLS.load()
embedder = EmbedCLS.get_embedder()

## LLM

In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=load_4bit,
    bnb_4bit_quant_type=quant_type,
    bnb_4bit_use_double_quant=double_quant,
    bnb_4bit_compute_dtype=bnb_compute_type
)

LMEngine = Engine(model_id=model_id,
                  cache_fld=model_cache,
                  quant_config=bnb_config,
                  device_map='auto',
                  max_new_tokens=max_new_tokens,
                  top_k=top_k, 
                  penalty_alpha=penalty_alpha)
LMEngine.load()
LMEngine.set_pipeline(batch_size=4)

llm = LMEngine.get_llm()
tokenizer = LMEngine.get_tokenizer()

In [None]:
LMEngine.penalty_alpha=penalty_alpha
LMEngine.top_k=top_k
LMEngine.set_pipeline(batch_size=4)
llm = LMEngine.get_llm()

# Vector Store

In [None]:
VectorStore = VectorDB(embedder=embedder, vector_store_location=vector_store_location)
VectorStore.create(docs)

db = VectorStore.db

# Set retriever

In [None]:
ComplexRetriever = ComposedRetriever(db, search_type, **search_kwargs)

# Chain & prompt template

In [None]:
instruct1 = "Given these documents:"
instruct2 = "List ID and a brief explanation. \
Each explanation shall be no longer than three sentences. \
Keep the answer concise."
# If you don't know the answer, just say that you don't know. \
query = "Which of these regulations are relevant for the following query:"

prompt_template = RAGPromptTemplates.long_context(instruct1, instruct2, query)

ChainsWrappers = WrapperChains(llm)
chain = ChainsWrappers.make_long_context_chain(prompt_template)

# Single query

In [None]:
report_loc = os.path.join(cwd, *("reports", "run1_02-23-2024"))
os.makedirs(report_loc, exist_ok=True)
idx = 10 # normally start from 0

In [None]:
%%time
#idx -= 1 # in case of multiple questions on the same topic
question = "Manual shut off valve for the fuel piping shall be installed." #+ \
#" We are looking for regulations that require a manual ."

fid = f"{idx:04d}.txt"

extracted_docs = ComplexRetriever.get_docs(question)
ans = chain.run(input_documents=extracted_docs, query=query)
t = gen_report(question, ans, extracted_docs)
print(t)

fname = uniquify(os.path.join(report_loc, fid))

with open(fname, 'w') as f:
    f.writelines(t)

idx += 1

# Batched processing of the requirements

In [None]:
report_loc = os.path.join(cwd, *("reports", "run1_02-22-2024"))
os.makedirs(report_loc, exist_ok=True)

In [None]:
for idx, question in tqdm(enumerate(requirements)):
    try:
        report = []
        extracted_docs = ComplexRetriever.get_docs(question)
        ans = chain.run(input_documents=extracted_docs, query=query)
    
        # merge the docs into a string
        extracted_docs_s = ""
        _t = []
        for doc in extracted_docs:
            _t.append(doc.page_content)
        extracted_docs_s = ("\n\n"+40*"*"+"\n").join(_t)
        
        report = f"""QUESTION:
{question}
================================================================

Generated response:
{ans}
================================================================

Extracted documents {len(extracted_docs)}. 
These are the contens:

{extracted_docs_s}
        """
        fname = os.path.join(report_loc, f"{idx:04d}.txt")
        with open(fname, 'w') as f:
            f.writelines(report)
    except Exception as e:
        print(f"{idx}: {e}")