# Faiss load/HuggingFace 파이프라인 생성

In [None]:
import langchain
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
from langchain_community.vectorstores import FAISS
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

from dpr_embedding import CustomEmbeddings

MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
DEVICE = "cuda:0"

In [None]:
embedding = CustomEmbeddings("facebook/dpr-question_encoder-single-nq-base", model_kwargs={"device_map": DEVICE})
faiss = FAISS.load_local("./db/faiss", embeddings=embedding, allow_dangerous_deserialization=True)

In [None]:
llama = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map = DEVICE, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token_id = tokenizer.eos_token_id
llama.generation_config.pad_token_id = tokenizer.eos_token_id

In [6]:
#llama.generation_config.max_length = 96
pipeline = transformers.pipeline(
    model = llama,
    tokenizer = tokenizer,
    task="text-generation",
    device_map=DEVICE, max_new_tokens=96,
)


# Langchain Pipelining

In [15]:
from langchain import hub
from langchain_core.prompt_values import ChatPromptValue
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate

prompt = ChatPromptTemplate.from_messages(
    [
        ("human",
"""You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. For example:\n
###
Question: when is the last episode of season 8 of the walking dead
Context: The eighth season of The Walking Dead, an American post-apocalyptic horror television series on AMC, premiered on October 22, 2017, and concluded on April 15, 2018, consisting of 16 episodes.
Answer: March 18, 2018
Question: what is the name of the most important jewish text
Context: Codes of Jewish law are written that are based on the responsa; the most important code, the Shulchan Aruch, largely determines Orthodox religious practice today. Jewish philosophy refers to the conjunction between serious study of philosophy and Jewish theology.
Answer: the Shulchan Aruch
###
Keep the answer very concise as one phrase, rather than sentences or clauses.\n
Question: {question} \nContext: {context} \nAnswer:""")
    ]
)

terminators = [
    pipeline.tokenizer.eos_token_id,
    pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
llm = HuggingFacePipeline(pipeline=pipeline,
                          model_kwargs={
                          "eos_token_id":terminators
                          })

retriever = faiss.as_retriever(search_kwargs={"k":4})

def format_docs(docs):
    return "\n\n".join([doc.page_content for doc in docs])

def llama_preprocessing(chatvalue: ChatPromptValue):
    messages = chatvalue.to_messages()
    chat = []
    for msg in messages:
        msg_dict = msg.dict()
        chat.append({"role": msg_dict["type"], "content": msg_dict["content"]})
    return tokenizer.apply_chat_template(chat,tokenize=False)

def output_parser(ai_message) -> str:
    eot_str = "<|eot_id|>"
    eot_idx = ai_message.find(eot_str)
    if eot_idx != -1:
        ai_message = ai_message[eot_idx+len(eot_str):]
    return ai_message

rag_chain = (
    {"context": retriever | format_docs, "question": RunnablePassthrough()}
    | prompt
    | llama_preprocessing
    | llm
    | output_parser # 프롬프트를 포함한 출력을 보고싶다면 이 부분을 주석처리
    | StrOutputParser()
)


In [None]:
# debug run
async for chunk in rag_chain.astream_log(
    "Who is the Moses's brother?", include_names="Docs"
):
    print("-"*40)
    print(chunk)

## load and preprocessing NQ dataset

In [None]:
nq_raw_datasets = datasets.load_dataset("google-research-datasets/natural_questions")

new_datasets = nq_raw_datasets.filter(
    lambda batch: [len("".join([t for sa in d["short_answers"] for t in sa["text"] ]) )>0 for d in batch["annotations"]],
    batched=True)

def preprocessing(data):
    sa_list = data["annotations"]["short_answers"]
    answer = ""
    for sa in sa_list:
        if len(sa["text"]) != 0:
            answer = ', '.join(sa["text"])
            break
    data["answer"] = answer.strip()
    return data

new_datasets = new_datasets.map(preprocessing)

## Test RAG + general-purpose LLM

In [None]:
for i in range(20,30):
    example = new_datasets["validation"][i]
    question = example["question"]["text"]
    answer = example["answer"]
    print("="*5)
    print("Q:", question)
    pred = rag_chain.invoke(question)
    print("pred:", pred.strip().strip("assistant").strip("Assistant").strip())
    print("GT:", answer)