# Safeguarding AI Virtual Assistant for Customer Service with NVIDIA NeMo Guardrails

AI agents present a significant opportunity for businesses to scale and elevate customer service and support interactions. By automating routine inquiries and enhancing response times, these agents improve efficiency and customer satisfaction, helping organizations stay competitive. 

However, alongside these benefits, AI agents come with risks. Large language models (LLMs) are vulnerable to generating inappropriate or off-topic content and can be susceptible to jailbreak attacks. To fully realize the potential of generative AI in customer service, it is essential to implement robust AI safety and security measures.

This tutorial equips AI builders with actionable steps to integrate essential safeguards into AI agents for customer service applications. We’ll explore how to integrate AI safeguard NIM microservices using NeMo Guardrails to build guardrail configurations that ensure your AI agent can identify and mitigate unsafe interactions in real time. Then, we’ll take it a step further by connecting these capabilities to the sophisticated agentic workflows outlined in the NVIDIA AI Blueprint for AI virtual assistants. By the end, you’ll have a clear understanding of how to create a scalable and secure AI assistant tailored to your brand’s unique needs. 

Figure 1 details the architecture workflow of integrating **[NeMo Guardrails](https://docs.nvidia.com/nemo/guardrails/index.html)** and safeguarding **[NIM microservices](https://developer.nvidia.com/nim)** in the **[NVIDIA AI Blueprint for virtual assistants](https://build.nvidia.com/nvidia/ai-virtual-assistant-for-customer-service)**.


## Prerequisites

### Docker compose

#### System requirements

Ubuntu 20.04 or 22.04 based machine, with sudo privileges

Install software requirements
- Install Docker Engine and Docker Compose. Refer to the instructions for Ubuntu.
- Ensure the Docker Compose plugin version is 2.29.1 or higher.
- Run docker compose version to confirm.
- Refer to Install the Compose plugin in the Docker documentation for more information.
- To configure Docker for GPU-accelerated containers, install the NVIDIA Container Toolkit.
- Install git

By default the provided configurations use GPU optimized databases such as Milvus.


### Safety NIM Microservices

#### Compute Requirements
If you are going to deploy the **[Safety NIM Miccroservices](https://docs.nvidia.com/_preview?_cms.db.previewId=00000194-6b79-d6bc-a7b5-6bfb99b10000&_fields=true&_mainObjectId=&_date=#nemoguard)** using the downloadable containers from **[NVIDIA NGC](https://registry.ngc.nvidia.com/)**, then there might be a need for higher compute
- minimum of 4xH100, or 4xA100

If the Safety NIM Microservices are used using the build.nvidia.com endpoints, there is no need for additional compute
- 1xH100, 1xA100

## Getting API Keys - Very Important

To run the pipeline you need to obtain an API key from NVIDIA. These will be needed in a later step to Set up the environment file.

- Required API Keys: These APIs are required by the pipeline to execute LLM queries.

- NVIDIA API Catalog
  1. Navigate to **[NVIDIA API Catalog](https://build.nvidia.com/explore/discover)**.
  2. Select any model, such as llama-3.3-70b-instruct.
  3. On the right panel above the sample code snippet, click on "Get API Key". This will prompt you to log in if you have not already.

NOTE: The API key starts with nvapi- and ends with a 32-character string. You can also generate an API key from the user settings page in NGC (https://ngc.nvidia.com/).

Export API Keys

In [None]:
import os

NVIDIA_API_KEY = input("Please enter your NVIDIA API key (nvapi-): ")
NGC_API_KEY=NVIDIA_API_KEY
os.environ["NVIDIA_API_KEY"] = NVIDIA_API_KEY
os.environ["NGC_CLI_API_KEY"] = NGC_API_KEY
os.environ["NGC_API_KEY"] = NGC_API_KEY

# Step 1: Deploying the NIM Blueprint 

Open the jupyter notebook  **[./ai-virtual-assistant/deploy/ai_virtual_assistant_notebook.ipynb](https://github.com/NVIDIA-AI-Blueprints/ai-virtual-assistant/blob/main/deploy/ai_virtual_assistant_notebook.ipynb)** and run through the cells (Shift + Enter) to start the **[NIM blueprint](https://github.com/NVIDIA-AI-Blueprints/ai-virtual-assistant)** and all the necessary docker containers. Following the same notebook, run the **[./ai-virtual-assistant/notebooks/ingest_data.ipynb](https://github.com/NVIDIA-AI-Blueprints/ai-virtual-assistant/blob/main/notebooks/ingest_data.ipynb)** to ingest the structured and unstructured data types.

# Step 2: Download the NeMo Guardrails Toolkit 

Start by cloning the NeMo Guardrails repository:

In [None]:
!git clone https://github.com/NVIDIA/NeMo-Guardrails.git nemoguardrails

Make sure that the notebook is operating from the `ai-virtual-assistant` directory. If it's not, it changes to that directory.

In [None]:
import os

current_path = os.getcwd()
last_part = os.path.basename(current_path)

if os.path.basename(os.getcwd()) != "ai-virtual-assistant":
    os.chdir("ai-virtual-assistant")

os.getcwd()

We login into the NGC catalogue.

In [None]:
!docker login nvcr.io -u '$oauthtoken' -p $NGC_API_KEY

## Build the NeMo Guardrails with Docker

First setup the `nemoguardrails.yaml` file for NeMo Guardrails and then launch the **[container](https://docs.nvidia.com/nemo/guardrails/user_guides/advanced/using-docker.html)** by using the following command:

In [None]:
%%bash
docker compose -f deploy/compose/nemoguardrails.yaml up -d

Before running the nemoguardrails server, we need to add the guardrails configuration. Let's deploy the safety NIMs and integrate it with NeMo Guardrails

## Step 3:  Deploying the Safety NIMs

This tutorial equips AI builders with actionable steps to integrate essential safeguards into AI agents for customer service applications. It demonstrates how to leverage NVIDIA NeMo Guardrails, a scalable rail orchestration platform, including the following three new AI safeguard models offered as NVIDIA NIM microservices:

**[Llama 3.1 NemoGuard 8B ContentSafety](https://build.nvidia.com/nvidia/llama-3_1-nemoguard-8b-content-safety)** for safeguarding input prompts and output responses in AI interactions, ensuring AI systems align with ethical standards. Llama 3.1 NemoGuard 8B ContentSafety is trained on the Aegis Content Safety Dataset including 35,000 human annotated AI safety data samples. It features explicit response labels curated through an automated process using an ensemble of LLM-as-a-judge across NVIDIA-developed and open community LLMs.

**[Llama 3.1 NemoGuard 8B TopicControl](https://build.nvidia.com/nvidia/llama-3_1-nemoguard-8b-topic-control)** for keeping conversations focused on approved topics, avoiding derailment or inappropriate content. Llama 3.1 NemoGuard 8B TopicControl is fine-tuned on synthetic data to maintain context and enforce boundaries consistently throughout entire AI conversations. 

**[NemoGuard JailbreakDetect](https://build.nvidia.com/nvidia/nemoguard-jailbreak-detect)** for protection against jailbreak attempts, helping to maintain AI integrity in adversarial scenarios. NemoGuard JailbreakDetect is an LLM jailbreak classification model trained on a dataset of 17,000 known challenging and successful jailbreaks, built in part using NVIDIA Garak, an open-source toolkit for LLM and application vulnerability scanning developed by the NVIDIA Research team.

Each of the Safety NIMs can be deployed either as a downloadable container or via the endpoint

Let us see how to deploy NIMs as downloadable containers

### Llama 3.1 NemoGuard 8B ContentSafety

The Llama 3.1 NemoGuard 8B ContentSafety NIM follows a set of 42 Safety hazard categories with data distributions including annotations for jailbreak data, diverse cultural and geographical AI content safety like Hazards and Self Harm. Custom and novel safety risk categories and policy can also be provided in the instruction for the model to categorize using the novel taxonomy and policy. The model detects if the user input and/or the LLM response are safe or unsafe, and if unsafe, gives the violated category in the response. 

In [None]:
!mkdir safety-nims

In [None]:
%%writefile safety-nims/content-safety.sh
export NGC_API_KEY=<your NGC personal key>
export NIM_IMAGE=<Path to latest NIM docker container>
export MODEL_NAME="llama-3.1-nemoguard-8b-content-safety"
docker pull $NIM_IMAGE

docker run -it --name=$MODEL_NAME \
    --gpus="device=0" --runtime=nvidia \
    -e NGC_API_KEY="$NGC_API_KEY" \
    -e NIM_SERVED_MODEL_NAME=$MODEL_NAME \
    -e NIM_CUSTOM_MODEL_NAME=$MODEL_NAME \
    -v $LOCAL_NIM_CACHE:"/opt/nim/.cache/" \
    -u $(id -u) \
    -p 8123:8000 \
    $NIM_IMAGE

On your teminal (irrespective of running locally or on VM) run the `ai-virtual-assistant/safety-nims/content-safety.sh` to deploy the NIM

### Llama 3.1 NemoGuard 8B TopicControl

The Llama 3.1 NemoGuard 8B TopicControl NIM can be used for topical and dialogue moderation of user prompts in human-assistant interactions being designed for task-oriented dialogue agents and custom policy-based moderation. Given a system instruction (also called topical instruction, i.e. specifying which topics are allowed and disallowed) and a conversation history ending with the last user prompt, the model returns a binary response that flags if the user message respects the system instruction, (i.e. message is on-topic or a distractor/off-topic). 

In [None]:
%%writefile safety-nims/topic-control.sh
export NGC_API_KEY=<your NGC personal key>
export NIM_IMAGE=<Path to latest NIM docker container>
export MODEL_NAME="llama-3.1-nemoguard-8b-topic-control"
docker pull $NIM_IMAGE

docker run -it --name=$MODEL_NAME \
    --gpus="device=1" --runtime=nvidia \
    -e NGC_API_KEY="$NGC_API_KEY" \
    -e NIM_SERVED_MODEL_NAME=$MODEL_NAME \
    -e NIM_CUSTOM_MODEL_NAME=$MODEL_NAME \
    -v $LOCAL_NIM_CACHE:"/opt/nim/.cache/" \
    -u $(id -u) \
    -p 8124:8000 \
    $NIM_IMAGE

On your teminal (irrespective of running locally or on VM) run the `ai-virtual-assistant/safety-nims/topic-control.sh` to deploy the NIM

### NemoGuard JailbreakDetect
The NemoGuard JailbreakDetect NIM was developed to detect attempts to jailbreak large language models. The Jailbreak detection model uses Snowflake-arctic-embed-m embeddings. It is trained on the combination of three open datasets, mixed together, de-duplicated, and reviewed for data quality. Jailbreak data was augmented with the use of garak.

In [None]:
%%writefile safety-nims/jailbreak-detect.sh
export NGC_API_KEY=<your NGC personal key>
export NIM_IMAGE=<Path to latest NIM docker container>
export MODEL_NAME='ardennes-jailbreak-arctic'
docker pull $NIM_IMAGE

docker run -it --name=$MODEL_NAME \
    --gpus="device=1" --runtime=nvidia \
    -e NGC_API_KEY="$NGC_API_KEY" \
    -v $LOCAL_NIM_CACHE:"/opt/nim/.cache/" \
    -u $(id -u) \
    -p 8125:8000 \
    $NIM_IMAGE

On your teminal (irrespective of running locally or on VM) run the `ai-virtual-assistant/safety-nims/jailbreak-detect.sh` to deploy the NIM

While building the guardrails configuration, integrate the three safeguard NIM microservices, start with creating the config directory: 


```
├── config
│   ├── config.yml
│   ├── prompts.yml

```

Now, add each configuration option one by one, starting with the models in the config.yml file:

## Step 4: Building the NeMo Guardrails configuration with the Safety NIMs
### 1. Creating the `config.yml` file

In [None]:
!mkdir config

In [None]:
%%writefile config/config.yml
models:
  - type: main
    engine: nim
    model: meta/llama-3.1-70b-instruct

  - type: "llama-3.1-nemoguard-8b-content-safety"
    engine: nim
    parameters:
      base_url: "http://localhost:8123/v1"
      model_name: "llama-3.1-nemoguard-8b-content-safety"
  - type: topic_control
    engine: nim
    parameters:
      base_url: "http://localhost:8124/v1/"
      model_name: "llama-3.1-nemoguard-8b-topic-control"

rails:
  config:
    jailbreak_detection:
      server_endpoint: ""
      embedding: "snowflake/snowflake-arctic-embed-m-long"
  input:
    flows:
      - content safety check input $model=llama-3.1-nemoguard-8b-content-safety
      - topic safety check input $model=topic_control
      - jailbreak detection model
  output:
    flows:
      - content safety check output $model=llama-3.1-nemoguard-8b-content-safety

Both the Content safety NIM and the Topic control NIM come with a set of policy compliances which can be added to the guardrails configuration using the `prompts.yml` file

### 2. Creating the `prompts.yml` file

In [None]:
%%writefile config/prompts.yml
# These are the default prompts released by Meta, except for policy O7, which was added to address direct insults.
prompts:

  - task: content_safety_check_input $model=llama_3.1_aegis_guard_2.0
    content: |
      Task: Check if there is unsafe content in the user or agent messages in the below conversation in accordance with our safety policy with the below categories. Provide a 'safe' or 'unsafe' rating for both the user message and the agent response. If there is unsafe content, please also provide a list of violated categories according to our safety policy below.

      <BEGIN UNSAFE CONTENT CATEGORIES>
      S1: Violence.
      S2: Sexual.
      S3: Criminal Planning/Confessions.
      S4: Guns and Illegal Weapons.
      S5: Controlled/Regulated Substances.
      S6: Suicide and Self Harm.
      S7: Sexual (minor).
      S8: Hate/Identity Hate.
      S9: PII/Privacy.
      S10: Harassment.
      S11: Threat.
      S12: Profanity.
      S13: Needs Caution.
      S14: Other.
      S15: Manipulation.
      S16: Fraud/Deception.
      S17: Malware.
      S18: High Risk Gov Decision Making.
      S19: Political/Misinformation/Conspiracy.
      S20: Copyright/Trademark/Plagiarism.
      S21: Unauthorized Advice.
      S22: Illegal Activity.
      S23: Immoral/Unethical.
      <END UNSAFE CONTENT CATEGORIES>

      <BEGIN CONVERSATION>

      user: {{ user_input }}

      <END CONVERSATION>

      Provide your safety assessment for the above conversation in the below output JSON format:
      {"User Safety": safe or unsafe rating for the user message, "Response Safety": safe or unsafe rating for the agent response. Omit if no agent response present. "Safety Categories": a comma-separated list of applicable safety categories from the provided taxonomy. Omit if all safe.}

      Do not include anything other than the output JSON in your response.
      Output JSON:
    output_parser: is_content_safe_aegis_input
    max_tokens: 50

  - task: content_safety_check_output $model=llama_3.1_aegis_guard_2.0
    content: |
      Task: Check if there is unsafe content in the user or agent messages in the below conversation in accordance with our safety policy with the below categories. Provide a 'safe' or 'unsafe' rating for both the user message and the agent response. If there is unsafe content, please also provide a list of violated categories according to our safety policy below.

      <BEGIN UNSAFE CONTENT CATEGORIES>
      S1: Violence.
      S2: Sexual.
      S3: Criminal Planning/Confessions.
      S4: Guns and Illegal Weapons.
      S5: Controlled/Regulated Substances.
      S6: Suicide and Self Harm.
      S7: Sexual (minor).
      S8: Hate/Identity Hate.
      S9: PII/Privacy.
      S10: Harassment.
      S11: Threat.
      S12: Profanity.
      S13: Needs Caution.
      S14: Other.
      S15: Manipulation.
      S16: Fraud/Deception.
      S17: Malware.
      S18: High Risk Gov Decision Making.
      S19: Political/Misinformation/Conspiracy.
      S20: Copyright/Trademark/Plagiarism.
      S21: Unauthorized Advice.
      S22: Illegal Activity.
      S23: Immoral/Unethical.
      <END UNSAFE CONTENT CATEGORIES>

      <BEGIN CONVERSATION>

      user: {{ user_input }}

      response: agent: {{ bot_response }}

      <END CONVERSATION>

      Provide your safety assessment for the above conversation in the below output JSON format:
      {"User Safety": safe or unsafe rating for the user message, "Response Safety": safe or unsafe rating for the agent response. Omit if no agent response present. "Safety Categories": a comma-separated list of applicable safety categories from the provided taxonomy. Omit if all safe.}

      Do not include anything other than the output JSON in your response.
      Output JSON:
    output_parser: is_content_safe_aegis_output
    max_tokens: 50

  - task: content_safety_check_input $model=llama_guard
    content: |
      <s>[INST] Task: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories.
        
  - task: topic_safety_check_input $model=topic_control
    content: |
      You are to act as a customer service agent, providing users with factual information in accordance to the knowledge base. Your role is to ensure that you respond only to relevant queries and adhere to the following guidelines

      Guidelines for the user messages:
      - Do not answer questions related to personal opinions or advice on user's order, future recommendations
      - Do not provide any information on non-company products or services.
      - Do not answer enquiries unrelated to the company policies.
      - Do not answer questions asking for personal details about the agent or its creators.
      - Do not answer questions about sensitive topics related to politics, religion, or other sensitive subjects.
      - If a user asks topics irrelevant to the company's customer service relations, politely redirect the conversation or end the interaction.
      - Your responses should be professional, accurate, and compliant with customer relations guidelines, focusing solely on providing transparent, up-to-date information about the company that is already publicly available.

## Step5: Wrapping the guardrails configuration around the agentic system

With the configuration complete, you could use it as is to apply guardrails to a general-purpose conversational AI by interfacing with the NeMo Guardrails server through its API. The assistant or agent from the NIM Blueprint performs multiple tasks, a few including RAG, checking if the user is compliant with the return policy, and thereby updating the return option, getting the user's purchase history. 

Start with adding chains to the following agent components
- `src/analytics/main.py`
- `src/agent/utils.py`
- `src/agent/main.py`

#### 1. Analytics

In [None]:
%%writefile src/analytics/main.py
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import logging
from datetime import datetime
from enum import Enum
from typing import Annotated, Generator, Literal, Sequence, TypedDict

from langchain_core.messages import BaseMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts.chat import ChatPromptTemplate
from langchain_core.runnables import RunnableConfig
from pydantic import BaseModel, Field
from src.analytics.datastore.session_manager import SessionManager
from src.common.utils import get_config, get_llm, get_prompts
from nemoguardrails import RailsConfig
from nemoguardrails.integrations.langchain.runnable_rails import RunnableRails

logger = logging.getLogger(__name__)
prompts = get_prompts()

# TODO get the default_kwargs from the Agent Server API
default_llm_kwargs = {"temperature": 0, "top_p": 0.7, "max_tokens": 1024}

# Initialize persist_data to determine whether data should be stored in the database.
persist_data = os.environ.get("PERSIST_DATA", "true").lower() == "true"

# Initialize session manager during startup
session_manager = None
try:
    session_manager = SessionManager()
except Exception as e:
    logger.info(f"Failed to connect to DB during init, due to exception {e}")
    
# Initialize  guardrails configuration
rail_config = RailsConfig.from_path("./config")
guardrails = RunnableRails(rail_config, input_key="query", output_key='content')
    
    
def get_database():
    """
    Connect to the database.
    """
    global session_manager
    try:
        if not session_manager:
            session_manager = SessionManager()

        return session_manager
    except Exception as e:
        logger.info(f"Error connecting to database: {e}")
        return None


def generate_summary(conversation_history):
    """
    Generate a summary of the conversation.

    Parameters:
        conversation_history (List): The conversation text.

    Returns:
        str: A summary of the conversation.
    """
    logger.info(f"conversation history: {conversation_history}")
    llm = get_llm(**default_llm_kwargs)
    prompt = prompts.get("summary_prompt", "")
    for turn in conversation_history:
        prompt += f"{turn['role']}: {turn['content']}\n"

    prompt += "\n\nSummary: "
    
    # Apply guardrails to the chain
    chain_with_guardrails = guardrails | llm
    response = chain_with_guardrails.invoke({"query": prompt})

    return response.content


def generate_session_summary(session_id):
    # TODO: Check for corner cases like when session_id does not exist
    session_manager = get_database()

    # Check if summary already exists in database
    session_info = session_manager.get_session_summary_and_sentiment(session_id)
    if session_info and session_info.get("summary", None):
        return session_info

    # Generate summary and session info
    conversation_history = session_manager.get_conversation(session_id)
    summary = generate_summary(conversation_history)
    sentiment = generate_sentiment(conversation_history)

    if persist_data:
        # Save the summary and sentiment in database
        session_manager.save_summary_and_sentiment(
            session_id,
            {
                "summary": summary,
                "sentiment": sentiment,
                "start_time": conversation_history[0].get("timestamp", 0),
                "end_time": conversation_history[-1].get("timestamp", 0),
            }
        )
    return {
        "summary": summary,
        "sentiment": sentiment,
        "start_time": datetime.fromtimestamp(
            float(conversation_history[0].get("timestamp", 0))
        ),
        "end_time": datetime.fromtimestamp(
            float(conversation_history[-1].get("timestamp", 0))
        ),
    }


def fetch_user_conversation(user_id, start_time=None, end_time=None):
    """
    Fetch a user's conversation from the database.
    """
    try:
        # TODO: Use start time and end time to filter the data
        session_manager = get_database()
        conversations = session_manager.list_sessions_for_user(user_id)
        logger.info(f"Conversation: {conversations}")
        return conversations
    except Exception as e:
        logger.error(f"Error fetching conversation: {e}")
        return None


def generate_sentiment(conversation_history):
    # Define an Enum for the sentiment values
    class SentimentEnum(str, Enum):
        POSITIVE = "positive"
        NEUTRAL = "neutral"
        NEGATIVE = "negative"

    # Define the Pydantic model using the Enum
    class Sentiment(BaseModel):
        """Sentiment for conversation."""

        sentiment: SentimentEnum = Field(
            description="Relevant value 'positive', 'neutral' or 'negative'"
        )

    logger.info("Finding sentiment for conversation")
    llm = get_llm(**default_llm_kwargs)
    prompt = prompts.get("sentiment_prompt", "")
    for turn in conversation_history:
        prompt += f"{turn['role']}: {turn['content']}\n"

    llm_with_tool = llm.with_structured_output(Sentiment)

    # Apply guardrails to the chain
    chain_with_guardrails = guardrails | llm_with_tool
    response = chain_with_guardrails.invoke({"query": prompt})
    
    sentiment = response.content.sentiment.value
    logger.info(f"Conversation classified as {sentiment}")
    return sentiment


def generate_sentiment_for_query(session_id):
    """Generate sentiment for user query and assistant response
    """

    logger.info("Fetching sentiment for queries")
    # Check if the sentiment is already identified in database, if yes return that
    session_manager = get_database()

    session_info = session_manager.get_query_sentiment(session_id)

    if session_info and session_info.get("messages", None):
        return {
        "messages": session_info.get("messages"),
            "session_info": {
                "session_id": session_id,
                "start_time": session_info.get("start_time"),
                "end_time": session_info.get("start_time"),
            },
        }

    class SentimentEnum(str, Enum):
        POSITIVE = "positive"
        NEUTRAL = "neutral"
        NEGATIVE = "negative"

    # Define the Pydantic model using the Enum
    class Sentiment(BaseModel):
        """Sentiment for conversation."""

        sentiment: SentimentEnum = Field(
            description="Relevant value 'positive', 'neutral' or 'negative'"
        )


    # Generate summary and session info
    conversation_history = session_manager.get_conversation(session_id)
    logger.info(f"Conversation history: {conversation_history}")

    logger.info("Finding sentiment for conversation")
    llm = get_llm(**default_llm_kwargs)

    llm_with_tool = llm.with_structured_output(Sentiment)
    
    # Apply guardrails to the chain
    chain_with_guardrails = guardrails | llm_with_tool

    messages = []
    # TODO: parallize this operation for faster response
    # Find sentiment for individual query and assistant response
    for turn in conversation_history:
        prompt = prompts.get("query_sentiment_prompt", "")
        prompt += f"{turn['role']}: {turn['content']}\n"

        response = chain_with_guardrails.invoke({"query": prompt})
        sentiment = response.content.sentiment.value
        messages.append({
            "role": turn["role"],
            "content": turn["content"],
            "sentiment": sentiment,
        })

    session_info = {
        "messages": messages,
        "start_time": conversation_history[0].get("timestamp", 0),
        "end_time": conversation_history[-1].get("timestamp", 0),
    }
    if persist_data:
        # Save information before sending it to user
        session_manager.save_query_sentiment(session_id, session_info)
    return {
        "messages": messages,
            "session_info": {
                "session_id": session_id,
                "start_time": datetime.fromtimestamp(
                    float(conversation_history[0].get("timestamp", 0))
                ),
                "end_time": datetime.fromtimestamp(
                    float(conversation_history[-1].get("timestamp", 0))
                ),
            },
    }


#### 2. Agent

In [None]:
%%writefile src/agent/utils.py
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import os
import logging
from typing import Dict
from pydantic import BaseModel, Field
from urllib.parse import urlparse

import requests

from psycopg_pool import AsyncConnectionPool
from psycopg.rows import dict_row
import psycopg2

from src.common.utils import get_llm, get_prompts, get_config
from langchain_core.prompts.chat import ChatPromptTemplate
from langchain_core.messages import HumanMessage
from langchain_core.messages import ToolMessage
from langchain_core.runnables import RunnableLambda
from langgraph.checkpoint.memory import MemorySaver
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from langgraph.prebuilt import ToolNode
from nemoguardrails import RailsConfig
from nemoguardrails.integrations.langchain.runnable_rails import RunnableRails

prompts = get_prompts()
logger = logging.getLogger(__name__)

# TODO get the default_kwargs from the Agent Server API
default_llm_kwargs = {"temperature": 0, "top_p": 0.7, "max_tokens": 1024}

canonical_rag_url = os.getenv('CANONICAL_RAG_URL', 'http://unstructured-retriever:8081')
canonical_rag_search = f"{canonical_rag_url}/search"

# Initialize  guardrails configuration
rail_config = RailsConfig.from_path("./config")
guardrails = RunnableRails(rail_config, input_key="query", output_key='content')

def get_product_name(messages, product_list) -> Dict:
    """Given the user message and list of product find list of items which user might be talking about"""

    # First check product name in query
    # If it's not in query, check in conversation
    # Once the product name is known we will search for product name from database
    # We will return product name from list and actual name detected.

    llm = get_llm(**default_llm_kwargs)

    class Product(BaseModel):
        name: str = Field(..., description="Name of the product talked about.")

    prompt_text = prompts.get("get_product_name")["base_prompt"]
    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", prompt_text),
        ]
    )
    llm = llm.with_structured_output(Product)

    chain = prompt | llm
    # Adding guardrails to the chain
    chain_with_guardrails = guardrails | chain
    # query to be used for document retrieval
    # Get the last human message instead of messages[-2]
    last_human_message = next((m.content for m in reversed(messages) if isinstance(m, HumanMessage)), None)
    response = chain_with_guardrails.invoke({"query": last_human_message})

    product_name = response.content.name

    # Check if product name is in query
    if product_name == 'null':

        # Check for produt name in user conversation
        fallback_prompt_text = prompts.get("get_product_name")["fallback_prompt"]
        prompt = ChatPromptTemplate.from_messages(
            [
                ("system", fallback_prompt_text),
            ]
        )

        llm = get_llm(**default_llm_kwargs)
        llm = llm.with_structured_output(Product)

        chain = prompt | llm
        # Adding guardrails to the chain
        chain_with_guardrails = guardrails | chain
        # query to be used for document retrieval
        response = chain.invoke({"query": messages})

        product_name = response.content.name
    # Check if it's partial name exists or not
    if product_name == 'null':
        return {}

    def filter_products_by_name(name, products):
        # TODO: Replace this by llm call to check if that can take care of cases like
        # spelling mistakes or words which are seperated
        # TODO: Directly make sql query with wildcard
        name_lower = name.lower()

        # Check for exact match first
        exact_match = [product for product in products if product.lower() == name_lower]
        if exact_match:
            return exact_match

        # If no exact match, fall back to partial matches
        name_parts = [part for part in re.split(r'\s+', name_lower) if part.lower() != 'nvidia']
        # Match only if all parts of the search term are found in the product name
        matching_products = [
            product for product in products
            if all(part in product.lower() for part in name_parts if part)
        ]

        return matching_products

    matching_products = filter_products_by_name(product_name, product_list)

    return {
        "product_in_query": product_name,
        "products_from_purchase": list(set([product for product in matching_products]))
    }


def handle_tool_error(state) -> dict:
    error = state.get("error")
    tool_calls = state["messages"][-1].tool_calls
    return {
        "messages": [
            ToolMessage(
                content=f"Error: {repr(error)}\n please fix your mistakes.",
                tool_call_id=tc["id"],
            )
            for tc in tool_calls
        ]
    }


def create_tool_node_with_fallback(tools: list) -> dict:
    return ToolNode(tools).with_fallbacks(
        [RunnableLambda(handle_tool_error)], exception_key="error"
    )


async def get_checkpointer() -> tuple:
    settings = get_config()

    if settings.checkpointer.name == "postgres":
        print(f"Using {settings.checkpointer.name} hosted on {settings.checkpointer.url} for checkpointer")
        db_user = os.environ.get("POSTGRES_USER")
        db_password = os.environ.get("POSTGRES_PASSWORD")
        db_name = os.environ.get("POSTGRES_DB")
        db_uri = f"postgresql://{db_user}:{db_password}@{settings.checkpointer.url}/{db_name}?sslmode=disable"
        connection_kwargs = {
            "autocommit": True,
            "prepare_threshold": 0,
            "row_factory": dict_row,
        }

        # Initialize PostgreSQL checkpointer
        pool = AsyncConnectionPool(
            conninfo=db_uri,
            min_size=2,
            kwargs=connection_kwargs,
        )
        checkpointer = AsyncPostgresSaver(pool)
        await checkpointer.setup()
        return checkpointer, pool
    elif settings.checkpointer.name == "inmemory":
        print(f"Using MemorySaver as checkpointer")
        return MemorySaver(), None
    else:
        raise ValueError(f"Only inmemory and postgres is supported chckpointer type")


def remove_state_from_checkpointer(session_id):

    settings = get_config()
    if settings.checkpointer.name == "postgres":
        # Handle cleanup for PostgreSQL checkpointer
        # Currently, there is no langgraph checkpointer API to remove data directly.
        # The following tables are involved in storing checkpoint data:
        # - checkpoint_blobs
        # - checkpoint_writes
        # - checkpoints
        # Note: checkpoint_migrations table can be skipped for deletion.
        try:
            app_database_url = settings.checkpointer.url

            # Parse the URL
            parsed_url = urlparse(f"//{app_database_url}", scheme='postgres')

            # Extract host and port
            host = parsed_url.hostname
            port = parsed_url.port

            # Connect to your PostgreSQL database
            connection = psycopg2.connect(
                dbname=os.getenv('POSTGRES_DB', None),
                user=os.getenv('POSTGRES_USER', None),
                password=os.getenv('POSTGRES_PASSWORD', None),
                host=host,
                port=port
            )
            cursor = connection.cursor()

            # Execute delete commands
            cursor.execute("DELETE FROM checkpoint_blobs WHERE thread_id = %s", (session_id,))
            cursor.execute("DELETE FROM checkpoint_writes WHERE thread_id = %s", (session_id,))
            cursor.execute("DELETE FROM checkpoints WHERE thread_id = %s", (session_id,))

            # Commit the changes
            connection.commit()
            logger.info(f"Deleted rows with thread_id: {session_id}")

        except Exception as e:
            logger.info(f"Error occurred while deleting data from checkpointer: {e}")
            # Optionally rollback if needed
            if connection:
                connection.rollback()
        finally:
            # Close the cursor and connection
            if cursor:
                cursor.close()
            if connection:
                connection.close()
    else:
        # For other supported checkpointer(i.e. inmemory) we don't need cleanup
        pass

def canonical_rag(query: str, conv_history: list)  -> str:
    """Use this for answering generic queries about products, specifications, warranties, usage, and issues."""

    entry_doc_search = {"query": query, "top_k": 4, "conv_history": conv_history}
    response = requests.post(canonical_rag_search, json=entry_doc_search).json()

    # Extract and aggregate the content
    aggregated_content = "\n".join(chunk["content"] for chunk in response.get("chunks", []))

    return aggregated_content

In [None]:
%%writefile src/agent/main.py
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from typing import Annotated, TypedDict, Dict
from langgraph.graph.message import AnyMessage, add_messages
from typing import Callable
from langchain_core.messages import ToolMessage, AIMessage, HumanMessage, SystemMessage
from typing import Annotated, Optional, Literal, TypedDict
from langchain_core.prompts.chat import ChatPromptTemplate
from langchain_core.prompts import MessagesPlaceholder
from langgraph.graph import END, StateGraph, START
from langgraph.prebuilt import tools_condition
from langchain_core.runnables import RunnableConfig
from src.agent.tools import (
        structured_rag, get_purchase_history, HandleOtherTalk, ProductValidation,
        return_window_validation, update_return, get_recent_return_details,
        ToProductQAAssistant,
        ToOrderStatusAssistant,
        ToReturnProcessing)
from src.agent.utils import get_product_name, create_tool_node_with_fallback, get_checkpointer, canonical_rag
from src.common.utils import get_llm, get_prompts
from nemoguardrails import RailsConfig
from nemoguardrails.integrations.langchain.runnable_rails import RunnableRails

logger = logging.getLogger(__name__)
prompts = get_prompts()
# TODO get the default_kwargs from the Agent Server API
default_llm_kwargs = {"temperature": 0.2, "top_p": 0.7, "max_tokens": 1024}

# Initialize  guardrails configuration
rail_config = RailsConfig.from_path("./config")
guardrails = RunnableRails(rail_config, input_key="query", output_key='content')

# STATE OF THE AGENT
class State(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]
    user_id: str
    user_purchase_history: Dict
    current_product: str
    needs_clarification: bool
    clarification_type: str
    reason: str

# NODES FOR THE AGENT
def validate_product_info(state: State, config: RunnableConfig):
    # This node will take user history and find product name based on query
    # If there are multiple name of no name specified in the graph then it will

    # This dict is to populate the user_purchase_history and product details if required
    response_dict = {"needs_clarification": False}
    if state["user_id"]:
        # Update user purchase history based
        response_dict.update({"user_purchase_history": get_purchase_history(state["user_id"])})

        # Extracting product name which user is expecting
        product_list = list(set([resp.get("product_name") for resp in response_dict.get("user_purchase_history", [])]))

        # Extract product name from query and filter from database
        product_info = get_product_name(state["messages"], product_list)

        product_names = product_info.get("products_from_purchase", [])
        product_in_query = product_info.get("product_in_query", "")
        if len(product_names) == 0:
            reason = ""
            if product_in_query:
                reason = f"{product_in_query}"
            response_dict.update({"needs_clarification": True, "clarification_type": "no_product", "reason": reason})
            return response_dict
        elif len(product_names) > 1:
            reason = ", ".join(product_names)
            response_dict.update({"needs_clarification": True, "clarification_type": "multiple_products", "reason": reason})
            return response_dict
        else:
            response_dict.update({"current_product": product_names[0]})

    return response_dict

async def handle_other_talk(state: State, config: RunnableConfig):
    """Handles greetings and queries outside order status, returns, or products, providing polite redirection and explaining chatbot limitations."""

    prompt = prompts.get("other_talk_template", "")

    prompt = ChatPromptTemplate.from_messages(
        [
        ("system", prompt),
        ("placeholder", "{messages}"),
        ]
    )

    # LLM
    llm_settings = config.get('configurable', {}).get("llm_settings", default_llm_kwargs)
    llm = get_llm(**llm_settings)
    llm = llm.with_config(tags=["should_stream"])

    # Chain
    small_talk_chain = prompt | llm
    
    # Adding guardrails
    small_talk_chain_guardrails = guardrails | small_talk_chain
    response = await small_talk_chain_guardrails.ainvoke(state, config)

    return {"messages": [response.content]}


def create_entry_node(assistant_name: str) -> Callable:
    def entry_node(state: State) -> dict:
        tool_call_id = state["messages"][-1].tool_calls[0]["id"]
        return {
            "messages": [
                ToolMessage(
                    content=f"The assistant is now the {assistant_name}. Reflect on the above conversation between the host assistant and the user."
                    f" The user's intent is unsatisfied. Use the provided tools to assist the user. Remember, you are {assistant_name},"
                    " and the booking, update, other other action is not complete until after you have successfully invoked the appropriate tool."
                    " If the user changes their mind or needs help for other tasks, let the primary host assistant take control."
                    " Do not mention who you are - just act as the proxy for the assistant.",
                    tool_call_id=tool_call_id,
                )
            ]
        }

    return entry_node

async def ask_clarification(state: State, config: RunnableConfig):

    # Extract the base prompt
    base_prompt = prompts.get("ask_clarification")["base_prompt"]
    previous_conversation = [m for m in state['messages'] if not isinstance(m, ToolMessage)]
    base_prompt = base_prompt.format(previous_conversation=previous_conversation)

    purchase_history = state.get("user_purchase_history", [])
    if state["clarification_type"] == "no_product" and state['reason'].strip():
        followup_prompt = prompts.get("ask_clarification")["followup"]["no_product"].format(
            reason=state['reason'],
            purchase_history=purchase_history
        )
    elif not state['reason'].strip():
        followup_prompt = prompts.get("ask_clarification")["followup"]["default"].format(reason=purchase_history)
    else:
        followup_prompt = prompts.get("ask_clarification")["followup"]["default"].format(reason=state['reason'])

    # Combine base prompt and followup prompt
    prompt = f"{base_prompt} {followup_prompt}"

    # LLM
    llm_settings = config.get('configurable', {}).get("llm_settings", default_llm_kwargs)
    llm = get_llm(**llm_settings)
    llm = llm.with_config(tags=["should_stream"])
    
    # Adding the guardrails
    chain_with guardrails = guardrails | llm

    response = await chain_with_guardrails.ainvoke(prompt, config)

    return {"messages": [response.content]}

async def handle_product_qa(state: State, config: RunnableConfig):

    # Extract the previous_conversation
    previous_conversation = [m for m in state['messages'] if not isinstance(m, ToolMessage) and m.content]
    message_type_map = {
        HumanMessage: "user",
        AIMessage: "assistant",
        SystemMessage: "system"
    }

    # Serialized conversation
    get_role = lambda x: message_type_map.get(type(x), None)
    previous_conversation_serialized = [{"role": get_role(m), "content": m.content} for m in previous_conversation if m.content]
    last_message = previous_conversation_serialized[-1]['content']

    retireved_content = canonical_rag(query=last_message, conv_history=previous_conversation_serialized)

    # Use the RAG Template to generate the response
    base_rag_prompt = prompts.get("rag_template")
    rag_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", base_rag_prompt),
        MessagesPlaceholder("chat_history") + "\n\nCONTEXT:  {context}"
    ]
    )
    rag_prompt = rag_prompt.format(chat_history=previous_conversation, context=retireved_content)

    # LLM
    llm_settings = config.get('configurable', {}).get("llm_settings", default_llm_kwargs)
    llm = get_llm(**llm_settings)
    llm = llm.with_config(tags=["should_stream"])
    
    # Adding guardrails
    chain_with_guardrails = guardrails | llm

    response = await chain_with_guardrails.ainvoke(rag_prompt, config)

    return {"messages": [response.content]}

class Assistant:
    def __init__(self, prompt: str, tools: list):
        self.prompt = prompt
        self.tools = tools

    async def __call__(self, state: State, config: RunnableConfig):
        while True:

            llm_settings = config.get('configurable', {}).get("llm_settings", default_llm_kwargs)
            llm = get_llm(**llm_settings)
            runnable = self.prompt | llm.bind_tools(self.tools)
            runnable_with_guardrails = guardrails | runnable
            state = await runnable_with_guardrails.invoke(state)
            last_message = state["messages"][-1]
            messages = []
            if isinstance(last_message, ToolMessage) and last_message.name in ["structured_rag", "return_window_validation", "update_return", "get_purchase_history", "get_recent_return_details"]:
                gen = runnable.with_config(
                tags=["should_stream"],
                callbacks=config.get(
                    "callbacks", []
                ),  # <-- Propagate callbacks (Python <= 3.10)
                )
                async for message in gen.astream(state):
                    messages.append(message.content)
                result = AIMessage(content="".join(messages))
            else:
                result = runnable_with_guardrails.invoke(state)

            if not result.tool_calls and (
                not result.content
                or isinstance(result.content, list)
                and not result.content[0].get("text")
            ):
                messages = state["messages"] + [("user", "Respond with a real output.")]
                state = {**state, "messages": messages}
                messages = state["messages"] + [("user", "Respond with a real output.")]
                state = {**state, "messages": messages}
            else:
                break
        return {"messages": result}

# order status Assistant
order_status_prompt_template = prompts.get("order_status_template", "")

order_status_prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            order_status_prompt_template
        ),
        ("placeholder", "{messages}"),
    ]
)

order_status_safe_tools = [structured_rag]
order_status_tools = order_status_safe_tools + [ProductValidation]

# Return Processing Assistant
return_processing_prompt_template = prompts.get("return_processing_template", "")

return_processing_prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            return_processing_prompt_template
        ),
        ("placeholder", "{messages}"),
    ]
)

return_processing_safe_tools = [get_recent_return_details, return_window_validation]
return_processing_sensitive_tools = [update_return]
return_processing_tools = return_processing_safe_tools + return_processing_sensitive_tools + [ProductValidation]

primary_assistant_prompt_template = prompts.get("primary_assistant_template", "")

primary_assistant_prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            primary_assistant_prompt_template
        ),
        ("placeholder", "{messages}"),
    ]
)

primary_assistant_tools = [
        HandleOtherTalk,
        ToProductQAAssistant,
        ToOrderStatusAssistant,
        ToReturnProcessing,
    ]

# BUILD THE GRAPH
builder = StateGraph(State)


# SUB AGENTS
# Create product_qa Assistant
builder.add_node(
    "enter_product_qa",
    handle_product_qa,
)

builder.add_edge("enter_product_qa", END)

builder.add_node("order_validation", validate_product_info)
builder.add_node("ask_clarification", ask_clarification)

# Create order_status Assistant
builder.add_node(
    "enter_order_status", create_entry_node("Order Status Assistant")
)
builder.add_node("order_status", Assistant(order_status_prompt, order_status_tools))
builder.add_edge("enter_order_status", "order_status")
builder.add_node(
    "order_status_safe_tools",
    create_tool_node_with_fallback(order_status_safe_tools),
)


def route_order_status(
    state: State,
) -> Literal[
    "order_status_safe_tools",
    "order_validation",
    "__end__"
]:
    route = tools_condition(state)
    if route == END:
        return END
    tool_calls = state["messages"][-1].tool_calls
    tool_names = [t.name for t in order_status_safe_tools]
    do_product_validation = any(tc["name"] == ProductValidation.__name__ for tc in tool_calls)
    if do_product_validation:
        return "order_validation"
    if all(tc["name"] in tool_names for tc in tool_calls):
        return "order_status_safe_tools"
    return "order_status_sensitive_tools"

builder.add_edge("order_status_safe_tools", "order_status")
builder.add_conditional_edges("order_status", route_order_status)

# Create return_processing Assistant
builder.add_node("return_validation", validate_product_info)

builder.add_node(
    "enter_return_processing",
    create_entry_node("Return Processing Assistant"),
)
builder.add_node("return_processing", Assistant(return_processing_prompt, return_processing_tools))
builder.add_edge("enter_return_processing", "return_processing")

builder.add_node(
    "return_processing_safe_tools",
    create_tool_node_with_fallback(return_processing_safe_tools),
)
builder.add_node(
    "return_processing_sensitive_tools",
    create_tool_node_with_fallback(return_processing_sensitive_tools),
)


def route_return_processing(
    state: State,
) -> Literal[
    "return_processing_safe_tools",
    "return_processing_sensitive_tools",
    "return_validation",
    "__end__",
]:
    route = tools_condition(state)
    if route == END:
        return END
    tool_calls = state["messages"][-1].tool_calls
    do_product_validation = any(tc["name"] == ProductValidation.__name__ for tc in tool_calls)
    if do_product_validation:
        return "return_validation"
    tool_names = [t.name for t in return_processing_safe_tools]
    if all(tc["name"] in tool_names for tc in tool_calls):
        return "return_processing_safe_tools"
    return "return_processing_sensitive_tools"


builder.add_edge("return_processing_sensitive_tools", "return_processing")
builder.add_edge("return_processing_safe_tools", "return_processing")
builder.add_conditional_edges("return_processing", route_return_processing)


def user_info(state: State):
    return {"user_purchase_history": get_purchase_history(state["user_id"]), "current_product": ""}

builder.add_node("fetch_purchase_history", user_info)
builder.add_edge(START, "fetch_purchase_history")
builder.add_edge("ask_clarification", END)

# Primary assistant
builder.add_node("primary_assistant", Assistant(primary_assistant_prompt, primary_assistant_tools))
builder.add_node(
    "other_talk", handle_other_talk
)

#  Add "primary_assistant_tools", if necessary
def route_primary_assistant(
    state: State,
) -> Literal[
    "enter_product_qa",
    "enter_order_status",
    "enter_return_processing",
    "other_talk",
    "__end__",
]:
    route = tools_condition(state)
    if route == END:
        return END
    tool_calls = state["messages"][-1].tool_calls
    if tool_calls:
        if tool_calls[0]["name"] == ToProductQAAssistant.__name__:
            return "enter_product_qa"
        elif tool_calls[0]["name"] == ToOrderStatusAssistant.__name__:
            return "enter_order_status"
        elif tool_calls[0]["name"] == ToReturnProcessing.__name__:
            return "enter_return_processing"
        elif tool_calls[0]["name"] == HandleOtherTalk.__name__:
            return "other_talk"
    raise ValueError("Invalid route")

builder.add_edge("other_talk", END)

# The assistant can route to one of the delegated assistants,
# directly use a tool, or directly respond to the user
builder.add_conditional_edges(
    "primary_assistant",
    route_primary_assistant,
    {
        "enter_product_qa": "enter_product_qa",
        "enter_order_status": "enter_order_status",
        "enter_return_processing": "enter_return_processing",
        "other_talk":"other_talk",
        END: END,
    },
)


def is_order_product_valid(state: State)  -> Literal[
    "ask_clarification",
    "order_status"
]:
    """Conditional edge from validation node to decide if we should ask followup questions"""
    if state["needs_clarification"] == True:
        return "ask_clarification"
    return "order_status"

def is_return_product_valid(state: State)  -> Literal[
    "ask_clarification",
    "return_processing"
]:
    """Conditional edge from validation node to decide if we should ask followup questions"""
    if state["needs_clarification"] == True:
        return "ask_clarification"
    return "return_processing"

builder.add_conditional_edges(
    "order_validation",
    is_order_product_valid
)
builder.add_conditional_edges(
    "return_validation",
    is_return_product_valid
)

builder.add_edge("fetch_purchase_history", "primary_assistant")


# Allow multiple async loop togeather
# This is needed to create checkpoint as it needs async event loop
# TODO: Move graph build into a async function and call that to remove nest_asyncio
import nest_asyncio
nest_asyncio.apply()

# To run the async main function
import asyncio

memory = None
pool = None

# TODO: Remove pool as it's not getting used
# WAR: It's added so postgres does not close it's session
async def get_checkpoint():
    global memory, pool
    memory, pool = await get_checkpointer()

asyncio.run(get_checkpoint())

# Compile
graph = builder.compile(checkpointer=memory,
                        interrupt_before=["return_processing_sensitive_tools"],
                        #interrupt_after=["ask_human"]
                        )

try:
    # Generate the PNG image from the graph
    png_image_data = graph.get_graph(xray=True).draw_mermaid_png()
    # Save the image to a file in the current directory
    with open("graph_image.png", "wb") as f:
        f.write(png_image_data)
except Exception as e:
    # This requires some extra dependencies and is optional
    logger.info(f"An error occurred: {e}")

with the guardrails configuration built and wrapped around the agent, we will run the nemoguardrails server. Make sure to add the absolute path of the `config` directory and the `container image` and run the following cell

In [None]:
%%bash
docker run -p 8000:8000 -v </path/to/local/config/>:/config <IMAGE_NAME>

## Exposing the Interface for Testing (optional)

The Blueprint comes equiped with a basic UI for testing the deployment. This interface is served at port 3001. In order to expose the port and try out the interaction, you need to follow the steps below.

First, navigate back to the created Launchable instance page and click on the Access menu.


![Access Menu](https://github.com/NVIDIA-AI-Blueprints/ai-virtual-assistant/raw/main/docs/imgs/brev-cli-install.png)


Scroll down until you find "Using Tunnels" section and click on Share a Service button.


![Using Tunnels](https://github.com/NVIDIA-AI-Blueprints/ai-virtual-assistant/raw/main/docs/imgs/brev-tunnels.png)


Enter the port 3001, as that is where the UI service endpoint is. Confirm with Done. Then click on Edit Access and make the port public:


![Share Access](https://github.com/NVIDIA-AI-Blueprints/ai-virtual-assistant/raw/main/docs/imgs/brev-share-access.png)


Past this point, by clicking on the link, the UI should appear in your browser and you are free to interact with the assistant and to ask him about the data that was ingested.


![AI Virtual Assistant Interface](https://github.com/NVIDIA-AI-Blueprints/ai-virtual-assistant/raw/main/docs/imgs/ai-virtual-assistant-interface.png)