In [None]:
import os
import requests
import gradio as gr
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.chains import RetrievalQA
from genai.extensions.langchain import LangChainInterface
from langchain.document_loaders import UnstructuredPDFLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.indexes import VectorstoreIndexCreator
from genai.model import Credentials, Model
from genai.schemas import GenerateParams, ModelType
from langchain.vectorstores.faiss import FAISS
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline
import torch
from langchain.document_loaders import TextLoader
from langchain.vectorstores import Chroma
from langchain import PromptTemplate, LLMChain

In [None]:
#使用するPDF

pdf_folder_path = './pdfs'
os.listdir(pdf_folder_path)

In [None]:
#変更できる変数
chunk_limit = 500
chunk_overlap = 500

#ベクターストアの場所
vec = "article_" +  str(chunk_limit) + "_" + str(chunk_overlap)
persist_directory="vec/" + vec

In [None]:
#LLMの準備
# APIのkeyを挿入
api_key = input("input your BAM Key")
api_endpoint = "https://bam-api.res.ibm.com/v1/"

creds = Credentials(api_key=api_key, api_endpoint=api_endpoint)

In [None]:
# PDFのローディング
loaders = [UnstructuredPDFLoader(os.path.join(pdf_folder_path, fn)) for fn in os.listdir(pdf_folder_path)] #ローダー定義
documents = [docu.load() for docu in loaders]
new_documents = ""
for i in range(len(documents)):
    new_documents += documents[i][0].page_content

In [None]:
embeddings = HuggingFaceEmbeddings(model_name="intfloat/multilingual-e5-large")

if vec in os.listdir("vec/"):
    #再利用
    index = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
else:
    
    #チャンクの分割
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_limit, chunk_overlap=chunk_overlap)
    pdf_text = text_splitter.split_text(new_documents)
    
    # ドキュメントの読みこみ
    index = Chroma.from_texts(pdf_text, embeddings, persist_directory=persist_directory)

In [None]:
params_qa = GenerateParams(
    decoding_method="greedy",
    min_new_tokens=3,
    max_new_tokens=1500,
    stream=False,
    repetition_penalty=1.1,
).dict() 

llm_qa = LangChainInterface(model='meta-llama/llama-2-70b-chat', credentials=creds, params=params_qa)

In [None]:
#リトリーバーの設定(ユーザーから質問を受け取り、ベクターストアに投げかけ、関連コードを返してくれるところ)
qa_chain = RetrievalQA.from_chain_type(llm=llm_qa, 
                                    chain_type="stuff", 
                                    retriever=index.as_retriever(search_kwargs={'k': 4, 'fetch_k': 50, 'score_threshold': 0.8}), 
                                    input_key="question")

In [None]:
def generate_qa(pdf_text, question):
    template = """<s>[INST] <<SYS>>
You are a helpful, respectful and honest assistant in Japanese. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.

If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
<</SYS>>

あなたは以下の文脈の情報のみを使って最後の質問に答えます。ただし、必ず日本語で回答して下さい。

文脈: 
{ground_text}
                
質問: 
{question}[/INST]

回答:"""
    prompt = PromptTemplate(
        template=template, 
        input_variables=["ground_text", "question"]
    )
    llm_chain = LLMChain(prompt=prompt, llm=llm_qa, verbose=True)
    answer = llm_chain.predict(ground_text=pdf_text, question=question)
     
    return answer

In [None]:
# Chat UIの起動
def add_text(history, text):
    history = history + [(text, None)]
    return history, ""

def bot(history):
    question = history[-1][0]

    retriever = index.as_retriever()
    docs = retriever.get_relevant_documents(query=question)
    doc = ""
    # print(f"\n参考箇所: \n {docs}")
    for d in docs:
        doc += d.page_content
        doc += "\n"

    answer = generate_qa(doc, question)
    answer = answer
    print(f"\n回答: \n {answer}")

    retriever = index.as_retriever()
    docs = retriever.get_relevant_documents(query=question)
    print(f"\n参考箇所: \n {docs}")

    history[-1][1] = answer
    return history

with gr.Blocks() as demo:
    chatbot = gr.Chatbot([], elem_id="chat with PDF").style(height=600)
    with gr.Row():
        with gr.Column(scale=0.6):
            txt = gr.Textbox(
                show_label=False,
                placeholder="Type your question and press enter",
            ).style(container=False)

    txt.submit(add_text, [chatbot, txt], [chatbot, txt]).then(
        bot, chatbot, chatbot
    )

demo.launch()