<a href="https://colab.research.google.com/github/venkataravuri/ai-ml-models/blob/main/rag_add_citations.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Source

https://python.langchain.com/docs/how_to/qa_citations/

In [None]:
%pip install -qU langchain langchain-openai langchain-anthropic langchain-community wikipedia

In [None]:
from google.colab import userdata
import os

os.environ["OPENAI_API_KEY"] = userdata.get('OPENAI_API_KEY')
os.environ["ANTHROPIC_API_KEY"] = userdata.get('ANTHROPIC_API_KEY')

In [None]:
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(model="gpt-4o-mini")

In [None]:
from langchain_community.retrievers import WikipediaRetriever
from langchain_core.prompts import ChatPromptTemplate

system_prompt = (
    "You're a helpful AI assistant. Given a user question "
    + "and some Wikipedia article snippet, answer the user "
    + "question. If none of the articles answer the question, "
    + "just say you don't know."
    + "\n\nHere are the Wikipedia articles: "
    + "{context}"
)

retriever = WikipediaRetriever(top_k_results=6, doc_content_chars_max=2000)
prompt = ChatPromptTemplate.from_messages([
    ("system", system_prompt),
    ("human", "{input}")
])
prompt.pretty_print()

In [None]:
from typing import List

from langchain_core.documents import Document
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough

def format_docs(docs: List[Document]) -> str:
    return "\n\n".join(doc.page_content for doc in docs)

rag_chain_from_docs = (
    RunnablePassthrough.assign(
        context=retriever.with_fallbacks([format_docs]).stream
    )
    | prompt
    | llm
    | StrOutputParser()
)

retrieve_docs = (lambda x: x["input"]) | retriever

chain = RunnablePassthrough.assign(
    input=retrieve_docs
) | rag_chain_from_docs

In [None]:
result = chain.invoke("input": "What is the capital of France?")
print(result.keys())
print(result["context"][0])
print(result["answer"])

In [None]:
from pydantic import BaseModel, Field

claass CitedAnswer(BaseMode):
  """Answer the user question based only on the given sources, and cite the sources used."""

  answer: str = Field(
      ...,
      description="The answer to the user question, which is based only on the given sources."
  )
  citations: List[str] = Field(
      ...,
      description="The integer IDs of the SPECIFIC sources which justify the answer."
  )