<a href="https://colab.research.google.com/github/rsrini7/Colabs/blob/main/langchain.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install langchain langchain-community langchain-core sentence-transformers chromadb litellm --quiet

In [None]:
# langchain_rag_openrouter_litellm.py
from google.colab import userdata
import os
import logging
import sys
from typing import Any, List, Mapping, Optional, Dict, Union, cast, AsyncIterator, Iterator

# --- Langchain Imports ---
from langchain_core.language_models.llms import LLM
from langchain_core.callbacks.manager import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun
from langchain_core.outputs import GenerationChunk, Generation
from langchain_community.document_loaders import TextLoader, DirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate

# --- Other Necessary Imports ---
import litellm

# --- Configuration & Constants ---
EMBED_MODEL_NAME = 'sentence-transformers/all-MiniLM-L6-v2'
DATA_DIR = "./data_langchain" # Use a different data directory to avoid conflicts
SAMPLE_FILE_NAME = "sample_langchain.txt"
OPENROUTER_LITELLM_MODEL_STRING = "openrouter/openai/gpt-3.5-turbo" # Or your preferred OpenRouter model
DB_PATH_LANGCHAIN = './db_chroma_langchain_openrouter_litellm'
# COLLECTION_NAME_LANGCHAIN = "langchain_rag_openrouter_litellm" # Chroma handles this internally based on persist_directory

# --- Helper: Setup Logging (Optional) ---
# logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
# logging.getLogger('litellm').setLevel(logging.INFO) # To see LiteLLM logs

# --- Custom Langchain LLM Class using LiteLLM ---
class LiteLLMWrapperForLangchain(LLM):
    """
    Custom Langchain LLM Wrapper for LiteLLM.
    """
    model_name: str = OPENROUTER_LITELLM_MODEL_STRING
    """The model name to pass to litellm.completion."""

    temperature: float = 0.0
    """The temperature to use for the completion."""

    max_tokens: Optional[int] = 512 # Max tokens for the *output*
    """The maximum number of tokens to generate."""

    top_p: float = 1.0
    """The top-p value to use for the completion."""

    litellm_kwargs: Optional[Dict[str, Any]] = None
    """Additional keyword arguments to pass to litellm.completion."""

    streaming: bool = False
    """Whether to stream the output."""

    @property
    def _llm_type(self) -> str:
        """Return type of llm."""
        return "litellm_langchain_wrapper"

    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        """Get the identifying parameters."""
        return {
            "model_name": self.model_name,
            "temperature": self.temperature,
            "max_tokens": self.max_tokens,
            "top_p": self.top_p,
            "streaming": self.streaming,
            **(self.litellm_kwargs or {}),
        }

    def _prepare_litellm_call_kwargs(self, stop: Optional[List[str]] = None) -> Dict[str, Any]:
        kwargs = self.litellm_kwargs or {}
        kwargs["model"] = self.model_name
        kwargs["temperature"] = self.temperature
        if self.max_tokens is not None: # LiteLLM uses max_tokens for output tokens
             kwargs["max_tokens"] = self.max_tokens
        kwargs["top_p"] = self.top_p
        if stop:
            kwargs["stop"] = stop
        kwargs["api_key"] = os.getenv("OPENROUTER_API_KEY") # LiteLLM can pick this up
        return kwargs

    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        """Call out to LiteLLM's completion endpoint."""
        if self.streaming: # Langchain handles streaming via _stream or _astream
            # This _call method is for non-streaming. If streaming is true,
            # it implies the user might have set it expecting streaming from _generate.
            # For simplicity here, we'll just make a non-streaming call if _call is invoked.
            # A more robust implementation would raise an error or adapt.
            pass

        call_kwargs = self._prepare_litellm_call_kwargs(stop=stop)
        call_kwargs.update(kwargs) # Allow overriding with call-specific kwargs

        messages = [{"role": "user", "content": prompt}]
        response = litellm.completion(messages=messages, **call_kwargs)
        return response.choices[0].message.content or ""

    async def _acall(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        """Async call out to LiteLLM's completion endpoint."""
        call_kwargs = self._prepare_litellm_call_kwargs(stop=stop)
        call_kwargs.update(kwargs)

        messages = [{"role": "user", "content": prompt}]
        response = await litellm.acompletion(messages=messages, **call_kwargs)
        return response.choices[0].message.content or ""

    def _stream(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> Iterator[GenerationChunk]:
        """Stream responses from LiteLLM."""
        call_kwargs = self._prepare_litellm_call_kwargs(stop=stop)
        call_kwargs.update(kwargs)
        call_kwargs["stream"] = True

        messages = [{"role": "user", "content": prompt}]
        for chunk in litellm.completion(messages=messages, **call_kwargs):
            if chunk.choices and chunk.choices[0].delta:
                delta_content = chunk.choices[0].delta.content
                if delta_content:
                    yield GenerationChunk(text=delta_content)
                    if run_manager:
                        run_manager.on_llm_new_token(delta_content)

    async def _astream(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> AsyncIterator[GenerationChunk]:
        """Async stream responses from LiteLLM."""
        call_kwargs = self._prepare_litellm_call_kwargs(stop=stop)
        call_kwargs.update(kwargs)
        call_kwargs["stream"] = True

        messages = [{"role": "user", "content": prompt}]
        async for chunk in await litellm.acompletion(messages=messages, **call_kwargs):
            if chunk.choices and chunk.choices[0].delta:
                delta_content = chunk.choices[0].delta.content
                if delta_content:
                    yield GenerationChunk(text=delta_content)
                    if run_manager:
                        await run_manager.on_llm_new_token(delta_content)

# --- Main Script Logic ---
def main():
    print("--- Starting Langchain RAG with OpenRouter via LiteLLM ---")

    # 0. Setup: API Keys and Sample Data
    # litellm.set_verbose = True # Uncomment for verbose LiteLLM logs

    try:
        openrouter_api_key = userdata.get('OPENROUTER_API_KEY')
        os.environ["OPENROUTER_API_KEY"] = openrouter_api_key # For LiteLLM
        print("OpenRouter API Key loaded from Colab Secrets.")
    except userdata.SecretNotFoundError:
        print("ERROR: OPENROUTER_API_KEY not found in Colab Secrets. Please add it.")
        sys.exit(1)
    except Exception as e:
        print(f"ERROR: Could not load OpenRouter API Key: {e}")
        sys.exit(1)

    sample_file_path = os.path.join(DATA_DIR, SAMPLE_FILE_NAME)
    if not os.path.exists(DATA_DIR):
        os.makedirs(DATA_DIR)
    if not os.path.exists(sample_file_path):
        with open(sample_file_path, "w") as f:
            f.write("""Langchain is a framework for developing applications powered by language models.
It provides modular components for building complex chains and agents.
Key features include document loaders, text splitters, vector stores, and LLM wrappers.
This example uses Langchain with OpenRouter via LiteLLM for RAG.
Retrieval Augmented Generation enhances LLM responses with external data.
""")
        print(f"Created dummy sample file: '{sample_file_path}'")

    # Configure Langchain Components
    print(f"\nConfiguring LLM: Langchain LiteLLM Wrapper with model '{OPENROUTER_LITELLM_MODEL_STRING}'")
    llm = LiteLLMWrapperForLangchain(
        model_name=OPENROUTER_LITELLM_MODEL_STRING,
        temperature=0.0,
        max_tokens=256 # Max output tokens for the LLM response
    )
    print("LLM configured.")

    print(f"\nConfiguring Embedding Model: '{EMBED_MODEL_NAME}'")
    try:
        embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL_NAME)
        print("Embedding model configured successfully.")
    except Exception as e:
        print(f"ERROR: Could not load HuggingFace embedding model '{EMBED_MODEL_NAME}': {e}")
        print("Ensure 'pip install sentence-transformers' has been run.")
        sys.exit(1)

    # 1. Ingest Text File
    print("\n--- 1. Ingesting Data ---")
    try:
        # Using DirectoryLoader to load all .txt files in the directory
        loader = DirectoryLoader(DATA_DIR, glob=f"**/{SAMPLE_FILE_NAME}", loader_cls=TextLoader, show_progress=True)
        documents = loader.load()
        if not documents:
            print(f"Warning: No documents loaded from '{DATA_DIR}'. Ensure '{SAMPLE_FILE_NAME}' exists.")
            sys.exit(1)
        print(f"Loaded {len(documents)} document(s). Total characters: {sum(len(doc.page_content) for doc in documents)}")
    except Exception as e:
        print(f"ERROR during document loading: {e}")
        sys.exit(1)

    # Split documents into chunks
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
    texts = text_splitter.split_documents(documents)
    print(f"Split into {len(texts)} chunks.")

    # 2. Store Contents in a Vector Database (ChromaDB)
    print("\n--- 2. Storing in Vector Database (ChromaDB) ---")
    try:
        print(f"Initializing Chroma vector store at '{DB_PATH_LANGCHAIN}'...")
        # If the directory exists and has data, Chroma will load it.
        # For a fresh run, you might want to delete the DB_PATH_LANGCHAIN directory.
        vectorstore = Chroma.from_documents(
            documents=texts,
            embedding=embeddings,
            persist_directory=DB_PATH_LANGCHAIN
        )
        vectorstore.persist() # Ensure persistence
        print(f"Vector store created/loaded. Collection count (approx): {vectorstore._collection.count()}")
    except Exception as e:
        print(f"ERROR during vector store setup: {e}")
        # If error is "Invalid dimension" check embedding model output vs Chroma expectations.
        sys.exit(1)

    # 3. Perform a Search Operation (via Retriever)
    print("\n--- 3. Performing Explicit Search (Retriever) ---")
    query = "What is Langchain?"
    try:
        # Load from disk if needed (e.g., in a separate run after ingestion)
        # vectorstore = Chroma(persist_directory=DB_PATH_LANGCHAIN, embedding_function=embeddings)
        retriever = vectorstore.as_retriever(search_kwargs={"k": 2}) # Get top 2 results
        retrieved_docs = retriever.invoke(query) # Langchain uses 'invoke'

        print(f"Search query: '{query}'")
        print(f"Found {len(retrieved_docs)} relevant document chunk(s):")
        for i, doc in enumerate(retrieved_docs):
            print(f"  Result {i+1} (Metadata: {doc.metadata}): {doc.page_content[:150].strip()}...")
    except Exception as e:
        print(f"ERROR during retrieval: {e}")
        sys.exit(1)

    # 4. Pass Search Results to LLM for Generating Answers (RetrievalQA Chain)
    print("\n--- 4. Generating Answer with LLM using RetrievalQA Chain ---")
    try:
        # Define a prompt template (optional, but good practice)
        prompt_template_str = """Use the following pieces of context to answer the question at the end.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Keep the answer concise and based *only* on the provided context.

Context:
{context}

Question: {question}
Helpful Answer:"""
        QA_PROMPT = PromptTemplate(
            template=prompt_template_str, input_variables=["context", "question"]
        )

        qa_chain = RetrievalQA.from_chain_type(
            llm=llm,
            chain_type="stuff", # "stuff" puts all context into the prompt
            retriever=retriever,
            return_source_documents=True, # Optionally return source documents
            chain_type_kwargs={"prompt": QA_PROMPT}
        )

        print(f"Querying LLM with (via chain): '{query}'")
        result = qa_chain.invoke({"query": query}) # Langchain chains use 'invoke'

        print(f"\nLLM Answer for '{query}':")
        print(f"Answer: {result['result']}")

        print("\nSource Documents considered by LLM:")
        for i, doc in enumerate(result["source_documents"]):
            print(f"  Source {i+1} (Metadata: {doc.metadata}): {doc.page_content[:100].strip()}...")

    except Exception as e:
        print(f"ERROR during RetrievalQA chain execution or LLM call: {e}")

    print("\n--- Langchain RAG with OpenRouter via LiteLLM Finished ---")

if __name__ == "__main__":
    main()

--- Starting Langchain RAG with OpenRouter via LiteLLM ---
OpenRouter API Key loaded from Colab Secrets.

Configuring LLM: Langchain LiteLLM Wrapper with model 'openrouter/openai/gpt-3.5-turbo'
LLM configured.

Configuring Embedding Model: 'sentence-transformers/all-MiniLM-L6-v2'


  embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL_NAME)


Embedding model configured successfully.

--- 1. Ingesting Data ---


100%|██████████| 1/1 [00:00<00:00, 422.81it/s]

Loaded 1 document(s). Total characters: 379
Split into 1 chunks.

--- 2. Storing in Vector Database (ChromaDB) ---
Initializing Chroma vector store at './db_chroma_langchain_openrouter_litellm'...



  vectorstore.persist() # Ensure persistence


Vector store created/loaded. Collection count (approx): 2

--- 3. Performing Explicit Search (Retriever) ---
Search query: 'What is Langchain?'
Found 2 relevant document chunk(s):
  Result 1 (Metadata: {'source': 'data_langchain/sample_langchain.txt'}): Langchain is a framework for developing applications powered by language models.
It provides modular components for building complex chains and agents...
  Result 2 (Metadata: {'source': 'data_langchain/sample_langchain.txt'}): Langchain is a framework for developing applications powered by language models.
It provides modular components for building complex chains and agents...

--- 4. Generating Answer with LLM using RetrievalQA Chain ---
Querying LLM with (via chain): 'What is Langchain?'

LLM Answer for 'What is Langchain?':
Answer: Langchain is a framework for developing applications powered by language models.

Source Documents considered by LLM:
  Source 1 (Metadata: {'source': 'data_langchain/sample_langchain.txt'}): Langchain i