## Using Prem AI and DSPy for Text to SQL Generation

In this cookbook, we will be exploring how we can use PremAI SDK and DSPy to generate SQL from text. In this tutorial we are going to:

1. Write simple prompts using `dspy.Signature` to give instruction to the LLM. 
2. Use `dspy.Module` to write a simple Text2SQL pipeline. 
3. Use `dspy.teleprompt` to automatically optimize the prompt for better result
4. Use `dspy.Evaluate` to evaluate the results. 

### Objective

The objective of this tutorial is to introduce you to DSPy and how to use it with Prem SDK. Finally in the later part of the tutorial we are also going to show you, how you can use Prem Platform to debug and see what DSPy has been optimizing and how it can be improved further. 

Before getting started, we install and import all our required packages.

**NOTE:**

For those who are not familiar with DSPy, DSPy is a LLM orchestration tool whose API is very similar to [PyTorch](https://pytorch.org/). The main focus of DSPy is to help developers write clean and modular code rather than writing very big prompts. You can learn more about it in their [documentation](https://dspy-docs.vercel.app/).

Before getting started, we create a new virtual environment and install all our required packages from this [requirements.txt](/text-2-sql/requirements.txt) file. Once done let's import the required packages.

In [2]:
import dspy
from dspy import PremAI
from dspy.evaluate import Evaluate
from dspy.datasets import DataLoader

  from .autonotebook import tqdm as notebook_tqdm


First, we define some constants and instantiate DSPy PremAI LM. We will be using CodeLLama as our base LLM to generate SQL Queries. Here is how our launchpad looks like

![image](assets/text2sql_launchpad.png)

If you are not familier with setting up a project on Prem, we recommend to take a quick look on [this guide](https://docs.premai.io/introduction). It is super intuitive and you can even get started for free. 

Prem AI offers a variety of models (see the [list](https://docs.premai.io/get-started/supported-models) here), so you can experiment with all the models. 

We did some initial experiments in [Prem Playground](https://app.premai.io/s/97e27016-3265-42e9-8358-68340a4d3ed7) and what we saw was CodeLlama was performing consistently better than other models like Claude 3, GPT-4o and Mistral. So for this cookbook experiment we are using Code Llama Instruct 70B by Meta AI. 

In [3]:
# Please change the API KEY and Project ID, This will not be valid 
# in your case. 

PREMAI_API_KEY = "G91lPIK3XDX7ohwxu6EIlmRDiHQnmO7SMn"
PROJECT_ID = "4071"
generation_kwargs = {
    "temperature": 0.1,
    "max_tokens": 1000
}

We used the following system prompt.

```markdown
You are an expert in SQL. You can understand and write complex SQL queries. You will be given some plain text as a questions and you are required to generate SQL query from that. Do not generate anything else. 
```

Additionally we set the temperature to be 0.1 and max_tokens to 500. You can copy these configurations to your experiment to reproduce or get started with your own implementation.

Now, we instantiate our `lm` object and test it before moving forward. 

In [4]:
lm = PremAI(
    project_id=PROJECT_ID,
    api_key=PREMAI_API_KEY,
    **generation_kwargs
)
dspy.configure(lm=lm)

lm("hello")

["SELECT 'hello' AS message;"]

### Loading Datasets

We start by loading a dataset. We are going to use the [gretelai/synthetic_text_to_sql](https://huggingface.co/datasets/gretelai/synthetic_text_to_sql) dataset for our example. We are also going to split the dataset into validation and test splits. The code below shows how we load and split the dataset using DSPy. 

In [5]:
data_loader = DataLoader()

# Load the dataset from huggingface

trainset = data_loader.from_huggingface(
    dataset_name="gretelai/synthetic_text_to_sql",
    fields=("sql_prompt", "sql_context", "sql"),
    input_keys=("sql_prompt", "sql_context"),
    split="train"
)

testset = data_loader.from_huggingface(
    dataset_name="gretelai/synthetic_text_to_sql", 
    fields=("sql_prompt", "sql_context", "sql"), 
    input_keys=("sql_prompt", "sql_context"), 
    split="test"
)

We sample a very small amount of data because we are not going to do any kind of fine-tuning here. DSPy will optimize our prompt by tweaking it in several ways (for example, adding optimal few-shot examples). However, DSPy also provides options for fine-tuning the weights as well, but that is out of the scope of this tutorial. 

In [6]:
trainset = data_loader.sample(dataset=trainset, n=100)
testset = data_loader.sample(dataset=testset, n=75)

_trainval = data_loader.train_test_split(
    dataset=trainset, 
    test_size=0.25, 
    random_state=1399
)

trainset, valset = _trainval["train"], _trainval["test"]

Let's also take one sample which will help us to do initial checks for our implementations. 

In [7]:
sample = data_loader.sample(dataset=trainset, n=1)[0]

for k, v in sample.items():
    print(f"\n{k.upper()}:\n")
    print(v)


SQL_PROMPT:

Find the top 3 menu items with the highest profit margin in each restaurant.

SQL_CONTEXT:

CREATE TABLE MenuItems (restaurant_id INT, menu_item_id INT, cost DECIMAL(10,2), price DECIMAL(10,2)); INSERT INTO MenuItems (restaurant_id, menu_item_id, cost, price) VALUES (1, 101, 5.00, 12.00), (1, 102, 6.00, 15.00), (1, 103, 4.50, 11.00), (2, 101, 4.00, 10.00), (2, 102, 5.50, 14.00);

SQL:

SELECT restaurant_id, menu_item_id, cost, price, (price - cost) as profit_margin FROM menuitems WHERE profit_margin IN (SELECT MAX(profit_margin) FROM menuitems GROUP BY restaurant_id, FLOOR((ROW_NUMBER() OVER (PARTITION BY restaurant_id ORDER BY profit_margin DESC)) / 3)) ORDER BY restaurant_id, profit_margin DESC;


### Creating a DSPy Signature

Most LLM orchestration frameworks, like langchain and llama-index, tell you to prompt the language models explicitly. Eventually, we need to tweak or optimize the prompt to seek better results. However, writing and managing big prompts can be super messy.

The whole point of DSPy is to shift the whole prompting process to a more programmatic paradigm. Signature in DSPy lets you specify how the language model's input and output behaviour should be. 

**NOTE**
> The Field names inside a Signature have a semantic significance. This means each field that represents a role (for example: `question` or `answer` or `sql_query`) defines that prompt variable and what the variable is about. More explained below in an example. 

For this example, we will try to understand class-based DsPy Signatures. You can learn more about Signatures in the [official documentation](https://dspy-docs.vercel.app/docs/building-blocks/signatures). 

#### Class-based DsPY Signatures 

You define a class in class-based signatures, where:

- The class docstring is used to express the nature of the overall task. 

- Provide different input and output variables using `InputField` and `OutputField`, which also describe the nature of those variables. 

Here is a simple example of how we define a simple class-based DsPY Signatures. 

In [8]:
class Emotion(dspy.Signature):
    # Define your overall task descripition here 
    """Classify emotion among sandness, joy, love, anger, fear, surprise."""

    sentence = dspy.InputField(
        desc="A sentence which needs to be classified"
    )

    sentiment = dspy.OutputField(
        desc="classify in either of one class: sandness / joy / love / anger / fear /surprise. Do not write anything else, just the class. Sentiment:"
    )

classify = dspy.Predict(Emotion)
classify(sentence="The day is super gloomy today ughhh")

Prediction(
    sentiment='sandness'
)

To inspect how the actual prompt was passed, you can check using `lm.inspect` method. Here is how it looks like for the `Emotion` class. 

In [9]:
lm.inspect_history(n=1)




Classify emotion among sandness, joy, love, anger, fear, surprise.

---

Follow the following format.

Sentence: A sentence which needs to be classified
Sentiment: classify in either of one class: sandness / joy / love / anger / fear /surprise. Do not write anything else, just the class. Sentiment:

---

Sentence: The day is super gloomy today ughhh
Sentiment:[32msandness[0m





'\n\n\nClassify emotion among sandness, joy, love, anger, fear, surprise.\n\n---\n\nFollow the following format.\n\nSentence: A sentence which needs to be classified\nSentiment: classify in either of one class: sandness / joy / love / anger / fear /surprise. Do not write anything else, just the class. Sentiment:\n\n---\n\nSentence: The day is super gloomy today ughhh\nSentiment:\x1b[32msandness\x1b[0m\n\n\n'

Now we make a signature for our SQL Generator. In this signature, we will 

1. Define the overall task in the class docstring.
2. Provide input and output prompt variables. 
3. Provide the description of Input and Output variables to tell what is expected here. 

You can also think of making a prompt very similar to the one shown below but in a more programable way:

```python
prompt = """
Transform a natural language query into an SQL query.
Do not output anything other than SQL Query. You will be given the 
sql_prompt which will tell what you need to do and sql_context
which will give some additional context to generate the right SQL. Here are the inputs:

Natural language query: {sql_prompt}
The context of the query: {sql_context}

Write the SQL here: 
"""
```

In [10]:
class Text2SQLSignature(dspy.Signature):
    """Transform a natural language query into a SQL query.
    You will be given the sql_prompt which will tell what you need to do 
    and a sql_context which will give some additional context to generate the right SQL.
    Only generate the SQL query nothing else. You should give one correct answer.
    starting and ending with ```
    """

    sql_prompt = dspy.InputField(desc="Natural language query")
    sql_context = dspy.InputField(desc="Context for the query")
    sql = dspy.OutputField(desc="SQL Query")

Now, let's generate a single sample with the signature we just wrote. Please note that the signature is just a blueprint of what we want to achieve. You are not required to tweak the prompt anymore. It is the job of DSPy to do optimization to make the best prompt. 

In [11]:
generate_sql_from_query = dspy.Predict(signature=Text2SQLSignature)
result = generate_sql_from_query(
    sql_prompt=sample["sql_prompt"],
    sql_context=sample["sql_context"]
)

for k, v in result.items():
    print(f"\n{k.upper()}:\n")
    print(v)


SQL:

```
WITH profit_margin AS (
  SELECT restaurant_id, menu_item_id, (price - cost) / price AS margin
  FROM MenuItems
)
SELECT restaurant_id, menu_item_id, margin
FROM (
  SELECT restaurant_id, menu_item_id, margin,
         ROW_NUMBER() OVER (PARTITION BY restaurant_id ORDER BY margin DESC) AS row_num
  FROM profit_margin
) AS subquery
WHERE row_num <= 3;
```


### DSPy Module

Now we know how Signatures work in DSPy. Multiple signatures or a prompting technique could work as a function. You can compose those techniques into a single module or a program. 

If you come from a deep learning background, then you have heard about `torch.nn.Module`, which helps to compose multiple layers to a single program. 

Similarly, in DSPy, multiple modules can be composed into bigger modules (programs). Let's create a simple module for our Text2SQL generator.

In [12]:
class Text2SQLProgram(dspy.Module):
    def __init__(self, signature: dspy.Signature):
        super().__init__()
        self.program = dspy.Predict(signature=signature)

    def forward(self, sql_prompt, sql_context):
        return self.program(
            sql_prompt=sql_prompt,
            sql_context=sql_context
        )

Now you might wonder why we are doing the same thing but now wrapping inside another class? The reason is that you might have multiple such signatures. Each signature might have different purposes. You can chain each of them individually to come up with a single output. However, that is out of the scope of the current tutorial. Now, let's run this module for a sanity check.

In [13]:
text2sql = Text2SQLProgram(signature=Text2SQLSignature)
result = text2sql(
    sql_prompt=sample["sql_prompt"],
    sql_context=sample["sql_context"]
)

for k, v in result.items():
    print(f"\n{k.upper()}:\n")
    print(v)


SQL:

```
WITH profit_margin AS (
  SELECT restaurant_id, menu_item_id, (price - cost) / price AS margin
  FROM MenuItems
)
SELECT restaurant_id, menu_item_id, margin
FROM (
  SELECT restaurant_id, menu_item_id, margin,
         ROW_NUMBER() OVER (PARTITION BY restaurant_id ORDER BY margin DESC) AS row_num
  FROM profit_margin
) AS subquery
WHERE row_num <= 3;
```


Awesome, now that things are working, let's make a simple metrics in which we can quantify how well the SQL queries are being generated. 

### Metrics in DSPy

A metric is a function that quantifies in some way how a ground truth is related to the prediction and how good the predicted output is. A simple example is accuracy. So, in our case, a very simple metric could be a direct string match on whether our predicted SQL string is identical with the actual SQL string. 

In [14]:
import re 
import sqlparse

def normalise_sql_string(sql_string):
    normalized = re.sub(r'```', '', sql_string).strip()
    return re.sub(r'\s+', ' ', normalized)

def compare_sqls(ground_truth, prediction, trace=None):
    ground_truth = ground_truth.sql
    prediction = prediction.sql
    
    ground_truth = normalise_sql_string(sql_string=ground_truth)
    prediction = normalise_sql_string(sql_string=prediction)

    ground_truth_parsed = sqlparse.format(
        ground_truth,
        reindent=True,
        keyword_case="upper"
    ).strip()

    prediction_parsed = sqlparse.format(
        prediction,
        reindent=True,
        keyword_case="upper"
    ).strip()

    return ground_truth_parsed == prediction_parsed

In [15]:
from tqdm import tqdm

scores = []

for x in tqdm(valset, total=len(valset)):
    prediction = text2sql(
        sql_prompt=x.sql_prompt, sql_context=x.sql_context
    )
    ground_truth = x
    score = compare_sqls(ground_truth=ground_truth, prediction=prediction)
    scores.append(score)

100%|██████████| 25/25 [01:16<00:00,  3.04s/it]


### Optimizing a DSPy program

An optimizer in DSPy optimizes the overall prompt workflow by tuning the prompt and/or the LM weights to maximize the target metrics, such as accuracy. 

Optimizers take three things in the input:

1. **The DSPy Program**: This may be a single module (e.g., dspy. Predict) or a complex multi-module program. We have already defined this in the above cells.

2. **Metric Function:** This is the function that evaluates the program and assigns a score in the end. We have already defined this in the above cell.

3. **Few training inputs:** This may be very small (i.e., only 5 or 10 examples) and incomplete (only inputs to your program, without any labels).

You can learn more about DSPy Optimizers in their [official documentation](https://dspy-docs.vercel.app/docs/building-blocks/optimizers). For this program, we are **NOT** optimizing any weights; hence we are not fine-tuning here. 

To keep things simple, we use the `LabeledFewShot` optimizer. It simply constructs a few shot examples (which we call as demos) from provided labelled input and output data points. Let's see it in action.

In [17]:
from dspy.teleprompt import LabeledFewShot

# k = number of few shot examples
optimizer = LabeledFewShot(k=4)

optimized_text2sql = optimizer.compile(
    student=text2sql,
    trainset=trainset
)

Awesome, finally we use the `dspy.Evaluate` to evaluate our overall system. Here is how we do it. 

In [18]:
evaluate = Evaluate(
    devset=valset,
    metric=compare_sqls,
    num_threads=3,
    display_progress=True,
    display_table=0
)

evaluate(optimized_text2sql)

Average Metric: 11 / 25  (44.0): 100%|██████████| 25/25 [00:22<00:00,  1.09it/s]


44.0

There are several more types of optimizers, and all of them follow the same flow. So you can simply plug in and out each of them to test which works best. You can learn more about different DSPy optimizers [here](https://dspy-docs.vercel.app/docs/building-blocks/optimizers). 

### Conclusion

In this tutorial, we cooked a straightforward example on how we can generate SQL from Text. 44 % is a fair accuracy for starting out. In this example, we used CodeLLama. However we can use different models and see which works best for us. 

Additionally when you use Prem, you can actually see all the traces and runs of the model being captured in the [traces section](https://docs.premai.io/get-started/monitoring). Inside Traces, you can monitor each LLM run and see how DSPy optimized your prompt. Here is an example of our case:

![](../assets/text2sql_traces.png)

From the above picture, as you can see, DSPy added some in-context examples to optimize the initial prompt. You can similarly do this for different settings (like using a different LLM or using a different Optimizer and all of those will be captured here)


Additionally, if you are using [Business plans](https://app.premai.io/users/organization/billing/), you can use these traces to further fine-tune your model to work even better. 