In [34]:
# Forcer le couple compatible avec 0.8.5
!pip install -Uq "google-generativeai==0.8.5" "google-ai-generativelanguage==0.6.15" "langchain-google-genai==2.0.5"

# (Optionnel) Si un autre paquet re-tire une mauvaise version :
!pip install -Uq --no-deps "google-generativeai==0.8.5" "google-ai-generativelanguage==0.6.15"


In [35]:
# --- Imports & types ---
import os
from typing import Annotated, Sequence, TypedDict, List, Literal
from dotenv import load_dotenv

# LangChain core
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, ToolMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.documents import Document

# Vector store / loaders
from langchain_community.document_loaders import WebBaseLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.tools.retriever import create_retriever_tool

# LLM: Google Gemini seulement
from langchain_google_genai import ChatGoogleGenerativeAI

# LangGraph
from langgraph.graph import StateGraph, END
from langgraph.graph.message import add_messages

# Pydantic
from pydantic import BaseModel, Field



In [36]:
# --- AgentState & grading model ---
class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], add_messages]
    documents: List[Document]
    loops: int

class grade(BaseModel):
    # garde le nom et la signature que tu avais
    binary_score: str = Field(description="Relevance score 'yes' or 'no'")

MAX_LOOPS = 2


In [37]:
# --- PATCH 1: LLM: flash ---
from langchain_google_genai import ChatGoogleGenerativeAI

def make_llm():
    api_key = os.getenv("GOOGLE_API_KEY")
    if not api_key:
        raise RuntimeError("GOOGLE_API_KEY est absent.")
    # 👉 modèle plus tolérant au quota
    return ChatGoogleGenerativeAI(model="gemini-1.5-flash", temperature=0)


In [38]:
# --- PATCH 2: Tokens ---
def build_vectorstore(urls: list, persist_dir: str = ".chroma"):
    import os
    from langchain_community.document_loaders import WebBaseLoader
    from langchain.text_splitter import RecursiveCharacterTextSplitter
    from langchain_community.vectorstores import Chroma
    from langchain_community.embeddings import HuggingFaceEmbeddings

    os.makedirs(persist_dir, exist_ok=True)
    emb = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
    try:
        vs = Chroma(persist_directory=persist_dir, embedding_function=emb)
        _ = vs._collection.count()
        return vs
    except Exception:
        pass

    docs = WebBaseLoader(urls).load()
    # 👉 chunks plus courts
    splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=120)
    chunks = splitter.split_documents(docs)
    vs = Chroma.from_documents(chunks, embedding=emb, persist_directory=persist_dir)
    vs.persist()
    return vs

def make_retriever_tool(vs):
    # 👉 k=3 pour limiter le contexte
    retriever = vs.as_retriever(search_kwargs={"k": 3})
    from langchain.tools.retriever import create_retriever_tool
    tool = create_retriever_tool(
        retriever,
        name="kb_search",
        description="Recherche dans la base Chroma."
    )
    return tool, retriever


In [39]:
# --- Nodes: My_AI_Assistant, Vector_Retriever, Query_Rewriter, Output_Generator ---
def make_nodes(llm, retriever_tool, retriever):
    # Routeur
    router_prompt = ChatPromptTemplate.from_messages([
        ("system", "Décide si la question a besoin de la base (appel kb_search) ou réponds directement."),
        ("placeholder", "{messages}"),
    ])
    llm_router = llm.bind_tools([retriever_tool])

    def ai_assistant(state: AgentState):
        resp = (router_prompt | llm_router).invoke({"messages": state["messages"]})
        return {"messages": [resp]}

    # Retrieve
    def retrieve(state: AgentState):
        last_ai = next((m for m in reversed(state["messages"]) if isinstance(m, AIMessage)), None)
        if not last_ai or not getattr(last_ai, "tool_calls", None):
            return {}
        out_msgs: List[BaseMessage] = []
        collected: List[Document] = []
        for call in last_ai.tool_calls:
            if call.get("name") != "kb_search":
                continue
            query = call.get("args", {}).get("query")
            if not query:
                last_user = next((m for m in reversed(state["messages"]) if isinstance(m, HumanMessage)), None)
                query = last_user.content if last_user else ""
            docs = retriever.get_relevant_documents(query)
            collected.extend(docs)
            preview = []
            for i, d in enumerate(docs, 1):
                src = d.metadata.get("source") or d.metadata.get("url") or "unknown"
                snippet = d.page_content[:350].replace("\n", " ")
                preview.append(f"[{i}] {src}: {snippet}…")
            out_msgs.append(
                ToolMessage(
                    content="\n".join(preview) if preview else "(Aucun document)",
                    name="kb_search",
                    tool_call_id=call.get("id", "0"),
                )
            )
        return {"messages": out_msgs, "documents": collected}

    # Grader (structured output)
    grader = llm.with_structured_output(grade)
    judge_prompt = ChatPromptTemplate.from_messages([
        ("system", "Tu es un juge. Réponds STRICTEMENT {binary_score:'yes'|'no'} selon la pertinence globale."),
        ("human", "Question:\n{q}\n\nDocs (extraits):\n{ctx}")
    ])
    def grade_documents(state: AgentState) -> Literal["Output_Generator","Query_Rewriter"]:
        docs = state.get("documents", [])
        if not docs:
            return "Query_Rewriter"
        q = next((m.content for m in reversed(state["messages"]) if isinstance(m, HumanMessage)), "")
        ctx = "\n\n".join(d.page_content[:500] for d in docs)
        res: grade = (judge_prompt | grader).invoke({"q": q, "ctx": ctx})
        return "Output_Generator" if res.binary_score.strip().lower() == "yes" else "Query_Rewriter"

    # Rewrite
    rewrite_prompt = ChatPromptTemplate.from_messages([
        ("system", "Réécris la question pour améliorer la recherche (<= 25 mots). Réponse = requête seule."),
        ("human", "Question:\n{q}")
    ])
    def rewrite(state: AgentState):
        loops = state.get("loops", 0)
        if loops >= MAX_LOOPS:
            return {
                "messages": [AIMessage(content="Après plusieurs tentatives, je manque de sources. Peux-tu préciser la période, les entités ou les mots-clés ?")],
                "documents": [],
                "loops": loops
            }
        q = next((m.content for m in reversed(state["messages"]) if isinstance(m, HumanMessage)), "")
        new_q = (rewrite_prompt | llm).invoke({"q": q}).content.strip()
        return {"messages": [HumanMessage(content=new_q)], "loops": loops + 1, "documents": []}

    # Generate
    gen_prompt = ChatPromptTemplate.from_messages([
        ("system", "Réponds en t'appuyant UNIQUEMENT sur les documents fournis. Cite brièvement les sources [n]."),
        ("human", "Question:\n{q}\n\nDocuments:\n{ctx}")
    ])
    def generate(state: AgentState):
        docs = state.get("documents", [])
        if not docs:
            return {"messages": [AIMessage(content="Je n'ai pas de documents pertinents à citer. Peux-tu préciser ?")]}
        q = next((m.content for m in reversed(state["messages"]) if isinstance(m, HumanMessage)), "")
        ctx = "\n\n".join([f"[{i}] {d.metadata.get('source','?')}\n{d.page_content[:1200]}"  # 1200 max
                   for i, d in enumerate(docs, 1)])
        resp = (gen_prompt | llm).invoke({"q": q, "ctx": ctx})
        return {"messages": [AIMessage(content=resp.content)]}

    return ai_assistant, retrieve, grade_documents, rewrite, generate


In [40]:
# --- Build workflow ---
def should_call_tools(state: AgentState):
    last_ai = next((m for m in reversed(state["messages"]) if isinstance(m, AIMessage)), None)
    return "Vector_Retriever" if (last_ai and getattr(last_ai, "tool_calls", None)) else "END"

def build_app(urls: List[str]):
    llm = make_llm()
    vs = build_vectorstore(urls)
    retriever_tool, retriever = make_retriever_tool(vs)
    ai_assistant, retrieve, grade_documents, rewrite, generate = make_nodes(llm, retriever_tool, retriever)

    workflow = StateGraph(AgentState)
    workflow.add_node("My_AI_Assistant", ai_assistant)
    workflow.add_node("Vector_Retriever", retrieve)
    workflow.add_node("Query_Rewriter", rewrite)
    workflow.add_node("Output_Generator", generate)

    workflow.set_entry_point("My_AI_Assistant")

    workflow.add_conditional_edges(
        "My_AI_Assistant",
        should_call_tools,
        {"Vector_Retriever": "Vector_Retriever", "END": END},
    )

    workflow.add_conditional_edges(
        "Vector_Retriever",
        grade_documents,
        {"Output_Generator": "Output_Generator", "Query_Rewriter": "Query_Rewriter"},
    )

    workflow.add_edge("Query_Rewriter", "My_AI_Assistant")

    return workflow.compile()


In [None]:
# --- Test run (Gemini only) ---
import os

assert os.getenv("GOOGLE_API_KEY"), "GOOGLE_API_KEY manquant. Fournis ta clé avant de continuer."

urls = [
    "https://python.langchain.com/docs/langgraph",
    "https://python.langchain.com/docs/langgraph/concepts",
    "https://python.langchain.com/docs/langgraph/how-tos",
    "https://python.langchain.com/docs/langgraph/examples",
]

app = build_app(urls)

init_state = AgentState(messages=[HumanMessage(content="À quoi sert StateGraph dans LangGraph ?")],
                        documents=[], loops=0)
result = app.invoke(init_state)

last = next((m for m in reversed(result["messages"]) if isinstance(m, AIMessage)), None)
print("=== Réponse ===")
print(last.content if last else "(pas de réponse)")
