<a href="https://colab.research.google.com/github/syedmahmoodiagents/guardrails/blob/main/GuardRails.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [13]:
!pip install langchain-huggingface --q

In [None]:
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import PydanticOutputParser
from pydantic import BaseModel

In [None]:
from langchain_core.runnables import RunnableLambda

In [14]:
import os, getpass

In [23]:
os.environ['HF_TOKEN'] = getpass.getpass()

··········


In [47]:
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint

In [48]:
llm = ChatHuggingFace(llm=HuggingFaceEndpoint(repo_id="oss-gpt-20b"))

In [49]:
class AnswerSchema(BaseModel):
    answer: str
    citations: list[str]

parser = PydanticOutputParser(pydantic_object=AnswerSchema)

In [50]:
primary_prompt = PromptTemplate(
    template="""
Answer the question using the context below.

Question:
{question}

Context:
{context}

{format_instructions}
""",
    input_variables=["question", "context"],
    partial_variables={
        "format_instructions": parser.get_format_instructions()
    }
)

In [51]:
def detect_missing_citations(response: AnswerSchema):
    if not response.citations:
        raise ValueError("Detective guardrail failed: no citations")
    return response

detect_citations = RunnableLambda(detect_missing_citations)

In [52]:
primary_chain = (primary_prompt | llm | parser | detect_citations)

In [53]:

retry_prompt = PromptTemplate(
    template="""
Answer again with STRICT rules.

Rules:
- Use ONLY the provided context
- Provide at least one citation
- If the answer is not in the context, say "I don’t know"

Question:
{question}

Context:
{context}

{format_instructions}
""",
    input_variables=["question", "context"],
    partial_variables={
        "format_instructions": parser.get_format_instructions()
    }
)

In [54]:
retry_chain = (
    retry_prompt | llm | parser | detect_citations
)

In [55]:
def safe_fallback(_):
    return AnswerSchema(
        answer="I don’t have enough information to answer this safely.",
        citations=[]
    )

fallback_chain = RunnableLambda(safe_fallback)

In [56]:
final_chain = primary_chain.with_fallbacks([
    retry_chain,
    fallback_chain
]).with_config({"max_retries": 5})

In [57]:
context_text = """
Aspirin is commonly used to reduce pain, fever, and inflammation.
(Source: Medical Handbook, Page 12)
"""

result = final_chain.invoke({
    "question": "What is aspirin used for?",
    "context": context_text
})

In [58]:
print(result)

answer='I don’t have enough information to answer this safely.' citations=[]
