In [None]:
import os
from pathlib import Path
from langchain_community.document_loaders import DirectoryLoader, TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_classic.schema import Document

# Custom loader function to preserve context and metadata
def load_documents_with_context(dataset_path: str):
    """
    Load documents while preserving file structure and metadata
    """
    documents = []
    dataset_dir = Path(dataset_path)
    
    # Walk through all subdirectories to maintain category information
    for category_dir in dataset_dir.iterdir():
        if category_dir.is_dir():
            category_name = category_dir.name
            print(f"Loading category: {category_name}")
            
            # Load all .txt files in this category
            txt_files = list(category_dir.glob("*.txt"))
            print(f"  Found {len(txt_files)} files")
            
            for txt_file in txt_files:
                try:
                    with open(txt_file, 'r', encoding='utf-8') as f:
                        content = f.read()
                    
                    # Create document with rich metadata
                    doc = Document(
                        page_content=content,
                        metadata={
                            'source': str(txt_file),
                            'filename': txt_file.name,
                            'category': category_name,
                            'file_size': len(content),
                            'word_count': len(content.split()),
                            'file_path': str(txt_file.relative_to(dataset_dir))
                        }
                    )
                    documents.append(doc)
                    
                except Exception as e:
                    print(f"  Error loading {txt_file}: {e}")
    
    return documents

# Load documents with preserved context
dataset_path = r"***"
docs = load_documents_with_context(dataset_path)

print(f"\nTotal documents loaded: {len(docs)}")

# Display category breakdown
if docs:
    categories = {}
    for doc in docs:
        cat = doc.metadata['category']
        categories[cat] = categories.get(cat, 0) + 1
    
    print(f"\nDocuments by category:")
    for category, count in sorted(categories.items()):
        print(f"  {category}: {count} documents")
    
    # Show sample document metadata
    print(f"\nSample document metadata:")
    sample_doc = docs[0]
    for key, value in sample_doc.metadata.items():
        if isinstance(value, str) and len(value) > 50:
            print(f"  {key}: {value[:50]}...")
        else:
            print(f"  {key}: {value}")
    
    print(f"\nSample content preview:")
    print(f"{sample_doc.page_content[:200]}...")

# Enhanced text splitter that preserves metadata
class ContextAwareTextSplitter(RecursiveCharacterTextSplitter):
    """
    Text splitter that preserves document metadata in chunks
    """
    
    def split_documents(self, documents):
        """Split documents while preserving all metadata"""
        all_chunks = []
        
        for doc in documents:
            # Split the document content
            chunks = self.split_text(doc.page_content)
            
            # Create new documents for each chunk with preserved metadata
            for i, chunk in enumerate(chunks):
                chunk_doc = Document(
                    page_content=chunk,
                    metadata={
                        **doc.metadata,  # Preserve all original metadata
                        'chunk_index': i,
                        'total_chunks': len(chunks),
                        'chunk_size': len(chunk.split())
                    }
                )
                all_chunks.append(chunk_doc)
        
        return all_chunks

# Split into chunks while preserving context
splitter = ContextAwareTextSplitter(
    chunk_size=1000,
    chunk_overlap=100,
    length_function=len,
    separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""]
)

chunks = splitter.split_documents(docs)

print(f"\nTotal chunks created: {len(chunks)}")

# Analyze chunk distribution by category
if chunks:
    chunk_categories = {}
    chunk_sizes = []
    
    for chunk in chunks:
        cat = chunk.metadata['category']
        chunk_categories[cat] = chunk_categories.get(cat, 0) + 1
        chunk_sizes.append(chunk.metadata['chunk_size'])
    
    print(f"\nChunks by category:")
    for category, count in sorted(chunk_categories.items()):
        print(f"  {category}: {count} chunks")
    
    import statistics
    print(f"\nChunk size statistics:")
    print(f"  Average: {statistics.mean(chunk_sizes):.1f} words")
    print(f"  Median: {statistics.median(chunk_sizes):.1f} words")
    print(f"  Min: {min(chunk_sizes)} words")
    print(f"  Max: {max(chunk_sizes)} words")
    
    # Show sample chunk with metadata
    print(f"\nSample chunk metadata:")
    sample_chunk = chunks[0]
    for key, value in sample_chunk.metadata.items():
        print(f"  {key}: {value}")
    
    print(f"\nSample chunk content:")
    print(f"{sample_chunk.page_content[:150]}...")

In [None]:
from langchain_classic.embeddings import HuggingFaceEmbeddings
emb = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")


In [13]:
from langchain_classic.vectorstores import Chroma
db = Chroma.from_documents(chunks, emb, collection_name="knowledge_base")
retriever = db.as_retriever(search_kwargs={"k": 4})


In [None]:
db.persist()

In [15]:
db._persist_directory

'./chroma'

In [20]:
from langchain_openai import ChatOpenAI
# from langchain.chains import RetrievalQA

llm = ChatOpenAI(
    model="gpt-4o-mini",   # public hosted model
    api_key=os.getenv("OPENAI_API_KEY")
)

In [None]:
from langchain_community.llms.ollama import Ollama
from langchain_classic.chains.retrieval_qa.base import RetrievalQA

# llm = Ollama(base_url="http://localhost:11435", model="llama3")
rag = RetrievalQA.from_chain_type(llm=llm, retriever=retriever)
print(rag("Explain HIV vaccine trial"))
