# Indexing
I need to index everything for sleap and sleap-io at first. 

In [6]:
import os
import dotenv
dotenv.load_dotenv(".local.env")
os.environ["LANGSMITH_TRACING_V2"] = "true"
os.environ["LANGSMITH_ENDPOINT"] = "https://api.smith.langchain.com"
os.environ["LANGSMITH_API_KEY"] = dotenv.get_key(".local.env", "LANGSMITH_API_KEY")
os.environ["LANGSMITH_PROJECT"] = "rag-sleap-docs"

In [7]:
import langchain
import langchain_community
import langchain_google_vertexai
import chromadb
import os
import vertexai
vertexai.init(project=dotenv.get_key(".local.env", "GOOGLE_CLOUD_PROJECT_ID"))

In [None]:
# Sanity check: at this point I should be able to see the LangSmith configuration
from langchain_google_vertexai import ChatVertexAI

# gemini = ChatVertexAI(
#     model_name="gemini-2.0-flash-lite",
#     temperature=0.2)
# gemini.invoke("Repeat after me: 'Hello, world!'")

AIMessage(content='Hello, world!\n', additional_kwargs={}, response_metadata={'is_blocked': False, 'safety_ratings': [], 'usage_metadata': {'prompt_token_count': 9, 'candidates_token_count': 5, 'total_token_count': 14, 'prompt_tokens_details': [{'modality': 1, 'token_count': 9}], 'candidates_tokens_details': [{'modality': 1, 'token_count': 5}], 'thoughts_token_count': 0, 'cached_content_token_count': 0, 'cache_tokens_details': []}, 'finish_reason': 'STOP', 'avg_logprobs': -0.010825920104980468, 'model_name': 'gemini-2.0-flash-lite'}, id='run--4d62e52c-c416-4718-9bbd-0f418449d22e-0', usage_metadata={'input_tokens': 9, 'output_tokens': 5, 'total_tokens': 14, 'input_token_details': {'cache_read': 0}})

In [9]:
# Make or bind a ChromaDB client

client = chromadb.PersistentClient(path="./chroma_db")
try:
    sleap_collection = client.get_or_create_collection(name="sleap")
    sleap_io_collection = client.get_or_create_collection(name="sleap_io")
    print("Collections created or accessed successfully.")
except Exception as e:
    print(f"Error creating or accessing collection: {e}")

Collections created or accessed successfully.


## Smart Embedding

Instead of this html based method, I will separately parse the codes (with comments) with ast and generate embeddings for each function and class. This will allow me to search for specific functions or classes more effectively. For guides and examples, I will use the existing markdown files.

In [10]:
import ast
from pathlib import Path
from langchain_core.documents import Document

# --- 1. CONFIGURE YOUR PATHS ---
# Point this to the root of the cloned sleap repository
REPO_PATH = Path("/Users/chan/PersonalProjects/rag-sleap-docs/sleap") 
DOCS_PATH = REPO_PATH / "docs"
SRC_PATH = REPO_PATH / "sleap"

# This class uses an Abstract Syntax Tree to safely parse Python code
class CodeParser(ast.NodeVisitor):
    """An AST visitor to extract functions, classes, and their docstrings."""
    def __init__(self, file_path: str, source: str, repo_path: Path = REPO_PATH):
        self.file_path = file_path
        self.source = source
        self.documents = []
        self.repo_path = repo_path

    def visit_FunctionDef(self, node: ast.FunctionDef):
        docstring = ast.get_docstring(node)
        if docstring:
            # Reconstruct a simple signature
            signature = f"def {node.name}({ast.unparse(node.args)}):"
            content = f"{signature}\n\n{docstring}"
            self.documents.append(Document(
                page_content=content,
                metadata={"source": f"{self.source}-api", "file": self.file_path, "object": node.name}
            ))
        self.generic_visit(node) # Continue visiting children

    def visit_ClassDef(self, node: ast.ClassDef):
        docstring = ast.get_docstring(node)
        if docstring:
            signature = f"class {node.name}:"
            content = f"{signature}\n\n{docstring}"
            self.documents.append(Document(
                page_content=content,
                metadata={"source": f"{self.source}-api", "file": self.file_path, "object": node.name}
            ))
        self.generic_visit(node) # Continue visiting children

def parse_source_code(src_path: Path, source: str, repo_path: Path = REPO_PATH) -> list[Document]:
    """Parses Python source files to extract API documentation."""
    print(f"Parsing source code in: {src_path}...")
    documents = []
    for py_file in src_path.rglob("*.py"):
        try:
            file_content = py_file.read_text(encoding="utf-8")
            tree = ast.parse(file_content)
            relative_path = str(py_file.relative_to(repo_path))
            parser = CodeParser(relative_path, source, repo_path)
            parser.visit(tree)
            documents.extend(parser.documents)
        except Exception as e:
            print(f"--> Could not parse {py_file}: {e}")
            
    print(f"-> Found {len(documents)} docstrings in source code.")
    return documents

def parse_guides(docs_path: Path, source: str, repo_path: Path = REPO_PATH) -> list[Document]:
    """Parses Markdown guides, loading each file as a single document."""
    print(f"Parsing guides in: {docs_path}...")
    documents = []
    for md_file in docs_path.rglob("*.md"):
        file_content = md_file.read_text(encoding="utf-8")
        relative_path = str(md_file.relative_to(repo_path))
        doc = Document(
            page_content=file_content,
            metadata={"source": f"{source}-guide", "file": relative_path}
        )
        documents.append(doc)
    
    for rst_file in docs_path.rglob("*.rst"):
        file_content = rst_file.read_text(encoding="utf-8")
        relative_path = str(rst_file.relative_to(REPO_PATH))
        doc = Document(
            page_content=file_content,
            metadata={"source": f"{source}-guide", "file": relative_path}
        )
        documents.append(doc)
            
    print(f"-> Found {len(documents)} guide files.")
    return documents



# Execute the parsing functions
guide_docs = parse_guides(DOCS_PATH, "sleap", repo_path=REPO_PATH)
api_docs = parse_source_code(SRC_PATH, "sleap", repo_path=REPO_PATH)

# Combine the results into a single list
all_raw_docs = guide_docs + api_docs

print(f"\n✅ Total documents collected: {len(all_raw_docs)}")

# print 2 examples each
print("\nExample documents:")
for doc in guide_docs[:2]:
    print(f"Source: {doc.metadata['source']}, File: {doc.metadata.get('file', 'N/A')}")
    print(doc.page_content[:100] + "...")  # Print first 100 characters
    print("-" * 40)
    
for doc in api_docs[:2]:
    print(f"Source: {doc.metadata['source']}, File: {doc.metadata.get('file', 'N/A')}, Object: {doc.metadata.get('object', 'N/A')}")
    print(doc.page_content[:100] + "...\n")  # Print first 100 characters
    print("-" * 40)

Parsing guides in: /Users/chan/PersonalProjects/rag-sleap-docs/sleap/docs...
-> Found 31 guide files.
Parsing source code in: /Users/chan/PersonalProjects/rag-sleap-docs/sleap/sleap...




-> Found 2072 docstrings in source code.

✅ Total documents collected: 2103

Example documents:
Source: sleap-guide, File: docs/CODE_OF_CONDUCT.md
# Contributor Covenant Code of Conduct

## Our Pledge

In the interest of fostering an open and welc...
----------------------------------------
Source: sleap-guide, File: docs/help.md
# Help

Stuck? Can't get SLEAP to run? Crashing? Try the recommended tips below.

## Installation

#...
----------------------------------------
Source: sleap-api, File: sleap/rangelist.py, Object: RangeList
class RangeList:

Class for manipulating a list of range intervals.
Each range interval in the list ...

----------------------------------------
Source: sleap-api, File: sleap/rangelist.py, Object: list
def list(self):

Returns the list of ranges....

----------------------------------------


In [11]:
# Now we do sleap-io docs

SLEAPIO_REPO_PATH = Path("/Users/chan/PersonalProjects/rag-sleap-docs/sleap-io") 
SLEAPIO_DOCS_PATH = SLEAPIO_REPO_PATH / "docs"
SLEAPIO_SRC_PATH = SLEAPIO_REPO_PATH / "sleap_io"


# Execute the parsing functions
sio_guide_docs = parse_guides(SLEAPIO_DOCS_PATH, "sleap-io", repo_path=SLEAPIO_REPO_PATH)
sio_api_docs = parse_source_code(SLEAPIO_SRC_PATH, "sleap-io", repo_path=SLEAPIO_REPO_PATH)

# Combine the results into a single list
sio_all_raw_docs = sio_guide_docs + sio_api_docs

DREEM_REPO_PATH = Path("/Users/chan/PersonalProjects/rag-sleap-docs/dreem")
DREEM_DOCS_PATH = DREEM_REPO_PATH / "docs"
DREEM_SRC_PATH = DREEM_REPO_PATH / "dreem"

dreem_guide_docs = parse_guides(DREEM_DOCS_PATH, "dreem", repo_path=DREEM_REPO_PATH)
dreem_api_docs = parse_source_code(DREEM_SRC_PATH, "dreem", repo_path=DREEM_REPO_PATH)

# Combine the results into a single list
dreem_all_raw_docs = dreem_guide_docs + dreem_api_docs


Parsing guides in: /Users/chan/PersonalProjects/rag-sleap-docs/sleap-io/docs...
-> Found 4 guide files.
Parsing source code in: /Users/chan/PersonalProjects/rag-sleap-docs/sleap-io/sleap_io...
-> Found 398 docstrings in source code.
Parsing guides in: /Users/chan/PersonalProjects/rag-sleap-docs/dreem/docs...
-> Found 24 guide files.
Parsing source code in: /Users/chan/PersonalProjects/rag-sleap-docs/dreem/dreem...
-> Found 345 docstrings in source code.


## Now, to split everything: 

In [12]:
from langchain_text_splitters import RecursiveCharacterTextSplitter, Language, MarkdownHeaderTextSplitter

# Combine your raw document lists
all_docs_combined = all_raw_docs + sio_all_raw_docs + dreem_all_raw_docs

# --- 1. Define your splitters ---

# For splitting guides by their headers
headers_to_split_on = [("#", "H1"), ("##", "H2"), ("###", "H3")]
md_splitter = MarkdownHeaderTextSplitter(headers_to_split_on=headers_to_split_on, strip_headers=False)

# For splitting very long docstrings OR language-specific files
chunk_size = 1000
chunk_overlap = 200
recursive_splitter = RecursiveCharacterTextSplitter(
    chunk_size=chunk_size,
    chunk_overlap=chunk_overlap
)
# This is the correct way to handle RST files
rst_splitter = RecursiveCharacterTextSplitter.from_language(
    language=Language.RST, chunk_size=chunk_size, chunk_overlap=chunk_overlap
)


# --- 2. Process documents based on type ---

final_chunks = []

for doc in all_docs_combined:
    source_type = doc.metadata.get("source", "")
    file_path = doc.metadata.get("file", "")

    # A. Process structured guides
    if "guide" in source_type:
        if file_path.endswith(".md"):
            chunks = md_splitter.split_text(doc.page_content)
            # Add original metadata back to the new chunks
            for chunk in chunks:
                chunk.metadata.update(doc.metadata)
            final_chunks.extend(chunks)

        elif file_path.endswith(".rst"):
            # Use the language-specific splitter for RST
            chunks = rst_splitter.split_documents([doc])
            final_chunks.extend(chunks)

    # B. Process code docstrings
    elif "api" in source_type:
        # If the docstring is larger than our chunk size, split it
        if len(doc.page_content) > chunk_size:
            chunks = recursive_splitter.split_documents([doc])
            final_chunks.extend(chunks)
        else:
            # Otherwise, keep it as a single, logical chunk
            final_chunks.append(doc)

# --- 3. Verification ---
print(f"Total documents processed: {len(all_docs_combined)}")
print(f"Total chunks created: {len(final_chunks)}")
print("\nSample chunk metadata from a guide:")
# Find and print a sample guide chunk if it exists
for chunk in final_chunks:
    if "guide" in chunk.metadata.get("source", ""):
        print(chunk.metadata)
        break

Total documents processed: 2874
Total chunks created: 3562

Sample chunk metadata from a guide:
{'H1': 'Contributor Covenant Code of Conduct', 'H2': 'Our Pledge', 'source': 'sleap-guide', 'file': 'docs/CODE_OF_CONDUCT.md'}


In [13]:
# Now I embed the documents and add them to the ChromaDB collections
from langchain_community.vectorstores import Chroma
from langchain_google_vertexai import VertexAIEmbeddings
embeddings = VertexAIEmbeddings(
    model_name="text-embedding-004"
)

# Check if collections already have data
sleap_collection = client.get_collection("sleap")

if sleap_collection.count() > 0:
    print(f"SLEAP collection already has {sleap_collection.count()} documents. Skipping embedding.")
    sleap_vectorstore = Chroma(
        client=client,
        collection_name="sleap",
        embedding_function=embeddings
    )
else:
    print("Embedding SLEAP documents...")
    sleap_vectorstore = Chroma.from_documents(
        final_chunks,
        embeddings,
        collection_name="sleap",
        client=client,
    )




Embedding SLEAP documents...


In [14]:
sleap_retriever = sleap_vectorstore.as_retriever(search_kwargs={"k": 6})

In [None]:
# Test this out
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_google_vertexai import ChatVertexAI
from langchain_core.prompts import ChatPromptTemplate

prompt = ChatPromptTemplate.from_template("""
You are a helpful AI assistant specialized in SLEAP (Social LEAP Estimates Animal Poses), SLEAP-IO, and DREEM documentation.

Use the following context from the documentation to answer the user's question. If the answer cannot be found in the context, say "I don't have enough information in the provided documentation to answer that question."

Context:
{context}

Question: {question}

Instructions:
- Provide accurate, detailed answers based on the documentation
- Include code examples when relevant
- Mention specific function names, classes, or modules when applicable
- If discussing installation or setup, be specific about requirements
- For troubleshooting questions, provide step-by-step solutions
- Always cite which part of the documentation (SLEAP, SLEAP-IO, or DREEM) your answer comes from

Answer:
""")

llm = ChatVertexAI(
    model_name="gemini-2.0-flash-lite",
    temperature=0.2)


# Post-processing
def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

# Chain
rag_chain = (
    {"context": sleap_retriever | format_docs, "question": RunnablePassthrough()}
    | prompt
    | llm
    | StrOutputParser()
)

# Question
rag_chain.invoke("How do I use DREEM on a existing SLEAP prediction?")

In [None]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_google_vertexai import ChatVertexAI
from langchain_core.prompts import ChatPromptTemplate
from langchain.memory import ConversationBufferMemory
from langchain_core.runnables import RunnableLambda

# Create memory to store conversation history
memory = ConversationBufferMemory(
    memory_key="chat_history",
    return_messages=True,
    output_key="answer"
)

# Updated prompt template that includes chat history
prompt = ChatPromptTemplate.from_template("""
You are a helpful AI assistant specialized in SLEAP (Social LEAP Estimates Animal Poses), SLEAP-IO, and DREEM documentation.

Chat History:
{chat_history}

Use the following context from the documentation to answer the user's question. If the answer cannot be found in the context, say "I don't have enough information in the provided documentation to answer that question."

Context:
{context}

Current Question: {question}

Instructions:
- Provide accurate, detailed answers based on the documentation
- Include code examples when relevant
- Mention specific function names, classes, or modules when applicable
- If discussing installation or setup, be specific about requirements
- For troubleshooting questions, provide step-by-step solutions
- Always cite which part of the documentation (SLEAP, SLEAP-IO, or DREEM) your answer comes from
- Reference previous conversation when relevant

Answer:
""")

llm = ChatVertexAI(
    model_name="gemini-2.0-flash-lite",
    temperature=0.2
)

# Post-processing
def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

# Function to get chat history
def get_chat_history(_):
    return memory.chat_memory.messages

# Conversational RAG Chain
conversational_rag_chain = (
    {
        "context": sleap_retriever | format_docs,
        "question": RunnablePassthrough(),
        "chat_history": RunnableLambda(get_chat_history),
    }
    | prompt
    | llm
    | StrOutputParser()
)

# Function to chat with memory
def chat_with_memory(question: str) -> str:
    # Get the response
    response = conversational_rag_chain.invoke(question)
    
    # Save to memory
    memory.save_context({"question": question}, {"answer": response})
    
    return response

# Test the conversational chain
print("=== Conversation 1 ===")
response1 = chat_with_memory("How do I install SLEAP?")
print(response1)

print("\n=== Conversation 2 ===")
response2 = chat_with_memory("What about the GPU requirements?")
print(response2)

print("\n=== Conversation 3 ===")
response3 = chat_with_memory("Can you show me a code example for the installation process we discussed?")
print(response3)