# 2.7 RAG with sources
### Using structured output to include sources in the answer
* Based on https://python.langchain.com/docs/how_to/qa_sources
* See also: https://python.langchain.com/docs/how_to/qa_citations/

## Setup

### Install dependencies

In [None]:
%pip install python-dotenv~=1.0 docarray~=0.40.0 pypdf~=5.1 --upgrade --quiet
%pip install chromadb~=0.5.18 sentence-transformers~=3.3 --upgrade --quiet 
%pip install langchain~=0.3.7 langchain_openai~=0.2.6 langchain_community~=0.3.5 langchain-chroma~=0.1.4 langchainhub~=0.1.21 --upgrade --quiet

# If running locally, you can do this instead:
#%pip install -r ../requirements.txt

### Load environment variables

In [None]:
import os
from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv())

# If running in Google Colab, you can use this code instead:
# from google.colab import userdata
# os.environ["AZURE_OPENAI_API_KEY"] = userdata.get("AZURE_OPENAI_API_KEY")
# os.environ["AZURE_OPENAI_ENDPOINT"] = userdata.get("AZURE_OPENAI_ENDPOINT")

### Setup models

In [None]:
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings
api_version = "2024-10-01-preview"
llm = AzureChatOpenAI(deployment_name="gpt-4o", temperature=0.0, api_version=api_version)
embedding_model = AzureOpenAIEmbeddings(model="text-embedding-3-large", api_version=api_version)

### Setup LangSmith tracing for this notebook

In [None]:
import os

# API key etc is in the .env file
# my_name = "Totoro"
# os.environ["LANGCHAIN_TRACING_V2"] = "true"
# os.environ["LANGCHAIN_PROJECT"] = f"tokyo24-test-{my_name}"

## Setting up a basic RAG chain to start us off


In [None]:
import bs4
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_chroma import Chroma
from langchain_community.document_loaders import WebBaseLoader
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import OpenAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter

# 1. Load, chunk and index the contents of the blog to create a retriever.
loader = WebBaseLoader(
    web_paths=("https://lilianweng.github.io/posts/2023-06-23-agent/",),
    bs_kwargs=dict(
        parse_only=bs4.SoupStrainer(
            class_=("post-content", "post-title", "post-header")
        )
    ),
)
docs = loader.load()

text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
splits = text_splitter.split_documents(docs)
vectorstore = Chroma.from_documents(documents=splits, embedding=embedding_model)
retriever = vectorstore.as_retriever()


# 2. Incorporate the retriever into a question-answering chain.
system_prompt = (
    "You are an assistant for question-answering tasks. "
    "Use the following pieces of retrieved context to answer "
    "the question. If you don't know the answer, say that you "
    "don't know. Use three sentences maximum and keep the "
    "answer concise."
    "\n\n"
    "{context}"
)

prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system_prompt),
        ("human", "{input}"),
    ]
)

question_answer_chain = create_stuff_documents_chain(llm, prompt)
rag_chain = create_retrieval_chain(retriever, question_answer_chain)

In [None]:
print(vectorstore._collection.count())

result = rag_chain.invoke({"input": "What is Task Decomposition?"})
print(result["answer"])

## Customizing the chain to prepare for more advanced output

In [None]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda


def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)


# This Runnable takes a dict with keys 'input' and 'context',
# formats them into a prompt, and generates a response.
rag_chain_from_docs = (
        {
            "input": lambda x: x["input"],  # input query
            "context": lambda x: format_docs(x["context"]),  # context
        }
        | prompt  # format query and context into prompt
        | llm  # generate response
        | StrOutputParser()  # coerce to string
)

# Pass input query to retriever
return_input = RunnableLambda(lambda x: x["input"]) # For clarity, you can also just use a lambda directly
retrieve_docs = return_input | retriever

# Below, we chain `.assign` calls. This takes a dict and successively
# adds keys-- "context" and "answer"-- where the value for each key
# is determined by a Runnable. The Runnable operates on all existing
# keys in the dict.
chain = RunnablePassthrough.assign(context=retrieve_docs).assign(
    answer=rag_chain_from_docs
)

chain.invoke({"input": "What is Task Decomposition"})

## Structure sources in model response

#### See also chapter `2.4 (Extraction)` and `2.3 (Tagging)` for a recap on this.

Because the above LCEL implementation is composed of Runnable primitives, it is straightforward to extend. Below, we make a simple change:

1. We use the model's tool-calling features to generate structured output, consisting of an answer and list of sources. The schema for the response is represented in the AnswerWithSources, below. Note that there is **_two ways_** of implementing this - as a `TypedDict` (if you only need JSON) or as a Pydantic `BaseModel` (if you want an object).
2. We remove the `StrOutputParser()`, as we expect dict output in this scenario.

In [None]:
from langchain_core.runnables import RunnablePassthrough
from typing_extensions import Annotated, TypedDict
from pydantic import BaseModel, Field

# Desired schema for response

# OPTION 1 - Simple JSON schema, by using TypedDict  
class AnswerWithSources1(TypedDict):
    """An answer to the question, with sources."""

    answer: str
    sources: Annotated[
        list[str],
        ...,
        "List of sources (author + year) used to answer the question",
    ]

# OPTION 2 - Pydantic object
# TODO: Your task - implement a this class after having tried option 1
# Hit: Look at `1.4 (Extraction)` and how a list of responses is used there
class AnswerWithSources2(BaseModel):
    """An answer to the question, with sources."""
    # TODO



# Our rag_chain_from_docs has the following changes:
# - add `.with_structured_output` to the LLM;
# - remove the output parser
rag_chain_from_docs = (
        {
            "input": lambda x: x["input"],
            "context": lambda x: format_docs(x["context"]),
        }
        | prompt
        | llm.with_structured_output(AnswerWithSources1)
)

retrieve_docs = (lambda x: x["input"]) | retriever

chain2 = RunnablePassthrough.assign(context=retrieve_docs).assign(
    answer=rag_chain_from_docs
)

response = chain2.invoke({"input": "What is Chain of Thought?"})
structured_answer = response["answer"]
print(type(structured_answer))

In [None]:
import json

structured_answer_json = structured_answer
#structured_answer_json = structured_answer.model_dump()

# Pretty print the event (can also use `print(json.dumps(structured_answer_json, indent=2))`)
from IPython.display import JSON
JSON(structured_answer_json)