In [2]:
import os

from second_brain.config import settings
from second_brain.infrastructure.mongo import MongoDBService

os.environ["OPENAI_API_KEY"] = settings.OPENAI_API_KEY

mongodb_client = MongoDBService(settings.MONGODB_URI)

[32m2025-01-18 10:56:28.616[0m | [1mINFO    [0m | [36msecond_brain.infrastructure.mongo.service[0m:[36m__init__[0m:[36m54[0m - [1mConnected to MongoDB instance:
 URI: mongodb://decodingml:decodingml@localhost:27017/?directConnection=true
 Database: second_brain
 Collection: mongodb://decodingml:decodingml@localhost:27017/?directConnection=true[0m


In [3]:
from langchain_mongodb.retrievers import (
    MongoDBAtlasParentDocumentRetriever,
)
from second_brain.application.rag import get_splitter
from second_brain.application.rag.embeddings import EmbeddingModelBuilder

embedding_model = EmbeddingModelBuilder().get_model()
parent_doc_retriever = MongoDBAtlasParentDocumentRetriever.from_connection_string(
    connection_string=settings.MONGODB_URI,
    embedding_model=embedding_model,
    child_splitter=get_splitter(200),
    parent_splitter=get_splitter(800),
    database_name=settings.MONGODB_DATABASE_NAME,
    collection_name="rag",
    text_key="page_content",
    search_kwargs={"k": 10},
)

In [4]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI

# Retrieve and parse documents
retrieve = {
    "context": parent_doc_retriever
    | (lambda docs: "\n\n".join([d.page_content for d in docs])),
    "question": RunnablePassthrough(),
}
template = """Answer the question based only on the following context. If no context is provided, respond with I DON'T KNOW: \
{context}

Question: {question}
"""
# Define the chat prompt
prompt = ChatPromptTemplate.from_template(template)
# Define the model to be used for chat completion
llm = ChatOpenAI(temperature=0, model="gpt-4o-2024-11-20")
# Parse output as a string
parse_output = StrOutputParser()
# Naive RAG chain
rag_chain = retrieve | prompt | llm | parse_output

In [5]:
answer = rag_chain.invoke("How can I optimize LLMs for inference?")
print(answer)


To optimize LLMs for inference, you can use the following techniques:

1. **Lower Precision (Quantization)**:
   - Use 8-bit or 4-bit precision to reduce memory usage and computational requirements without significant performance loss.

2. **Flash Attention**:
   - Implement Flash Attention for faster and more memory-efficient inference by utilizing on-chip memory (SRAM) instead of slower GPU VRAM.

3. **Speculative Decoding**:
   - Use a smaller model to generate draft tokens and a larger model to verify them, reducing latency, memory usage, and compute demands.

4. **Caching**:
   - Implement KV-caching or prompt caching to reuse computations and speed up inference.

5. **Compilers**:
   - Use tools like `torch.compile()` or TensorRT to optimize model execution.

6. **Continuous Batching**:
   - Dynamically batch requests to maximize GPU utilization.

7. **Optimized Attention Mechanisms**:
   - Use techniques like PagedAttention or FlashAttention for efficient attention computation.
