In [1]:
from langchain_community.document_loaders import DirectoryLoader

loader = DirectoryLoader('data', glob="**/*.txt")

docs = loader.load()

In [2]:
from langchain_community.embeddings import HuggingFaceEmbeddings

from langchain_community.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter


embeddings = HuggingFaceEmbeddings()

text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
documents = text_splitter.split_documents(docs)
vector = FAISS.from_documents(documents, embeddings)

In [3]:
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

#model_path = "mistralai/Mistral-7B-Instruct-v0.1"
model_path = "meta-llama/Llama-2-7b-chat-hf"

model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_path)

# Tok config
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "left"

pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=20,
    pad_token_id=tokenizer.eos_token_id
)
llm = HuggingFacePipeline(pipeline=pipe)

#llm = HuggingFacePipeline.from_model_id(
#    model_id=model_path,
#    task="text-generation",
#    device_map="auto",  # replace with device_map="auto" to use the accelerate library.
#    pipeline_kwargs={"max_new_tokens": 20},
#)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
from langchain_core.prompts import PromptTemplate
from langchain.chains.combine_documents import create_stuff_documents_chain

# MISTRAL
#prompt = PromptTemplate.from_template("""[INST] Answer the following question based only on the provided context:
#
#<context>
#{context}
#</context>
#
#Question: {input} [/INST]""")

# LLAMA2
prompt = PromptTemplate.from_template("""[INST] <<SYS>>
Answer the following question based only on the provided context.
<</SYS>>

<context>
{context}
</context>

Question: {input} [/INST]""")

document_chain = create_stuff_documents_chain(llm, prompt)

In [5]:
from langchain.chains import create_retrieval_chain

retriever = vector.as_retriever(search_args={"k": 4})
retrieval_chain = create_retrieval_chain(retriever, document_chain)

In [6]:
#response = retrieval_chain.invoke({"input": "Cristiano Ronaldo currently plays for"})
#response

In [7]:
import json
import os

from tqdm import tqdm

#outdated_qa_file = "/home/simone/papers/ACL/knowledge-editing/models_generation/results_w_prompt/Mistral-7B-Instruct-v0.1/qa_to_update.json"
outdated_qa_file = "/home/simone/papers/ACL/knowledge-editing/models_generation/results_w_prompt/llama2-7b-chat/qa_to_update.json"

out_dir = os.path.join("results", model_path.split('/').pop())
os.makedirs(out_dir, exist_ok=True)

with open(outdated_qa_file, "r") as f:
    outdated_questions = json.load(f)

answers = {}
for domain in outdated_questions:
    if domain not in answers:
        answers[domain] = {}
    for element in tqdm(outdated_questions[domain], desc=domain):
        if element not in answers[domain]:
            answers[domain][element] = {}
        if domain in ["countries_byGDP", "organizations"]:
            for attribute in outdated_questions[domain][element]:
                if attribute not in answers[domain][element]:
                    answers[domain][element][attribute] = {}
                questions = outdated_questions[domain][element][attribute]["questions"]
                res = {
                    "questions": {},
                    "answers": {}
                }
                for qt, q in questions.items():
                    response = retrieval_chain.invoke({"input": q})
                    res["questions"][qt] = q
                    res["answers"][qt] = response["answer"]
                answers[domain][element][attribute] = res
        else:
            questions = outdated_questions[domain][element]["questions"]
            
            res = {
                "questions": {},
                "answers": {}
            }
            for qt, q in questions.items():
                response = retrieval_chain.invoke({"input": q})
                res["questions"][qt] = q
                res["answers"][qt] = response["answer"]
            answers[domain][element] = res
for domain in answers:
    with open(os.path.join(out_dir,  f"{domain}_answers.json"), "w") as f:
        json.dump(answers[domain], f, indent=4)

companies_byRevenue: 100%|██████████| 4/4 [00:20<00:00,  5.07s/it]
organizations: 100%|██████████| 1/1 [00:05<00:00,  5.76s/it]
countries_byGDP: 100%|██████████| 31/31 [02:49<00:00,  5.48s/it]
athletes_byPayment: 100%|██████████| 10/10 [00:55<00:00,  5.56s/it]


In [None]:
!pip install sentence_transformers
!pip install faiss-cpu

In [None]:
!pip install langchain

In [None]:
!pip install huggingface_hub
!pip install transformers

In [None]:
!pip install unstructured