In [15]:
import logging
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.vectorstores import Chroma
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_core.messages import SystemMessage
from langchain_text_splitters import CharacterTextSplitter

def load_pdf_and_split(path):
    pdf_loader = PyPDFLoader(path)
    pdf_pages = pdf_loader.load_and_split()
    text_chunks = pdf_pages[0].page_content
    return text_chunks

def split_text_into_chunks(text_chunks, chunk_size):
    text_splitter = CharacterTextSplitter(chunk_size=chunk_size)
    chunks = text_splitter.create_documents([text_chunks])
    return chunks

def embed_documents(text_chunks, api_key, model):
    embedding_model = GoogleGenerativeAIEmbeddings(google_api_key=api_key, model=model)
    vectors = embedding_model.embed_documents(text_chunks)
    return vectors

def create_vector_store(chunks, embedding_model):
    db = Chroma.from_documents(chunks, embedding_model)
    db.persist()
    return db

def create_retriever(db_connection):
    retriever = db_connection.as_retriever(search_kwargs={"k": 5})
    return retriever

def build_rag_chain(retriever, chat_template, api_key, model):
    output_parser = StrOutputParser()
    rag_chain = (
        {"context": retriever | (lambda docs: "\n\n".join(doc.page_content for doc in docs)),
         "question": RunnablePassthrough()}
        | chat_template
        | (ChatGoogleGenerativeAI(google_api_key=api_key, model=model)
          | output_parser)
    )
    return rag_chain

def main():
    path = '/content/2404.07143.pdf'
    text_chunks = load_pdf_and_split(path)
    chunks = split_text_into_chunks(text_chunks, chunk_size=300)
    with open("/content/api.txt") as f:
        api_key = f.read().strip()
    vectors = embed_documents(text_chunks, api_key, model="models/embedding-001")
    db = create_vector_store(chunks, GoogleGenerativeAIEmbeddings(google_api_key=api_key, model="models/embedding-001"))
    db_connection = Chroma(embedding_function=GoogleGenerativeAIEmbeddings(google_api_key=api_key, model="models/embedding-001"))
    retriever = create_retriever(db_connection)
    chat_template = ChatPromptTemplate.from_messages([
        SystemMessage(content="I'm a helpful AI assistant. I'll use the provided document to answer your questions."),
        HumanMessagePromptTemplate.from_template("""Answer the following question based on the provided context:

        Context:
        {context}

        Question:
        {question}

        Answer:""")
    ])
    model = "gemini-1.5-pro-latest"
    rag_chain = build_rag_chain(retriever, chat_template, api_key, model)
    user_question = input("Enter your question: ")
    logging.info(f"User question: {user_question}")
    response = rag_chain.invoke(user_question)
    print(response)

if __name__ == "__main__":
    main()


Enter your question: Give the most unique points




## Unique Points of Infini-attention:

Based on the provided context, here are the most unique points of Infini-attention:

* **Combines Compressive Memory with Attention:** Infini-attention integrates a compressive memory system into the traditional attention mechanism. This allows it to handle infinitely long input sequences while maintaining bounded memory and computational costs. This is a significant departure from standard Transformers, which struggle with memory limitations for long sequences.
* **Hybrid Attention Mechanism:**  It utilizes both masked local attention and long-term linear attention within a single Transformer block. This enables the model to capture both local context and long-range dependencies effectively.
* **Efficiency and Scalability:** The use of compressive memory allows for efficient processing of long sequences, making it more scalable than traditional attention mechanisms. This efficiency also translates to faster streaming inference for LLMs.
* **Minim

In [28]:
import logging
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.vectorstores import Chroma
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from langchain.prompts import ChatPromptTemplate
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.runnables import RunnablePassthrough
from langchain_core.messages import SystemMessage
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain_core.output_parsers import StrOutputParser

class Document:
    def __init__(self, page_content, metadata=None):
        self.page_content = page_content
        self.metadata = metadata


class RAGChain:
    def __init__(self, retriever, llm, output_parser, format_docs):
        self.retriever = retriever
        self.llm = llm
        self.output_parser = output_parser
        self.format_docs = format_docs

    def invoke(self, input_string, retrieved_docs=None):
        if retrieved_docs is not None:
            docs = [Document(page_content=doc.page_content) for doc in retrieved_docs]
        else:
            docs = self.retriever.retrieve(input_string)

        docs_formatted = self.format_docs(docs)
        output = self.llm(docs_formatted)
        parsed_output = self.output_parser.parse(output)
        return parsed_output

pdf_loader = PyPDFLoader('/content/2404.07143.pdf')
pdf_pages = pdf_loader.load_and_split()

text_chunks = [Document(page_content=page.page_content) for page in pdf_pages]

embedding_model = GoogleGenerativeAIEmbeddings(google_api_key="____",
                                               model="models/embedding-001")

db = Chroma.from_documents(text_chunks, embedding_model, persist_directory="./chroma_db_")
db.persist()


db_connection = Chroma(persist_directory="./chroma_db_", embedding_function=embedding_model)
retriever = db_connection.as_retriever(search_kwargs={"k": 5})


llm = ChatGoogleGenerativeAI(google_api_key="____",
                             model="models/gemini-1.5-pro-latest")

output_parser = StrOutputParser()

def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

chat_template = ChatPromptTemplate.from_messages([
    SystemMessage(content="""You are a Helpful AI Bot.
    You take the context and question from user. Your answer should be based on the specific context."""),
    HumanMessagePromptTemplate.from_template("""Answer the question based on the given context.
    Context:
    {context}

    Question:
    {question}

    Answer: """)
])

rag_chain = RAGChain(retriever, llm, output_parser, format_docs)

user_input = "Can you explain the main contribution of the Leave No Context Behind paper?"


retrieved_docs = retriever.invoke(user_input)
filtered_docs = [doc for doc in retrieved_docs if isinstance(doc, Document)]

if filtered_docs:
    response = rag_chain.invoke(user_input, retrieved_docs=filtered_docs)
    print(response)
else:
    print("No valid Document objects found in retrieved_docs.")

print(response)


No valid Document objects found in retrieved_docs.
## Summary of "Leave No Context Behind: Efficient Infinite Context Transformers with Infini-attention"

This research introduces **Infini-attention**, a novel method for scaling Transformer-based Large Language Models (LLMs) to handle infinitely long input sequences while maintaining manageable memory and computational requirements. 

**The Problem:** Traditional Transformer models struggle with long sequences due to the quadratic complexity of the attention mechanism, leading to massive memory consumption and computational costs. This limits their ability to process and understand extensive contexts effectively.

**The Solution:** Infini-attention tackles this challenge by incorporating a **compressive memory** into the attention mechanism. This memory efficiently stores and retrieves information from long sequences using a fixed number of parameters, avoiding the memory explosion of standard attention. The method combines both **mask