In [1]:
##
import os 
from langchain_groq import ChatGroq
from dotenv import load_dotenv

load_dotenv()

GROQ_API_KEY = os.getenv("GROQ_API_KEY")



llm = ChatGroq(model = "llama-3.2-1b-preview",
            groq_api_key = GROQ_API_KEY)

In [2]:
import spacy
from langchain.tools import tool

from typing import Dict, List, Optional

# Load the spaCy English model
nlp = spacy.load("en_core_web_sm")

# A mapping from common natural language phrases to OData query components
NL_TO_OLQUERY_MAP: Dict[str, str] = {
    "find": "select",
    "get": "select",
    "filter": "filter",
    "where": "filter",
    "order by": "orderby",
    "greater than": "gt",
    "less than": "lt",
    "equals": "eq",
    "not equal": "ne",
    "equal to": "eq"
}

def parse_natural_language(query: str) -> Dict[str, List[str]]:
    """
    Parse the natural language query to extract key information.
    """
    doc = nlp(query.lower())
    
    # Placeholder for extracted parts of the OData query
    select_part: List[str] = []
    filter_part: List[str] = []
    orderby_part: Optional[str] = None
    
    # Parse each token and map to OData equivalent
    for i, token in enumerate(doc):
        if token.text == "find" or token.text == "get":
            # The next token is likely what we're selecting
            if i + 1 < len(doc):
                select_part.append(doc[i+1].text)
        elif token.text == "where":
            # Look for the condition after "where"
            if i + 3 < len(doc):
                field = doc[i+1].text
                condition = " ".join([t.text for t in doc[i+2:i+4]])
                if condition in NL_TO_OLQUERY_MAP:
                    odata_condition = NL_TO_OLQUERY_MAP[condition]
                    if i + 4 < len(doc):
                        value = doc[i+4].text
                        filter_part.append(f"{field} {odata_condition} {value}")
    
    return {
        "select": select_part,
        "filter": filter_part,
        "orderby": orderby_part
    }

def construct_odata_query(select_part: List[str], filter_part: List[str], orderby_part: Optional[str]) -> str:
    """
    Construct the OData query from parsed components.
    """
    odata_query = "$select=" + ",".join(select_part) if select_part else ""
    
    if filter_part:
        odata_query += "&$filter=" + " and ".join(filter_part)
        
    if orderby_part:
        odata_query += "&$orderby=" + orderby_part
    
    return odata_query

@tool
def nl_to_odata(query: str) -> str:
    """
    Convert a natural language query to an OData query.

    Args:
        query (str): The natural language query to convert.

    Returns:
        str: The corresponding OData query.

    Example:
        nl_to_odata("find products where price is less than 50")
        -> "$select=products&$filter=price lt 50"
    """
    parsed_query = parse_natural_language(query)
    odata_query = construct_odata_query(parsed_query["select"], parsed_query["filter"], parsed_query["orderby"])
    return odata_query

In [3]:
r1 = nl_to_odata( "Find customers where age greater than 30")
r1

  r1 = nl_to_odata( "Find customers where age greater than 30")


'$select=customers&$filter=age gt 30'

In [4]:
from langchain_core.tools import tool
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import Runnable

from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph.message import AnyMessage, add_messages
from langchain_core.messages import ToolMessage
from langchain_core.runnables import RunnableLambda
from langgraph.prebuilt import ToolNode
from langgraph.prebuilt import tools_condition
from langgraph.graph import END, StateGraph, START

### utility

In [5]:
def handle_tool_error(state) -> dict:
    """
    Function to handle errors that occur during tool execution.
    
    Args:
        state (dict): The current state of the AI agent, which includes messages and tool call details.
    
    Returns:
        dict: A dictionary containing error messages for each tool that encountered an issue.
    """
    # Retrieve the error from the current state
    error = state.get("error")
    
    # Access the tool calls from the last message in the state's message history
    tool_calls = state["messages"][-1].tool_calls
    
    # Return a list of ToolMessages with error details, linked to each tool call ID
    return {
        "messages": [
            ToolMessage(
                content=f"Error: {repr(error)}\n please fix your mistakes.",  # Format the error message for the user
                tool_call_id=tc["id"],  # Associate the error message with the corresponding tool call ID
            )
            for tc in tool_calls  # Iterate over each tool call to produce individual error messages
        ]
    }

def create_tool_node_with_fallback(tools: list) -> dict:
    """
    Function to create a tool node with fallback error handling.
    
    Args:
        tools (list): A list of tools to be included in the node.
    
    Returns:
        dict: A tool node that uses fallback behavior in case of errors.
    """
    # Create a ToolNode with the provided tools and attach a fallback mechanism
    # If an error occurs, it will invoke the handle_tool_error function to manage the error
    return ToolNode(tools).with_fallbacks(
        [RunnableLambda(handle_tool_error)],  # Use a lambda function to wrap the error handler
        exception_key="error"  # Specify that this fallback is for handling errors
    )

### Define the State and Assistant Class

In [6]:
class State(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]


class Assistant:
    def __init__(self, runnable: Runnable):
        # Initialize with the runnable that defines the process for interacting with the tools
        self.runnable = runnable

    def __call__(self, state: State):
        while True:
            # Invoke the runnable with the current state (messages and context)
            print("Current state before running tool:", state)
            result = self.runnable.invoke(state)
            

            # Check if the tool was invoked and returned valid results
            print("Result after tool execution:", result)  # DEBUG
            
            # If the tool fails to return valid output, re-prompt the user to clarify or retry
            if not result.tool_calls and (
                not result.content
                or isinstance(result.content, list)
                and not result.content[0].get("text")
            ):
                # Add a message to request a valid response
                messages = state["messages"] + [("user", "Respond with a real output.")]
                state = {**state, "messages": messages}
            else:
                # Break the loop when valid output is obtained
                break

        # Return the final state after processing the runnable
        return {"messages": result}

### prompt

In [7]:
text2Odata_prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            '''You are an efficient query assistant specialized in converting natural language into OData queries. 
            Your job is to help users transform their natural language requests into proper OData queries for interacting 
            with their API's. You will take a user's natural language query, parse it, and return the correct OData query format.
            
            
            Always use the "nl_to_odata" tool to generate the OData query. Do not attempt to create the query yourself.
            After getting the result from the tool, present it to the user in a clear format
            ''',
        ),
        ("placeholder", "{messages}"),
    ]
)


In [8]:
text2Odata_tool = [nl_to_odata]

In [9]:
# Bind the tools to the assistant's workflow
text2Odata_assistant_runnable = text2Odata_prompt | llm.bind_tools(text2Odata_tool)

### Build the Graph


In [10]:
builder = StateGraph(State)
builder.add_node("Text2Odata", Assistant(text2Odata_assistant_runnable))
builder.add_node("tools", create_tool_node_with_fallback(text2Odata_tool))

builder.add_edge(START, "Text2Odata")  # Start with the assistant
builder.add_conditional_edges("Text2Odata", tools_condition)  # Move to tools after input
builder.add_edge("tools", "Text2Odata")  # Return to assistant after tool execution


graph = builder.compile()


In [11]:
# Let's create an example conversation a user might have with the assistant
tutorial_questions = [
 "Find customers where age greater than 30"
]

In [12]:
for question in tutorial_questions:
    events = graph.stream(
        {"messages": ("user", question)}, stream_mode="values"
    )
    for event in events:
        print(event)
        print("--------------------------------------")
        human_message = event['messages'][0].content
        print("Human_message:", human_message)
        print("-------------------------------------")
        #ai_message_content = event["messages"][1]["content"]
        #print("AI Message:", ai_message_content)
        



{'messages': [HumanMessage(content='Find customers where age greater than 30', additional_kwargs={}, response_metadata={}, id='b9a873f6-b4e7-4bd9-9d93-9c473b48f5a6')]}
--------------------------------------
Human_message: Find customers where age greater than 30
-------------------------------------
Current state before running tool: {'messages': [HumanMessage(content='Find customers where age greater than 30', additional_kwargs={}, response_metadata={}, id='b9a873f6-b4e7-4bd9-9d93-9c473b48f5a6')]}
Result after tool execution: content='<nl_to_odata>"find customers where age is greater than 30"</nl_to_odata>{"name":"nl_to_odata","description":"Convert a natural language query to an OData query.\\n\\nArgs:\\n    query (str): The natural language query to convert.\\n\\nReturns:\\n    str: The corresponding OData query.\\n\\nExample:\\n    nl_to_odata(\\"find customers where age is greater than 30\\")\\n    -\\u003e \\"$select=customers&$filter=age gt 30\\"","parameters":{"properties":{"qu