In [1]:
import re
import ast
from typing import Dict, List, Optional
from dataclasses import dataclass
from pathlib import Path
import os
from gitingest import ingest
from langchain_qdrant import QdrantVectorStore
from langchain_ollama import OllamaEmbeddings
from langchain.text_splitter import (
    RecursiveCharacterTextSplitter,
    MarkdownHeaderTextSplitter,
    Language,
    PythonCodeTextSplitter
)
 
import nest_asyncio
import asyncio
from qdrant_client import QdrantClient
import yaml
import json
from qdrant_client.http import models as rest
# Enable nested event loops
nest_asyncio.apply()

# Get connection string from environment variable or use default
qdrant_conn = os.getenv('ConnectionStrings__qdrant_http')
print(f'Using Qdrant connection string: {qdrant_conn}')

# Parse connection string
endpoint = qdrant_conn.split(';')[0].split('=')[1]
api_key = qdrant_conn.split(';')[1].split('=')[1]

# Initialize Qdrant client
qdrant = QdrantClient(url=endpoint, api_key=api_key)

# Setup Ollama Embeddings
embeddings = OllamaEmbeddings(
    model="mxbai-embed-large", 
    base_url="http://ollama:11434"
)

# Test connection by listing collections
collections = qdrant.get_collections()
print("Successfully connected to Qdrant!")
print(f"Available collections: {collections}")


@dataclass
class CodeEntity:
    """Represents a code entity (function, class, etc.) with its metadata."""
    name: str
    type: str  # 'function', 'class', 'method'
    docstring: Optional[str]
    start_line: int
    end_line: int
    decorators: List[str]
    parent: Optional[str]
    dependencies: List[str]

class MetadataExtractor:
    """Extracts rich metadata from different file types."""
    
    @staticmethod
    def extract_python_entities(content: str) -> List[CodeEntity]:
        """Extract functions, classes, and methods from Python code."""
        try:
            tree = ast.parse(content)
            entities = []
            
            for node in ast.walk(tree):
                if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
                    # Extract docstring
                    docstring = ast.get_docstring(node)
                    
                    # Get decorators
                    decorators = [
                        ast.unparse(decorator).strip()
                        for decorator in node.decorator_list
                    ]
                    
                    # Find imports and dependencies
                    dependencies = []
                    for sub_node in ast.walk(node):
                        if isinstance(sub_node, ast.Import):
                            dependencies.extend(n.name for n in sub_node.names)
                        elif isinstance(sub_node, ast.ImportFrom):
                            dependencies.append(sub_node.module)
                    
                    entity = CodeEntity(
                        name=node.name,
                        type='class' if isinstance(node, ast.ClassDef) else 'function',
                        docstring=docstring,
                        start_line=node.lineno,
                        end_line=node.end_lineno,
                        decorators=decorators,
                        parent=None,  # Will be filled later for methods
                        dependencies=list(set(dependencies))
                    )
                    entities.append(entity)
            
            return entities
        except SyntaxError:
            return []

    @staticmethod
    def extract_markdown_metadata(content: str) -> Dict:
        """Extract metadata from markdown files."""
        metadata = {
            'headers': [],
            'links': [],
            'code_blocks': [],
            'frontmatter': None
        }
        
        # Extract headers
        headers = re.findall(r'^(#{1,6})\s+(.+)$', content, re.MULTILINE)
        metadata['headers'] = [(len(h[0]), h[1]) for h in headers]
        
        # Extract links
        links = re.findall(r'\[([^\]]+)\]\(([^\)]+)\)', content)
        metadata['links'] = links
        
        # Extract code blocks
        code_blocks = re.findall(r'```(\w+)?\n(.*?)```', content, re.DOTALL)
        metadata['code_blocks'] = [(lang or 'text', code) for lang, code in code_blocks]
        
        # Extract frontmatter
        if content.startswith('---'):
            try:
                fm_match = re.match(r'---\n(.*?)\n---', content, re.DOTALL)
                if fm_match:
                    metadata['frontmatter'] = yaml.safe_load(fm_match.group(1))
            except yaml.YAMLError:
                pass
        
        return metadata

    @staticmethod
    def extract_html_template_metadata(content: str, file_type: str) -> Dict:
        """Extract metadata from HTML/Jinja templates."""
        metadata = {
            'blocks': [],
            'extends': None,
            'includes': [],
            'macros': [],
            'variables': []
        }
        
        if file_type == '.jinja':
            # Extract template inheritance
            extends_match = re.search(r'{%\s*extends\s+[\'"](.+?)[\'"]', content)
            if extends_match:
                metadata['extends'] = extends_match.group(1)
            
            # Extract blocks
            blocks = re.findall(r'{%\s*block\s+(\w+)\s*%}', content)
            metadata['blocks'] = blocks
            
            # Extract includes
            includes = re.findall(r'{%\s*include\s+[\'"](.+?)[\'"]', content)
            metadata['includes'] = includes
            
            # Extract macros
            macros = re.findall(r'{%\s*macro\s+(\w+)\s*\(', content)
            metadata['macros'] = macros
            
            # Extract variables
            variables = re.findall(r'{{(.+?)}}', content)
            metadata['variables'] = [v.strip() for v in variables]
        
        return metadata

def create_chunking_strategies():
    """Create specialized chunking strategies for different file types."""
    return {
        # Python files
        '.py': RecursiveCharacterTextSplitter.from_language(
            language=Language.PYTHON,
            chunk_size=500,
            chunk_overlap=50,
            #separators=["\nclass ", "\ndef ", "\n\n", "\n"]
        ),
        
        # Markdown files
        '.md': MarkdownHeaderTextSplitter(
            headers_to_split_on=[
                ("#", "Header 1"),
                ("##", "Header 2"),
                ("###", "Header 3"),
                ("####", "Header 4")
            ]
        ),
        
        # HTML/Jinja templates
        '.html': RecursiveCharacterTextSplitter(
            chunk_size=500,
            chunk_overlap=100,
            separators=["</div>", "</template>", "</section>", "\n\n", "\n"]
        ),
        '.jinja': RecursiveCharacterTextSplitter(
            chunk_size=500,
            chunk_overlap=100,
            separators=["{% block ", "{% extends ", "{% include ", "\n\n", "\n"]
        ),
        
        # Config files
        '.yml': RecursiveCharacterTextSplitter(
            chunk_size=300,
            chunk_overlap=50,
            separators=["---", "\n\n", "\n"]
        ),
        '.yaml': RecursiveCharacterTextSplitter(
            chunk_size=300,
            chunk_overlap=50,
            separators=["---", "\n\n", "\n"]
        ),
        '.toml': RecursiveCharacterTextSplitter(
            chunk_size=300,
            chunk_overlap=50,
            separators=["\n\n", "\n"]
        ),
        
        # Other common file types
        '.json': RecursiveCharacterTextSplitter(
            chunk_size=500,
            chunk_overlap=50,
            separators=["},", "}\n", "\n"]
        ),
        '.rst': RecursiveCharacterTextSplitter(
            chunk_size=600,
            chunk_overlap=100,
            separators=["\n=+\n", "\n-+\n", "\n\n", "\n"]
        ),
        '.txt': RecursiveCharacterTextSplitter(
            chunk_size=500,
            chunk_overlap=50,
            separators=["\n\n", "\n", ". "]
        ),
        
        # Default
        'default': RecursiveCharacterTextSplitter(
            chunk_size=400,
            chunk_overlap=50,
            separators=["\n\n", "\n", ". ", " "]
        )
    }

def extract_file_metadata(file_path: str, content: str) -> Dict:
    """Extract comprehensive metadata for a file."""
    file_ext = os.path.splitext(file_path)[1].lower()
    metadata = {
        "file_path": file_path,
        "file_type": file_ext,
        "file_name": os.path.basename(file_path),
        "directory": os.path.dirname(file_path),
        "size_bytes": len(content.encode('utf-8')),
        "num_lines": len(content.splitlines()),
        "is_empty": len(content.strip()) == 0,
        "has_shebang": content.startswith('#!') if content else False,
        "file_level_metadata": {}
    }
    
    # Extract file type specific metadata
    if file_ext == '.py':
        python_entities = MetadataExtractor.extract_python_entities(content)
        metadata['file_level_metadata'].update({
            'classes': [e for e in python_entities if e.type == 'class'],
            'functions': [e for e in python_entities if e.type == 'function'],
            'has_main': any(e.name == '__main__' for e in python_entities),
            'imports': re.findall(r'^(?:from|import)\s+(\S+)', content, re.MULTILINE),
            'doc_coverage': sum(1 for e in python_entities if e.docstring) / len(python_entities) if python_entities else 0
        })
    
    elif file_ext == '.md':
        metadata['file_level_metadata'].update(
            MetadataExtractor.extract_markdown_metadata(content)
        )
    
    elif file_ext in ['.html', '.jinja']:
        metadata['file_level_metadata'].update(
            MetadataExtractor.extract_html_template_metadata(content, file_ext)
        )
    
    elif file_ext in ['.yml', '.yaml']:
        try:
            yaml_content = yaml.safe_load(content)
            metadata['file_level_metadata']['yaml_structure'] = {
                'top_level_keys': list(yaml_content.keys()) if isinstance(yaml_content, dict) else [],
                'is_list': isinstance(yaml_content, list)
            }
        except yaml.YAMLError:
            pass
    
    elif file_ext == '.json':
        try:
            json_content = json.loads(content)
            metadata['file_level_metadata']['json_structure'] = {
                'top_level_keys': list(json_content.keys()) if isinstance(json_content, dict) else [],
                'is_array': isinstance(json_content, list)
            }
        except json.JSONDecodeError:
            pass
    
    return metadata
    

# 2. Let's modify split_by_files to be more verbose and handle errors better
def split_by_files(content: str) -> List[Dict]:
    """Split the concatenated content into individual files with their paths."""
    if not content:
        print("Warning: Content is empty")
        return []
    
    # Split by double newline to handle the actual format
    file_parts = content.split("\n\n")
    files = []
    
    current_file = None
    current_content = []
    
    for part in file_parts:
        if part.startswith("================================================\nFile: "):
            # If we have a previous file, save it
            if current_file:
                files.append({
                    "path": current_file,
                    "content": "\n".join(current_content),
                    "metadata": extract_file_metadata(current_file, "\n".join(current_content))
                })
                current_content = []
            
            # Extract new file path
            file_line = part.split('\n')[1]  # Get the "File: /path" line
            current_file = file_line.replace("File: ", "").strip()
        else:
            if current_file and part.strip():  # Only add non-empty parts
                current_content.append(part)
    
    # Don't forget to add the last file
    if current_file:
        files.append({
            "path": current_file,
            "content": "\n".join(current_content),
            "metadata": extract_file_metadata(current_file, "\n".join(current_content))
        })
    
    print(f"Total files processed: {len(files)}")
    if files:
        print(f"Sample file paths:")
        for i, file in enumerate(files[:3]):  # Show first 3 files
            print(f"  {i+1}. {file['path']}")
    
    return files
 

# 3. Let's modify advanced_repository_chunking to be more verbose
def advanced_repository_chunking(content: str, repo_url: str):
    """Process repository content using file-type specific chunking with rich metadata."""
    print("Starting repository chunking...")
    
    # Get chunking strategies
    chunking_strategies = create_chunking_strategies()
    print("Created chunking strategies")
    
    # Split content into files
    files = split_by_files(content)
    print(f"Split content into {len(files)} files")
    
    if not files:
        print("Warning: No files to process")
        return []
    
    # Extract repository-level metadata
    repo_metadata = {
        "repository_url": repo_url,
        "repository_name": repo_url.split('/')[-1],
        "organization": repo_url.split('/')[-2],
        "total_files": len(files),
        "file_types": {},
        "directory_structure": {}
    }
    
    # Process each file with appropriate chunking strategy
    processed_chunks = []
    for file in files:
        file_ext = file["metadata"]["file_type"]
        print(f"Processing file with extension: {file_ext}")
        
        splitter = chunking_strategies.get(file_ext, chunking_strategies['default'])
        combined_metadata = {
            **file["metadata"],
            "repository": repo_metadata
        }
        
        try:
            if file_ext == '.md':
                header_splits = splitter.split_text(file["content"])
                if any(len(split.page_content) > 600 for split in header_splits):
                    size_splitter = RecursiveCharacterTextSplitter(
                        chunk_size=600,
                        chunk_overlap=50,
                        separators=["\n\n", "\n", ". "]
                    )
                    for split in header_splits:
                        smaller_splits = size_splitter.create_documents(
                            texts=[split.page_content],
                            metadatas=[{
                                **split.metadata,
                                **combined_metadata
                            }]
                        )
                        processed_chunks.extend(smaller_splits)
                else:
                    for split in header_splits:
                        split.metadata.update(combined_metadata)
                    processed_chunks.extend(header_splits)
            else:
                chunks = splitter.create_documents(
                    texts=[file["content"]],
                    metadatas=[combined_metadata]
                )
                processed_chunks.extend(chunks)
                print(f"Added {len(chunks)} chunks for file {file['path']}")
        except Exception as e:
            print(f"Error processing file {file['path']}: {str(e)}")
            continue
    
    print(f"Total chunks created: {len(processed_chunks)}")
    if processed_chunks:
        print("Sample chunk metadata:", processed_chunks[0].metadata)
    return processed_chunks


def setup_qdrant(collection_name: str = "code_repositories"):
    """Setup Qdrant connection and initialize embeddings."""
    # Get connection string from environment variable or use default
    qdrant_conn = os.getenv('ConnectionStrings__qdrant_http')
    if not qdrant_conn:
        raise ValueError("Qdrant connection string not found in environment variables")
    
    # Parse connection string
    endpoint = qdrant_conn.split(';')[0].split('=')[1]
    api_key = qdrant_conn.split(';')[1].split('=')[1]
    
    # Initialize Qdrant client
    qdrant = QdrantClient(url=endpoint, api_key=api_key)
    
    # Initialize embeddings
    embeddings
# 5. Create vector store
vector_store = QdrantVectorStore(
    client=qdrant,
    collection_name="my_repo_collection",
    embedding=embeddings,
)


Using Qdrant connection string: Endpoint=http://qdrant:6333;Key=aMjJKx0t1a6E9hysaCacWz
Successfully connected to Qdrant!
Available collections: collections=[CollectionDescription(name='embedding-demo'), CollectionDescription(name='embedding-demo1'), CollectionDescription(name='my_repo_collection')]


  qdrant = QdrantClient(url=endpoint, api_key=api_key)


In [2]:
repo_url = "https://github.com/cyclotruc/gitingest"
summary, tree, content = ingest(repo_url)
print(summary)
# 4. Process and index the repository
chunks = advanced_repository_chunking(content, repo_url)
vector_size = len(embeddings.embed_query("test"))
# Create the collection first
collection_name = "my_repo_collection"


# Check if collection exists and create if it doesn't
if not qdrant.collection_exists(collection_name):
    qdrant.create_collection(
        collection_name=collection_name,
        vectors_config=rest.VectorParams(
            size=vector_size,
            distance=rest.Distance.COSINE
        )
    )
    print(f"Created new collection: {collection_name}")
else:
    print(f"Collection {collection_name} already exists")



# 6. Add documents to vector store
vector_store.add_documents(chunks)
print("done")

Repository: cyclotruc/gitingest
Files analyzed: 51

Estimated tokens: 51.2k
Starting repository chunking...
Created chunking strategies
Total files processed: 1
Sample file paths:
  1. /README.md
Split content into 1 files
Processing file with extension: .md
Total chunks created: 420
Collection my_repo_collection already exists


KeyboardInterrupt: 

In [None]:
# Test search
results = vector_store.similarity_search(
    "What is the purpose of this repository?",
    k=2
)
print(results)

In [None]:
from typing import List, Dict
from langchain_ollama import ChatOllama
from langchain.prompts import PromptTemplate
from langchain.schema.runnable import RunnablePassthrough
from langchain.schema import StrOutputParser
from langchain.prompts import ChatPromptTemplate
from typing import List, Dict
from langchain_ollama import ChatOllama
from langchain.prompts import PromptTemplate
from langchain.schema.runnable import RunnablePassthrough
from langchain.schema import StrOutputParser
from langchain.prompts import ChatPromptTemplate

class LocalRAG:
    """A class to handle local RAG operations using Ollama for both embeddings and LLM."""
    
    def __init__(self, vector_store, model_name="phi3.5"):
        """
        Initialize the LocalRAG with vector store and model configurations.
        
        Args:
            vector_store: Initialized QdrantVectorStore
            model_name: Name of the Ollama model to use (default: "phi3.5")
        """
        self.vector_store = vector_store
        self.llm = ChatOllama(model=model_name, base_url="http://ollama:11434")
        
        # Define a better prompt template for RAG
        self.template = """You are a helpful AI assistant. Use the following context to answer the question. 
        If you cannot find the answer in the context, say "I cannot find the answer in the provided context."
        
        Context:
        {context}
        
        Question:
        {question}
        
        Answer:"""
        
        self.prompt = PromptTemplate(
            template=self.template,
            input_variables=["context", "question"]
        )
        
    def format_docs(self, docs: List[Dict]) -> str:
        """Format the retrieved documents into a string."""
        return "\n\n".join(doc.page_content for doc in docs)
    
    def retrieve_and_answer(self, question: str, k: int = 3) -> str:
        """
        Retrieve relevant documents and generate an answer.
        
        Args:
            question: User's question
            k: Number of documents to retrieve (default: 3)
            
        Returns:
            str: Generated answer
        """
        # First retrieve the documents
        retrieved_docs = self.vector_store.similarity_search(question, k=k)
        formatted_context = self.format_docs(retrieved_docs)
        
        # Create and execute the RAG chain
        chain = (
            self.prompt | 
            self.llm | 
            StrOutputParser()
        )
        
        # Execute the chain with the prepared context and question
        response = chain.invoke({
            "context": formatted_context,
            "question": question
        })
        
        return response
    
    def get_relevant_chunks(self, question: str, k: int = 3) -> List[Dict]:
        """
        Get the relevant chunks for a question without generating an answer.
        Useful for debugging and understanding what context is being used.
        
        Args:
            question: User's question
            k: Number of documents to retrieve (default: 3)
            
        Returns:
            List[Dict]: List of relevant documents
        """
        return self.vector_store.similarity_search(question, k=k)

def demonstrate_local_rag(vector_store):
    """Demonstrate how to use the LocalRAG class."""
    # Initialize the RAG system
    rag = LocalRAG(vector_store)
    
    # Example questions to test
    questions = [
        # "What is the purpose of this repository?",
        # "How does the code handle different file types?",
        # "What metadata is extracted from Python files?"
        "How can I use this library?"
    ]
    
    print("RAG Demo:\n")
    for question in questions:
        print(f"Question: {question}")
        print("\nRelevant chunks:")
        chunks = rag.get_relevant_chunks(question, k=2)
        for i, chunk in enumerate(chunks, 1):
            print(f"\nChunk {i}:")
            print(f"Source: {chunk.metadata.get('file_path', 'Unknown')}")
            print(f"Content: {chunk.page_content[:200]}...")
        
        print("\nGenerated Answer:")
        answer = rag.retrieve_and_answer(question)
        print(answer)
        print("\n" + "="*80 + "\n")


# To use this code with your existing setup:
"""
# First make sure you have the vector store set up as in your original code
vector_store = QdrantVectorStore(
    client=qdrant,
    collection_name="my_repo_collection",
    embedding=embeddings,
)
vector_store
# Then you can run the demonstration
demonstrate_local_rag(vector_store)
"""
demonstrate_local_rag(vector_store)