In [1]:
#|export
import subprocess
import dspy
import logging
import numpy as np
from dotenv import load_dotenv
from dspy.datasets import DataLoader
from dspy.evaluate import Evaluate
from dspy.teleprompt import BootstrapFewShotWithRandomSearch, LabeledFewShot

## Text to SQL with DSPy

Heavily pulled from this great notebook --> [DSPy-Text2SQL.ipynb](https://github.com/jjovalle99/DSPy-Text2SQL/blob/23a0a347db2d7515c5a28c305dacaea00d09dddc/DSPy-Text2SQL.ipynb) 

In [2]:
# |export
########## DSPy Config  ##########
lm = dspy.OllamaLocal("open-hermes-2-4_0", max_tokens=2000, model_type="text")
evaluator_lm = dspy.OllamaLocal("nous-llama-3-8b-instruct", max_tokens=4000, model_type="text")

dspy.settings.configure(lm=lm)

Testing inference of `open-hermes` and the quantized `nous-llama-3-instruct`

In [3]:
lm("Who is Bodhi in Point Break?")

['Bodhi is a character from the 1991 film "Point Break." He is portrayed by actor Patrick Swayze. In the movie, Bodhi is the leader of a group of surfers who turn out to be bank robbers. He is known for his charismatic and free-spirited nature, as well as his deep connection with surfing.']

In [4]:
evaluator_lm(prompt="Who is Johnny Utah?")

['Johnny Utah is a fictional character played by Keanu Reeves in the 1991 action film "Point Break". He is an FBI agent who goes undercover as a surfer to investigate a group of bank robbers who are also surfers. The movie follows his adventures and battles with the group, led by the charismatic and mysterious Bodhi (played by Patrick Swayze).']

### Add Arize Phoenix Telemetry

In [5]:
########## Arize Phoenix ##########
import phoenix as px
from openinference.instrumentation.dspy import DSPyInstrumentor
from opentelemetry import trace as trace_api
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk import trace as trace_sdk
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from openinference.semconv.trace import SpanAttributes

In [6]:
phoenix_session = px.launch_app()
endpoint = "http://localhost:6006/v1/traces"
resource = Resource(attributes={})
tracer_provider = trace_sdk.TracerProvider(resource=resource)
span_otlp_exporter = OTLPSpanExporter(endpoint=endpoint)
tracer_provider.add_span_processor(SimpleSpanProcessor(span_exporter=span_otlp_exporter))

🌍 To view the Phoenix app in your browser, visit http://localhost:6006/
📺 To view the Phoenix app in a notebook, run `px.active_session().view()`
📖 For more information on how to use Phoenix, check out https://docs.arize.com/phoenix


In [7]:
trace_api.set_tracer_provider(tracer_provider=tracer_provider)
DSPyInstrumentor().instrument()
class TestTrace(dspy.Signature):
    """
    Perform the task requested to the best of your ability.
    """

    task = dspy.InputField(desc="Task to be performed.")
    answer = dspy.OutputField(desc="The answer.")

In [8]:
tracer = trace_api.get_tracer(__name__)

In [9]:
tracer

<opentelemetry.sdk.trace.Tracer at 0x7fdad8125570>

Test traces

In [10]:
test_trace = dspy.Predict(TestTrace)
answer = test_trace(task="Say Hello.")
test_trace_cot = dspy.ChainOfThought(TestTrace)
pred = test_trace_cot(task="What would Bodhi from Point Break say is the key to surfing?")
pred.answer

'According to Bodhi from Point Break, the key to surfing is not just about technical skill but also about embracing the moment, feeling the rush of adrenaline, and living life to its fullest.'

In [11]:
np.random.seed(42)

Using [Synthetic Text to Sql](https://huggingface.co/datasets/gretelai/synthetic_text_to_sql/viewer/default/train?row=1) dataset

In [12]:
dl = DataLoader()
dataset = dl.from_huggingface(
    dataset_name="gretelai/synthetic_text_to_sql", # Dataset name from Huggingface
    fields=("sql_prompt", "sql_context", "sql"), # Fields needed
    input_keys=("sql_prompt", "sql_context") # What our model expects to recieve to generate an output
)
trainset, testset = dl.sample(dataset["train"], n=50), dl.sample(dataset["test"], n=25) # 50 training samples, 25 testing samples


Verify an example of the dataset

In [13]:
sample = dl.sample(dataset=trainset, n=1)[0]

In [14]:
for k,v in sample.items():
    print(f"\n{k.upper()}:\n")
    print(v)


SQL_PROMPT:

What is the total number of decentralized applications in the finance category worldwide?

SQL_CONTEXT:

CREATE TABLE dapps (id INT, category TEXT, region TEXT); INSERT INTO dapps (id, category, region) VALUES (1, 'Finance', 'Global');

SQL:

SELECT COUNT(*) FROM dapps WHERE category = 'Finance';


### Create Signature


In [15]:
class TextToSql(dspy.Signature):
    """Transform a natural language query into a SQL query."""

    sql_prompt = dspy.InputField(desc="Natural language query")
    sql_context = dspy.InputField(desc="Context for the query")
    sql = dspy.OutputField(desc="SQL query")

### Inference

Baseline Inference 

In [16]:
generate_sql_query = dspy.Predict(signature=TextToSql)

In [17]:
result = generate_sql_query(sql_prompt=sample["sql_prompt"], sql_context=sample["sql_context"])

for k, v in result.items():
    print(f"\n{k.upper()}:\n")
    print(v)


SQL:

Sql Prompt: What is the total number of decentralized applications in the finance category worldwide?
Sql Context: CREATE TABLE dapps (id INT, category TEXT, region TEXT); INSERT INTO dapps (id, category, region) VALUES (1, 'Finance', 'Global');
Sql: SELECT COUNT(*) FROM dapps WHERE category = 'Finance' AND region = 'Global';


In [18]:
result.sql

"Sql Prompt: What is the total number of decentralized applications in the finance category worldwide?\nSql Context: CREATE TABLE dapps (id INT, category TEXT, region TEXT); INSERT INTO dapps (id, category, region) VALUES (1, 'Finance', 'Global');\nSql: SELECT COUNT(*) FROM dapps WHERE category = 'Finance' AND region = 'Global';"

### Metric of Evaluation

In [19]:
class Correctness(dspy.Signature):
    """Assess if the SQL query accurately answers the given natural language query based on the provided context."""

    sql_prompt = dspy.InputField(desc="Natural language query")
    sql_context = dspy.InputField(desc="Context for the query")
    sql = dspy.InputField(desc="SQL query")
    correct = dspy.OutputField(desc="Indicate whether the SQL query correctly answers the natural language query based on the given context", prefix="Yes/No:")


In [20]:
def correctness_metric(example, pred, trace=None):
    sql_prompt, sql_context, sql = example.sql_prompt, example.sql_context, pred.sql

    # current_span = trace_api.get_current_span()
    # current_span.set_attribute(SpanAttributes.LLM_MODEL_NAME, "llama-3")

    correctness = dspy.Predict(Correctness)

    with dspy.context(lm=evaluator_lm):
        correct = correctness(
            sql_prompt=sql_prompt,
            sql_context=sql_context,
            sql=sql,
        )

    score = int(correct.correct == "Yes")

    if trace is not None:
        return score == 1

    return score

### Evaluate Single Data Point


In [21]:
_correctness = correctness_metric(example=sample, pred=result)
print(f"Correct SQL query: {'Yes' if _correctness else 'No'}")

Correct SQL query: Yes


In [22]:
evaluator_lm.inspect_history(n=1)





Assess if the SQL query accurately answers the given natural language query based on the provided context.

---

Follow the following format.

Sql Prompt: Natural language query

Sql Context: Context for the query

Sql: SQL query

Yes/No: Indicate whether the SQL query correctly answers the natural language query based on the given context

---

Sql Prompt: What is the total number of decentralized applications in the finance category worldwide?

Sql Context: CREATE TABLE dapps (id INT, category TEXT, region TEXT); INSERT INTO dapps (id, category, region) VALUES (1, 'Finance', 'Global');

Sql: Sql Prompt: What is the total number of decentralized applications in the finance category worldwide? Sql Context: CREATE TABLE dapps (id INT, category TEXT, region TEXT); INSERT INTO dapps (id, category, region) VALUES (1, 'Finance', 'Global'); Sql: SELECT COUNT(*) FROM dapps WHERE category = 'Finance' AND region = 'Global';

Yes/No:[32m Yes[0m





### To-do: Set up Metadata Tracer

In [23]:
# with tracer.start_as_current_span("testing metadata") as parent:
#     current_span = trace_api.get_current_span()
#     current_span.set_attribute("model", "llama-3-7b-instruct-5q")
#     evaluator_lm("Hello there")

In [24]:
# class QuestionClassifier(dspy.Module):
#     def __init__(self):
#         super().__init__()
#         # [snip]
#     def forward(self, question: str) -> tuple[str,str]:
#         current_span = trace_api.get_current_span()
#         current_span.set_attribute("metadata", "{ 'foo': 'bar' }")
#         # [your dspy code]

### Evaluate Entire Dataset

#### Open Hermes 2(quantized)

Quantized `open-hermes` outperforms `gpt-3.5-turbo-0125`, with `gpt` scoring about 68% and `open-hermes` scoring about 80% OOTB

In [25]:
with dspy.context(lm=lm):
    evaluate = Evaluate(
        devset=testset,
        metric=correctness_metric,
        num_threads=8,
        display_progress=True,
        display_table=5,
    )
    evaluate(generate_sql_query)

Average Metric: 18 / 25  (72.0): 100%|██████████| 25/25 [02:43<00:00,  6.53s/it]

Average Metric: 18 / 25  (72.0%)



  df.loc[:, metric_name] = df[metric_name].apply(


Unnamed: 0,sql_prompt,sql_context,example_sql,pred_sql,correctness_metric
0,Insert a new carbon offset program for 'Community Tree Planting' in 'Urban Area X',"CREATE TABLE carbon_offset_programs (program_id INT, program_name TEXT, location TEXT);","INSERT INTO carbon_offset_programs (program_id, program_name, location) VALUES (4, 'Community Tree Planting', 'Urban Area X');","Sql Prompt: Insert a new carbon offset program for 'Community Tree Planting' in 'Urban Area X' Sql Context: CREATE TABLE carbon_offset_programs (program_id INT, program_name TEXT,...",1
1,Which organizations had the highest number of successful interventions in 'Africa' in 2018?,"CREATE TABLE organizations (id INT, name VARCHAR(255)); INSERT INTO organizations (id, name) VALUES (1, 'UNHCR'), (2, 'IOM'), (3, 'World Vision'); CREATE TABLE interventions (id INT,...","SELECT organization_id, MAX(success) as highest_successful_interventions FROM interventions WHERE YEAR(intervention_date) = 2018 AND location = 'Africa';","SELECT organizations.name, COUNT(*) as number_of_successful_interventions FROM organizations JOIN interventions ON organizations.id = interventions.organization_id WHERE interventions.location = 'Africa' AND interventions.success > 0 AND YEAR(intervention_date) = 2018...",1
2,What is the total revenue of natural skincare products in Australia in the last year?,"CREATE TABLE skincare_revenue (revenue_id INT, product_id INT, revenue DECIMAL(5,2), is_natural BOOLEAN, revenue_date DATE); INSERT INTO skincare_revenue VALUES (1, 20, 55.99, true, '2021-12-15');",SELECT SUM(revenue) FROM skincare_revenue WHERE is_natural = true AND revenue_date BETWEEN '2020-01-01' AND '2021-12-31' AND country = 'Australia';,"SELECT SUM(revenue) AS total_revenue FROM skincare_revenue WHERE is_natural = true AND revenue_date >= DATEADD(year, -1, GETDATE()) AND revenue_date < GETDATE();",1
3,How many military aircraft maintenance requests were recorded for the Air Force in Q4 2019?,"CREATE TABLE maintenance_requests (request_id INT, service_branch VARCHAR(255), request_date DATE); INSERT INTO maintenance_requests (request_id, service_branch, request_date) VALUES (1, 'Air Force', '2019-10-01'), (2, 'Navy', '2019-12-02'), (3, 'Air...",SELECT COUNT(*) FROM maintenance_requests WHERE service_branch = 'Air Force' AND EXTRACT(QUARTER FROM request_date) = 4 AND EXTRACT(YEAR FROM request_date) = 2019;,"Sql Prompt: How many military aircraft maintenance requests were recorded for the Air Force in Q4 2019? Sql Context: CREATE TABLE maintenance_requests (request_id INT, service_branch...",1
4,"What is the total number of water wells dug in ""Latin America"" since 2018?","CREATE TABLE water_wells (id INT, project_id INT, location VARCHAR(255), construction_date DATE); INSERT INTO water_wells (id, project_id, location, construction_date) VALUES (1, 4001, 'Colombia', '2019-05-01'); INSERT INTO...",SELECT COUNT(*) FROM water_wells WHERE location = 'Latin America' AND YEAR(construction_date) >= 2018;,SELECT COUNT(*) FROM water_wells WHERE location LIKE '%Latin America%' AND construction_date >= '2018-01-01';,0


#### Llama 3 7b (quantized)

In [26]:
with dspy.context(lm=evaluator_lm):
    evaluate = Evaluate(
        devset=testset,
        metric=correctness_metric,
        num_threads=8,
        display_progress=True,
        display_table=5,
    )
    evaluate(generate_sql_query)

Average Metric: 20 / 25  (80.0): 100%|██████████| 25/25 [01:49<00:00,  4.39s/it]

Average Metric: 20 / 25  (80.0%)



  df.loc[:, metric_name] = df[metric_name].apply(


Unnamed: 0,sql_prompt,sql_context,example_sql,pred_sql,correctness_metric
0,Insert a new carbon offset program for 'Community Tree Planting' in 'Urban Area X',"CREATE TABLE carbon_offset_programs (program_id INT, program_name TEXT, location TEXT);","INSERT INTO carbon_offset_programs (program_id, program_name, location) VALUES (4, 'Community Tree Planting', 'Urban Area X');","--- Sql Prompt: Insert a new carbon offset program for 'Community Tree Planting' in 'Urban Area X' Sql Context: CREATE TABLE carbon_offset_programs (program_id INT, program_name...",1
1,Which organizations had the highest number of successful interventions in 'Africa' in 2018?,"CREATE TABLE organizations (id INT, name VARCHAR(255)); INSERT INTO organizations (id, name) VALUES (1, 'UNHCR'), (2, 'IOM'), (3, 'World Vision'); CREATE TABLE interventions (id INT,...","SELECT organization_id, MAX(success) as highest_successful_interventions FROM interventions WHERE YEAR(intervention_date) = 2018 AND location = 'Africa';","Sql Prompt: Which organizations had the highest number of successful interventions in 'Africa' in 2018? Sql Context: CREATE TABLE organizations (id INT, name VARCHAR(255)); INSERT...",1
2,What is the total revenue of natural skincare products in Australia in the last year?,"CREATE TABLE skincare_revenue (revenue_id INT, product_id INT, revenue DECIMAL(5,2), is_natural BOOLEAN, revenue_date DATE); INSERT INTO skincare_revenue VALUES (1, 20, 55.99, true, '2021-12-15');",SELECT SUM(revenue) FROM skincare_revenue WHERE is_natural = true AND revenue_date BETWEEN '2020-01-01' AND '2021-12-31' AND country = 'Australia';,"--- Sql Prompt: What is the total revenue of natural skincare products in Australia in the last year? Sql Context: CREATE TABLE skincare_revenue (revenue_id INT,...",0
3,How many military aircraft maintenance requests were recorded for the Air Force in Q4 2019?,"CREATE TABLE maintenance_requests (request_id INT, service_branch VARCHAR(255), request_date DATE); INSERT INTO maintenance_requests (request_id, service_branch, request_date) VALUES (1, 'Air Force', '2019-10-01'), (2, 'Navy', '2019-12-02'), (3, 'Air...",SELECT COUNT(*) FROM maintenance_requests WHERE service_branch = 'Air Force' AND EXTRACT(QUARTER FROM request_date) = 4 AND EXTRACT(YEAR FROM request_date) = 2019;,"Sql Prompt: How many military aircraft maintenance requests were recorded for the Air Force in Q4 2019? Sql Context: CREATE TABLE maintenance_requests (request_id INT, service_branch...",1
4,"What is the total number of water wells dug in ""Latin America"" since 2018?","CREATE TABLE water_wells (id INT, project_id INT, location VARCHAR(255), construction_date DATE); INSERT INTO water_wells (id, project_id, location, construction_date) VALUES (1, 4001, 'Colombia', '2019-05-01'); INSERT INTO...",SELECT COUNT(*) FROM water_wells WHERE location = 'Latin America' AND YEAR(construction_date) >= 2018;,"--- Sql Prompt: What is the total number of water wells dug in ""Latin America"" since 2018? Sql Context: CREATE TABLE water_wells (id INT, project_id...",1


### Optimize for Text-to-SQL

#### Create Program

In [29]:
class TextToSqlProgram(dspy.Module):
    def __init__(self):
        super().__init__()
        self.program = dspy.ChainOfThought(signature=TextToSql)

    def forward(self, sql_prompt, sql_context):
        return self.program(sql_prompt=sql_prompt, sql_context=sql_context)

#### Few Shot

In [30]:
# Execute optimizer -> This only adds a few shots to the prompt
optimizer = LabeledFewShot(k=4)
optimized_program = optimizer.compile(student=TextToSqlProgram(), trainset=trainset)

In [31]:
optimized_program(sql_context=sample["sql_context"], sql_prompt=sample["sql_prompt"])

Prediction(
    rationale='answer this question. Firstly, we need to identify the table that contains the information about decentralized applications and their categories. In this case, it is the "dapps" table. We also know that we are looking for decentralized applications in the finance category worldwide, so we can filter the data based on these criteria.\n\nNow, let\'s write the SQL query to get the total number of decentralized applications in the finance category worldwide:\n\n```sql\nSELECT COUNT(*) AS total_number FROM dapps WHERE category = \'Finance\' AND region = \'Global\';\n```\n\nThis query will count all the rows in the "dapps" table where the category is "Finance" and the region is "Global". The result of this query will be the total number of decentralized applications in the finance category worldwide.',
    sql="SELECT COUNT(*) AS total_number FROM dapps WHERE category = 'Finance' AND region = 'Global';"
)

In [32]:
lm.inspect_history(n=1)





Transform a natural language query into a SQL query.

---

Sql Prompt: Find the percentage of tourists visiting New Zealand from North America compared to the global total in 2019?
Sql Context: CREATE TABLE visitor_info (id INT, visitor_country VARCHAR(50), destination_country VARCHAR(50), visit_date DATE);
Sql: SELECT 100.0 * SUM(CASE WHEN visitor_country = 'United States' OR visitor_country = 'Canada' THEN 1 ELSE 0 END) / (SELECT COUNT(*) FROM visitor_info WHERE destination_country = 'New Zealand' AND YEAR(visit_date) = 2019) AS pct_north_america_nz_2019 FROM visitor_info WHERE destination_country = 'New Zealand' AND YEAR(visit_date) = 2019;

Sql Prompt: What is the average number of bikes available at each station?
Sql Context: CREATE TABLE stations (id INT, station_name VARCHAR(20), bikes INT, docks INT); INSERT INTO stations (id, station_name, bikes, docks) VALUES (1, 'Station 1', 10, 20), (2, 'Station 2', 15, 20), (3, 'Station 3', 20, 20);
Sql: SELECT AVG(bikes) FROM stations

### Evaluate the Optimized Program

In [33]:
evaluate(optimized_program, num_threads=8)

Average Metric: 9 / 25  (36.0): 100%|██████████| 25/25 [06:13<00:00, 14.96s/it]

Average Metric: 9 / 25  (36.0%)



  df.loc[:, metric_name] = df[metric_name].apply(


Unnamed: 0,sql_prompt,sql_context,example_sql,rationale,pred_sql,correctness_metric
0,Insert a new carbon offset program for 'Community Tree Planting' in 'Urban Area X',"CREATE TABLE carbon_offset_programs (program_id INT, program_name TEXT, location TEXT);","INSERT INTO carbon_offset_programs (program_id, program_name, location) VALUES (4, 'Community Tree Planting', 'Urban Area X');",insert a new carbon offset program for 'Community Tree Planting' in 'Urban Area X'. Step 1: Check if the program already exists. We can use...,"INSERT INTO carbon_offset_programs (program_id, program_name, location) VALUES (1, 'Community Tree Planting', 'Urban Area X');",1
1,Which organizations had the highest number of successful interventions in 'Africa' in 2018?,"CREATE TABLE organizations (id INT, name VARCHAR(255)); INSERT INTO organizations (id, name) VALUES (1, 'UNHCR'), (2, 'IOM'), (3, 'World Vision'); CREATE TABLE interventions (id INT,...","SELECT organization_id, MAX(success) as highest_successful_interventions FROM interventions WHERE YEAR(intervention_date) = 2018 AND location = 'Africa';","answer this question. First, we need to filter the interventions that took place in Africa in 2018. Then, we can group them by organization and...","SELECT organizations.name AS organization_name, COUNT(*) AS num_interventions FROM organizations JOIN interventions ON organizations.id = interventions.organization_id WHERE interventions.location = 'Africa' AND YEAR(interventions.intervention_date) = 2018 GROUP BY...",1
2,What is the total revenue of natural skincare products in Australia in the last year?,"CREATE TABLE skincare_revenue (revenue_id INT, product_id INT, revenue DECIMAL(5,2), is_natural BOOLEAN, revenue_date DATE); INSERT INTO skincare_revenue VALUES (1, 20, 55.99, true, '2021-12-15');",SELECT SUM(revenue) FROM skincare_revenue WHERE is_natural = true AND revenue_date BETWEEN '2020-01-01' AND '2021-12-31' AND country = 'Australia';,"answer this question correctly. Firstly, we need to filter the revenue data for natural skincare products. Then, we need to group the results by product...",SELECT SUM(revenue) AS total_revenue FROM skincare_revenue WHERE is_natural = true AND revenue_date >= '2021-01-01' AND revenue_date < '2022-01-01';,1
3,How many military aircraft maintenance requests were recorded for the Air Force in Q4 2019?,"CREATE TABLE maintenance_requests (request_id INT, service_branch VARCHAR(255), request_date DATE); INSERT INTO maintenance_requests (request_id, service_branch, request_date) VALUES (1, 'Air Force', '2019-10-01'), (2, 'Navy', '2019-12-02'), (3, 'Air...",SELECT COUNT(*) FROM maintenance_requests WHERE service_branch = 'Air Force' AND EXTRACT(QUARTER FROM request_date) = 4 AND EXTRACT(YEAR FROM request_date) = 2019;,"answer this question. First, we need to filter the maintenance requests table for only those that belong to the Air Force service branch. Then, we...",SELECT COUNT(*) AS num_requests FROM maintenance_requests WHERE service_branch = 'Air Force' AND request_date BETWEEN '2019-10-01' AND '2019-12-31';,1
4,"What is the total number of water wells dug in ""Latin America"" since 2018?","CREATE TABLE water_wells (id INT, project_id INT, location VARCHAR(255), construction_date DATE); INSERT INTO water_wells (id, project_id, location, construction_date) VALUES (1, 4001, 'Colombia', '2019-05-01'); INSERT INTO...",SELECT COUNT(*) FROM water_wells WHERE location = 'Latin America' AND YEAR(construction_date) >= 2018;,"answer this question. First, we need to filter the water wells that were dug in Latin America. Then, we need to count the number of...","SELECT YEAR(construction_date) AS year, COUNT(*) AS total_wells FROM water_wells WHERE location LIKE '%Latin America%' AND construction_date >= '2018-01-01' GROUP BY YEAR(construction_date) WITH ROLLUP; This query...",0


36.0