In [None]:
#|export
import subprocess
import dspy
import logging
import numpy as np
from dotenv import load_dotenv
from dspy.datasets import DataLoader
from dspy.evaluate import Evaluate
from dspy.teleprompt import BootstrapFewShotWithRandomSearch, LabeledFewShot

## Text to SQL with DSPy

Heavily pulled from this great notebook --> [DSPy-Text2SQL.ipynb](https://github.com/jjovalle99/DSPy-Text2SQL/blob/23a0a347db2d7515c5a28c305dacaea00d09dddc/DSPy-Text2SQL.ipynb) 

In [None]:
# |export
########## DSPy Config  ##########
lm = dspy.OllamaLocal("open-hermes-2-4_0", max_tokens=2000, model_type="text")
evaluator_lm = dspy.OllamaLocal("nous-llama-3-8b-instruct", max_tokens=4000, model_type="text")

dspy.settings.configure(lm=lm)

Testing inference of `open-hermes` and the quantized `nous-llama-3-instruct`

In [None]:
lm("Who is Bodhi in Point Break?")

In [None]:
evaluator_lm(prompt="Who is Johnny Utah?")

### Add Arize Phoenix Telemetry

In [None]:
########## Arize Phoenix ##########
import phoenix as px
from openinference.instrumentation.dspy import DSPyInstrumentor
from opentelemetry import trace as trace_api
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk import trace as trace_sdk
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from openinference.semconv.trace import SpanAttributes

In [None]:
phoenix_session = px.launch_app()
endpoint = "http://localhost:6006/v1/traces"
resource = Resource(attributes={})
tracer_provider = trace_sdk.TracerProvider(resource=resource)
span_otlp_exporter = OTLPSpanExporter(endpoint=endpoint)
tracer_provider.add_span_processor(SimpleSpanProcessor(span_exporter=span_otlp_exporter))

In [None]:
trace_api.set_tracer_provider(tracer_provider=tracer_provider)
DSPyInstrumentor().instrument()
class TestTrace(dspy.Signature):
    """
    Perform the task requested to the best of your ability.
    """

    task = dspy.InputField(desc="Task to be performed.")
    answer = dspy.OutputField(desc="The answer.")

In [None]:
tracer = trace_api.get_tracer(__name__)

In [None]:
tracer

Test traces

In [None]:
test_trace = dspy.Predict(TestTrace)
answer = test_trace(task="Say Hello.")
test_trace_cot = dspy.ChainOfThought(TestTrace)
pred = test_trace_cot(task="What would Bodhi from Point Break say is the key to surfing?")
pred.answer

In [None]:
np.random.seed(42)

Using [Synthetic Text to Sql](https://huggingface.co/datasets/gretelai/synthetic_text_to_sql/viewer/default/train?row=1) dataset

In [None]:
dl = DataLoader()
dataset = dl.from_huggingface(
    dataset_name="gretelai/synthetic_text_to_sql", # Dataset name from Huggingface
    fields=("sql_prompt", "sql_context", "sql"), # Fields needed
    input_keys=("sql_prompt", "sql_context") # What our model expects to recieve to generate an output
)
trainset, testset = dl.sample(dataset["train"], n=50), dl.sample(dataset["test"], n=25) # 50 training samples, 25 testing samples


Verify an example of the dataset

In [None]:
sample = dl.sample(dataset=trainset, n=1)[0]

In [None]:
for k,v in sample.items():
    print(f"\n{k.upper()}:\n")
    print(v)

### Create Signature


In [None]:
class TextToSql(dspy.Signature):
    """Transform a natural language query into a SQL query."""

    sql_prompt = dspy.InputField(desc="Natural language query")
    sql_context = dspy.InputField(desc="Context for the query")
    sql = dspy.OutputField(desc="SQL query")

### Inference

Baseline Inference 

In [None]:
generate_sql_query = dspy.Predict(signature=TextToSql)

In [None]:
result = generate_sql_query(sql_prompt=sample["sql_prompt"], sql_context=sample["sql_context"])

for k, v in result.items():
    print(f"\n{k.upper()}:\n")
    print(v)

In [None]:
result.sql

### Metric of Evaluation

In [None]:
class Correctness(dspy.Signature):
    """Assess if the SQL query accurately answers the given natural language query based on the provided context."""

    sql_prompt = dspy.InputField(desc="Natural language query")
    sql_context = dspy.InputField(desc="Context for the query")
    sql = dspy.InputField(desc="SQL query")
    correct = dspy.OutputField(desc="Indicate whether the SQL query correctly answers the natural language query based on the given context", prefix="Yes/No:")


In [None]:
def correctness_metric(example, pred, trace=None):
    sql_prompt, sql_context, sql = example.sql_prompt, example.sql_context, pred.sql

    # current_span = trace_api.get_current_span()
    # current_span.set_attribute(SpanAttributes.LLM_MODEL_NAME, "llama-3")

    correctness = dspy.Predict(Correctness)

    with dspy.context(lm=evaluator_lm):
        correct = correctness(
            sql_prompt=sql_prompt,
            sql_context=sql_context,
            sql=sql,
        )

    score = int(correct.correct == "Yes")

    if trace is not None:
        return score == 1

    return score

### Evaluate Single Data Point


In [None]:
_correctness = correctness_metric(example=sample, pred=result)
print(f"Correct SQL query: {'Yes' if _correctness else 'No'}")

In [None]:
evaluator_lm.inspect_history(n=1)

### To-do: Set up Metadata Tracer

In [None]:
# with tracer.start_as_current_span("testing metadata") as parent:
#     current_span = trace_api.get_current_span()
#     current_span.set_attribute("model", "llama-3-7b-instruct-5q")
#     evaluator_lm("Hello there")

In [None]:
# class QuestionClassifier(dspy.Module):
#     def __init__(self):
#         super().__init__()
#         # [snip]
#     def forward(self, question: str) -> tuple[str,str]:
#         current_span = trace_api.get_current_span()
#         current_span.set_attribute("metadata", "{ 'foo': 'bar' }")
#         # [your dspy code]

### Evaluate Entire Dataset

#### Open Hermes 2(quantized)

Quantized `open-hermes` outperforms `gpt-3.5-turbo-0125`, with `gpt` scoring about 68% and `open-hermes` scoring about 80% OOTB

In [None]:
with dspy.context(lm=lm):
    evaluate = Evaluate(
        devset=testset,
        metric=correctness_metric,
        num_threads=8,
        display_progress=True,
        display_table=5,
    )
    evaluate(generate_sql_query)

#### Llama 3 7b (quantized)

In [None]:
with dspy.context(lm=evaluator_lm):
    evaluate = Evaluate(
        devset=testset,
        metric=correctness_metric,
        num_threads=8,
        display_progress=True,
        display_table=5,
    )
    evaluate(generate_sql_query)

### Optimize for Text-to-SQL

#### Create Program

In [None]:
class TextToSqlProgram(dspy.Module):
    def __init__(self):
        super().__init__()
        self.program = dspy.ChainOfThought(signature=TextToSql)

    def forward(self, sql_prompt, sql_context):
        return self.program(sql_prompt=sql_prompt, sql_context=sql_context)

#### Few Shot

In [None]:
# Execute optimizer -> This only adds a few shots to the prompt
optimizer = LabeledFewShot(k=4)
optimized_program = optimizer.compile(student=TextToSqlProgram(), trainset=trainset)

In [None]:
optimized_program(sql_context=sample["sql_context"], sql_prompt=sample["sql_prompt"])

In [None]:
lm.inspect_history(n=1)

### Evaluate the Optimized Program

In [None]:
evaluate(optimized_program, num_threads=8)

### BootstrapFewShotWithRandomSearch

General Rules for DSPy(at the moment):
- If you have few examples(~10) use `BootstrapFewShot`
- If you have more(~50) use `BootstrapFewShotWithRandomSearch`
- If you have a lot(~300) use `MIPRO`
- If you have been able to use one of these with a large LM (e.g., 7B
parameters or above) and need a very efficient program, compile that
down to a small LM with `BootstrapFinetune`.

In [None]:
optimizer2 = BootstrapFewShotWithRandomSearch(
    metric=correctness_metric, max_bootstrapped_demos=2, num_candidate_programs=8, num_threads=8
)

In [None]:
optimized_program_2 = optimizer2.compile(
    student=TextToSqlProgram(), trainset=trainset, valset=testset
)