In [1]:
__import__("pysqlite3")
import sys

sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")

from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.prompts.few_shot import FewShotChatMessagePromptTemplate
from langchain.callbacks import StreamingStdOutCallbackHandler
from langchain.memory import ConversationBufferMemory
from langchain.schema.runnable import RunnablePassthrough, RunnableLambda
from langchain.chains import LLMChain
from langchain.document_loaders import UnstructuredFileLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.embeddings import OpenAIEmbeddings, CacheBackedEmbeddings
from langchain.vectorstores import Chroma
from langchain.storage import LocalFileStore

In [2]:
memory = ConversationBufferMemory(return_messages=True, memory_key="history")

llm = ChatOpenAI(temperature=0.1)

In [3]:
cache_dir = LocalFileStore("./.cache")

splitter = CharacterTextSplitter(
    separator="\n",
    chunk_size=600,
    chunk_overlap=100,
)

loader = UnstructuredFileLoader("./files/document.txt")
docs = loader.load_and_split(text_splitter=splitter)

embeddings = OpenAIEmbeddings()
cached_embedding = CacheBackedEmbeddings.from_bytes_store(
    embeddings,
    cache_dir,
)

vectorstore = Chroma.from_documents(docs, cached_embedding)

retriever = vectorstore.as_retriever()

Created a chunk of size 717, which is longer than the specified 600
Created a chunk of size 608, which is longer than the specified 600
Created a chunk of size 642, which is longer than the specified 600
Created a chunk of size 1444, which is longer than the specified 600
Created a chunk of size 1251, which is longer than the specified 600
Created a chunk of size 1012, which is longer than the specified 600
Created a chunk of size 1493, which is longer than the specified 600
Created a chunk of size 819, which is longer than the specified 600
Created a chunk of size 1458, which is longer than the specified 600
Created a chunk of size 1411, which is longer than the specified 600
Created a chunk of size 742, which is longer than the specified 600
Created a chunk of size 669, which is longer than the specified 600
Created a chunk of size 906, which is longer than the specified 600
Created a chunk of size 703, which is longer than the specified 600
Created a chunk of size 1137, which is lon

In [4]:
prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You are a helpful assistant. Answer questions using only the following context. If you don't know the answer just say you don't know, don't make it up:\n\n{context}",
        ),
        MessagesPlaceholder(variable_name="history"),
        ("human", "{question}"),
    ]
)


chain = (
    {
        "context": retriever,
        "question": RunnablePassthrough(),
        "history": lambda _: memory.load_memory_variables({})["history"],
    }
    | prompt
    | llm
)

In [5]:
def invoke_qa(query: str):
    result = chain.invoke(query)
    memory.save_context({"input": query}, {"output": result.content})
    return result

In [6]:
invoke_qa("Is Aaronson guilty?")

AIMessage(content='Jones, Aaronson, and Rutherford were guilty of the crimes they were charged with.')

In [7]:
invoke_qa("What message did he write in the table?")

AIMessage(content='The message he wrote under the table was: "GOD IS POWER"')

In [8]:
invoke_qa("Who is Julia?")

AIMessage(content="Julia is a significant person in the protagonist's life, with whom he shares a complex and emotional relationship.")