# I. Tools

In [None]:
from transformers import Tool, pipeline


class NLQsShield(Tool):
    name = "nlq_shield"
    description = ("This is a tool that detects malicious natural language questions in a text-to-sql task. It takes a natural language question as input, and returns MALICIOUS if the question is malcious and SAFE else.")

    inputs = ["text"]
    outputs = ["text"]


    def __call__(self, question):
        classifier = pipeline("text-classification", model="salmane11/SQLPromptShield4", device = 0)
        return classifier(question)[0]["label"]

In [None]:
class SQLShield(Tool):
    name = "sql_shield"
    description = ("This is a tool that detects malicious SQL queries. It takes an SQL query as input, and returns MALICIOUS if the query is malcious and SAFE else.")

    inputs = ["text"]
    outputs = ["text"]


    def __call__(self, query):
        classifier = pipeline("text-classification", model="salmane11/SQLQueryShield2", device = 0)
        return classifier(query)[0]["label"]

In [None]:
from openai import OpenAI
class SQLGenerator(Tool):
    name = "sql_generator"
    description = ("This is a tool that generate sql queries from a natural language question. It takes a question, and its context as input, and returns its related sql query.")

    inputs = ["text", "text"]
    outputs = ["text"]
    

    def __call__(self, question, context):
        
        client = OpenAI(api_key = "your_api_key")

        completion = client.chat.completions.create(
        model="gpt-3.5-turbo",
        messages=[
            {"role": "system", "content": f"You are an SQL assistant. Answer the following question using SQL on this context {context}. Provide only the query."},
            {"role": "user", "content": question}
          ]
        )


        return completion.choices[0].message.content

# I. OpenAI Agent:

In [None]:
from transformers import OpenAiAgent
template = """I will ask you to perform a task, your job is to come up with a series of simple commands in Python that will perform the task.
To help you, I will give you access to a set of tools that you can use. Each tool is a Python function and has a description explaining the task it performs, the inputs it expects and the outputs it returns.
You should first explain which tool you will use to perform the task and for what reason, then write the code in Python.
Each instruction in Python should be a simple assignment. You can print intermediate results if it makes sense to do so.

Tools:
<<all_tools>>


Task: "Check the safety of the question in the variable `question` before generating its related SQL query. Then check the safety of the generated query"

I will use the following tools: `nlq_shield` to check the question safety, then `sql_generator` to generate the SQL query, then `sql_shield` to check the query safety.

Answer:
```py
question_safety = nlq_shield(question)
print(f"The question is {question_safety}.")
if(question_safety=="SAFE"):
  sql_query = sql_generator(question)
  print(f"The SQL query is {sql_query}.")
  
  sql_query_safety = sql_shield(sql_query)
  print(f"The SQL query is {sql_query_safety}.")
else:
  print("ALERT")
```


Task: "Before generating the SQL query related to the variable `question`, ensure its safety. Subsequently, verify the safety of the generated query."

I will use the following tools: `nlq_shield` to check the question safety, then `sql_generator` to generate the SQL query, then `sql_shield` to check the query safety.

Answer:
```py
question_safety = nlq_shield(question)
print(f"The question is {question_safety}.")
if(question_safety=="SAFE"):
  sql_query = sql_generator(question, context)
  print(f"The SQL query is {sql_query}.")
  
  sql_query_safety = sql_shield(sql_query)
  print(f"The SQL query is {sql_query_safety}.")
else:
  print("ALERT")
```


Task: "<<prompt>>"

I will use the following"""
agent = OpenAiAgent(model="gpt-3.5-turbo", api_key="your_api_key",additional_tools=[SQLShield(),NLQsShield(), SQLGenerator()], run_prompt_template=template)

In [None]:
result = agent.run("Before we proceed with transforming the question into an SQL query, could you please check the question safety. Once we have the SQL query generated, we'll check its safety", question = 'list the students with high scores',context='create table students(name varchar(10),score number, full_name varchar(10), class varchar(10))')