Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 20 additions & 14 deletions docs/docs/cheatsheet.md
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran the linter but it seems to have run it against the whole file. It did find lots of fixes.... I can revert if need be but I think this should be done regardless.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no worries, I will push a commit to fix it

Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
sidebar_position: 999
---


# DSPy Cheatsheet

This page will contain snippets for frequent usage patterns.
Expand Down Expand Up @@ -216,7 +215,7 @@ def parse_integer_answer(answer, only_first_line=True):
except (ValueError, IndexError):
# print(answer)
answer = 0

return answer

# Metric Function
Expand Down Expand Up @@ -254,15 +253,17 @@ evaluate_program(your_dspy_program)

## DSPy Optimizers

### LabeledFewShot
### LabeledFewShot

```python
from dspy.teleprompt import LabeledFewShot

labeled_fewshot_optimizer = LabeledFewShot(k=8)
your_dspy_program_compiled = labeled_fewshot_optimizer.compile(student = your_dspy_program, trainset=trainset)
```

### BootstrapFewShot
### BootstrapFewShot

```python
from dspy.teleprompt import BootstrapFewShot

Expand All @@ -272,6 +273,7 @@ your_dspy_program_compiled = fewshot_optimizer.compile(student = your_dspy_progr
```

#### Using another LM for compilation, specifying in teacher_settings

```python
from dspy.teleprompt import BootstrapFewShot

Expand Down Expand Up @@ -306,7 +308,6 @@ loaded_program.load(path=save_path)

Detailed documentation on BootstrapFewShotWithRandomSearch can be found [here](deep-dive/optimizers/bootstrap-fewshot.md).


```python
from dspy.teleprompt import BootstrapFewShotWithRandomSearch

Expand All @@ -315,8 +316,8 @@ fewshot_optimizer = BootstrapFewShotWithRandomSearch(metric=your_defined_metric,
your_dspy_program_compiled = fewshot_optimizer.compile(student = your_dspy_program, trainset=trainset, valset=devset)

```
Other custom configurations are similar to customizing the `BootstrapFewShot` optimizer.

Other custom configurations are similar to customizing the `BootstrapFewShot` optimizer.

### Ensemble

Expand Down Expand Up @@ -363,7 +364,6 @@ for p in finetune_program.predictors():

Detailed documentation on COPRO can be found [here](deep-dive/optimizers/copro.md).


```python
from dspy.teleprompt import COPRO

Expand All @@ -376,7 +376,6 @@ compiled_program_optimized_signature = copro_teleprompter.compile(your_dspy_prog

### MIPRO


```python
from dspy.teleprompt import MIPRO

Expand All @@ -392,7 +391,9 @@ compiled_program_optimized_bayesian_signature = teleprompter.compile(your_dspy_p
Note: detailed documentation can be found [here](deep-dive/optimizers/miprov2.md). `MIPROv2` is the latest extension of `MIPRO` which includes updates such as (1) improvements to instruction proposal and (2) more efficient search with minibatching.

#### Optimizing with MIPROv2

This shows how to perform an easy out-of-the box run with `auto=light`, which configures many hyperparameters for you and performs a light optimization run. You can alternatively set `auto=medium` or `auto=heavy` to perform longer optimization runs. The more detailed `MIPROv2` documentation [here](deep-dive/optimizers/miprov2.md) also provides more information about how to set hyperparameters by hand.

```python
# Import the optimizer
from dspy.teleprompt import MIPROv2
Expand Down Expand Up @@ -422,6 +423,7 @@ evaluate(optimized_program, devset=devset[:])
```

#### Optimizing instructions only with MIPROv2 (0-Shot)

```python
# Import the optimizer
from dspy.teleprompt import MIPROv2
Expand Down Expand Up @@ -449,6 +451,7 @@ optimized_program.save(f"mipro_optimized")
print(f"Evaluate optimized program...")
evaluate(optimized_program, devset=devset[:])
```

### Signature Optimizer with Types

```python
Expand All @@ -466,12 +469,14 @@ compiled_program = optimize_signature(
### KNNFewShot

```python
from dspy.predict import KNN
from sentence_transformers import SentenceTransformer
from dspy import Embedder
from dspy.teleprompt import KNNFewShot
from dspy import ChainOfThought

knn_optimizer = KNNFewShot(KNN, k=3, trainset=trainset)
knn_optimizer = KNNFewShot(k=3, trainset=trainset, vectorizer=Embedder(SentenceTransformer("all-MiniLM-L6-v2").encode))

your_dspy_program_compiled = knn_optimizer.compile(student=your_dspy_program, trainset=trainset, valset=devset)
qa_compiled = knn_optimizer.compile(student=ChainOfThought("question -> answer"))
```

### BootstrapFewShotWithOptuna
Expand All @@ -483,21 +488,22 @@ fewshot_optuna_optimizer = BootstrapFewShotWithOptuna(metric=your_defined_metric

your_dspy_program_compiled = fewshot_optuna_optimizer.compile(student=your_dspy_program, trainset=trainset, valset=devset)
```
Other custom configurations are similar to customizing the `dspy.BootstrapFewShot` optimizer.

Other custom configurations are similar to customizing the `dspy.BootstrapFewShot` optimizer.

## DSPy Assertions

### Including `dspy.Assert` and `dspy.Suggest` statements

```python
dspy.Assert(your_validation_fn(model_outputs), "your feedback message", target_module="YourDSPyModule")

dspy.Suggest(your_validation_fn(model_outputs), "your feedback message", target_module="YourDSPyModule")
```

### Activating DSPy Program with Assertions
### Activating DSPy Program with Assertions

**Note**: To use Assertions properly, you must **activate** a DSPy program that includes `dspy.Assert` or `dspy.Suggest` statements from either of the methods above.
**Note**: To use Assertions properly, you must **activate** a DSPy program that includes `dspy.Assert` or `dspy.Suggest` statements from either of the methods above.

```python
#1. Using `assert_transform_module:
Expand Down
23 changes: 12 additions & 11 deletions dspy/predict/knn.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,30 @@
import numpy as np

from dspy.clients import Embedder
from dspy.primitives import Example


class KNN:
def __init__(self, k: int, trainset: list, vectorizer=None):
def __init__(self, k: int, trainset: list[Example], vectorizer: Embedder):
"""
A k-nearest neighbors retriever that finds similar examples from a training set.

Args:
k: Number of nearest neighbors to retrieve
trainset: List of training examples to search through
vectorizer: Optional dspy.Embedder for computing embeddings. If None, uses sentence-transformers.
vectorizer: The `Embedder` to use for vectorization

Example:
>>> trainset = [dsp.Example(input="hello", output="world"), ...]
>>> knn = KNN(k=3, trainset=trainset)
>>> import dspy
>>> from sentence_transformers import SentenceTransformer
>>>
>>> trainset = [dspy.Example(input="hello", output="world"), ...]
>>> knn = KNN(k=3, trainset=trainset, vectorizer=dspy.Embedder(SentenceTransformer("all-MiniLM-L6-v2").encode))
>>> similar_examples = knn(input="hello")
"""

import dspy.dsp as dsp
import dspy

self.k = k
self.trainset = trainset
self.embedding = vectorizer or dspy.Embedder(dsp.SentenceTransformersVectorizer())
self.embedding = vectorizer
trainset_casted_to_vectorize = [
" | ".join([f"{key}: {value}" for key, value in example.items() if key in example._input_keys])
for example in self.trainset
Expand All @@ -33,5 +35,4 @@ def __call__(self, **kwargs) -> list:
input_example_vector = self.embedding([" | ".join([f"{key}: {val}" for key, val in kwargs.items()])])
scores = np.dot(self.trainset_vectors, input_example_vector.T).squeeze()
nearest_samples_idxs = scores.argsort()[-self.k :][::-1]
train_sampled = [self.trainset[cur_idx] for cur_idx in nearest_samples_idxs]
return train_sampled
return [self.trainset[cur_idx] for cur_idx in nearest_samples_idxs]
28 changes: 26 additions & 2 deletions dspy/teleprompt/knn_fewshot.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,40 @@
import types

from dspy.clients import Embedder
from dspy.predict.knn import KNN
from dspy.primitives import Example
from dspy.teleprompt import BootstrapFewShot

from .teleprompt import Teleprompter


class KNNFewShot(Teleprompter):
def __init__(self, k: int, trainset: list, vectorizer=None, **few_shot_bootstrap_args):
def __init__(self, k: int, trainset: list[Example], vectorizer: Embedder, **few_shot_bootstrap_args):
"""
KNNFewShot is an optimizer that uses an in-memory KNN retriever to find the k nearest neighbors
in a trainset at test time. For each input example in a forward call, it identifies the k most
similar examples from the trainset and attaches them as demonstrations to the student module.

Args:
k: The number of nearest neighbors to attach to the student model.
trainset: The training set to use for few-shot prompting.
vectorizer: The `Embedder` to use for vectorization
**few_shot_bootstrap_args: Additional arguments for the `BootstrapFewShot` optimizer.

Example:
>>> import dspy
>>> from sentence_transformers import SentenceTransformer
>>>
>>> qa = dspy.ChainOfThought("question -> answer")
>>> trainset = [dspy.Example(question="What is the capital of France?", answer="Paris").with_inputs("question"), ...]
>>> knn_few_shot = KNNFewShot(k=3, trainset=trainset, vectorizer=dspy.Embedder(SentenceTransformer("all-MiniLM-L6-v2").encode))
>>> compiled_qa = knn_few_shot.compile(qa)
>>> compiled_qa("What is the capital of Belgium?")
"""
self.KNN = KNN(k, trainset, vectorizer=vectorizer)
self.few_shot_bootstrap_args = few_shot_bootstrap_args

def compile(self, student, *, teacher=None, trainset=None, valset=None):
def compile(self, student, *, teacher=None):
student_copy = student.reset_copy()

def forward_pass(_, **kwargs):
Expand Down
83 changes: 38 additions & 45 deletions tests/predict/test_knn.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,48 @@
import pytest
import numpy as np
import pytest

import dspy
from dspy.utils import DummyVectorizer
from dspy.predict import KNN
from dspy.utils import DummyVectorizer


def mock_example(question: str, answer: str) -> dspy.Example:
"""Creates a mock DSP example with specified question and answer."""
return dspy.Example(question=question, answer=answer).with_inputs("question")


# @pytest.fixture
# def setup_knn():
# """Sets up a KNN instance with a mocked vectorizer for testing."""
# dsp.SentenceTransformersVectorizer = DummyVectorizer
# trainset = [
# mock_example("What is the capital of France?", "Paris"),
# mock_example("What is the largest ocean?", "Pacific"),
# mock_example("What is 2+2?", "4"),
# ]
# knn = KNN(k=2, trainset=trainset)
# return knn


# def test_knn_initialization(setup_knn):
# """Tests the KNN initialization and checks if the trainset vectors are correctly created."""
# knn = setup_knn
# assert knn.k == 2, "Incorrect k value"
# assert len(knn.trainset_vectors) == 3, "Incorrect size of trainset vectors"
# assert isinstance(
# knn.trainset_vectors, np.ndarray
# ), "Trainset vectors should be a NumPy array"


# def test_knn_query(setup_knn):
# """Tests the KNN query functionality for retrieving the nearest neighbors."""
# knn = setup_knn
# query = {"question": "What is 3+3?"} # A query close to "What is 2+2?"
# nearest_samples = knn(**query)
# assert len(nearest_samples) == 2, "Incorrect number of nearest samples returned"
# assert nearest_samples[0].answer == "4", "Incorrect nearest sample returned"


# def test_knn_query_specificity(setup_knn):
# """Tests the KNN query functionality for specificity of returned examples."""
# knn = setup_knn
# query = {
# "question": "What is the capital of Germany?"
# } # A query close to "What is the capital of France?"
# nearest_samples = knn(**query)
# assert len(nearest_samples) == 2, "Incorrect number of nearest samples returned"
# assert "Paris" in [
# sample.answer for sample in nearest_samples
# ], "Expected Paris to be a nearest sample answer"
@pytest.fixture
def setup_knn() -> KNN:
"""Sets up a KNN instance with a mocked vectorizer for testing."""
trainset = [
mock_example("What is the capital of France?", "Paris"),
mock_example("What is the largest ocean?", "Pacific"),
mock_example("What is 2+2?", "4"),
]
return KNN(k=2, trainset=trainset, vectorizer=dspy.Embedder(DummyVectorizer()))


def test_knn_initialization(setup_knn):
"""Tests the KNN initialization and checks if the trainset vectors are correctly created."""
knn = setup_knn
assert knn.k == 2, "Incorrect k value"
assert len(knn.trainset_vectors) == 3, "Incorrect size of trainset vectors"
assert isinstance(knn.trainset_vectors, np.ndarray), "Trainset vectors should be a NumPy array"


def test_knn_query(setup_knn):
"""Tests the KNN query functionality for retrieving the nearest neighbors."""
knn = setup_knn
query = {"question": "What is 3+3?"} # A query close to "What is 2+2?"
nearest_samples = knn(**query)
assert len(nearest_samples) == 2, "Incorrect number of nearest samples returned"
assert nearest_samples[0].answer == "4", "Incorrect nearest sample returned"


def test_knn_query_specificity(setup_knn):
"""Tests the KNN query functionality for specificity of returned examples."""
knn = setup_knn
query = {"question": "What is the capital of Germany?"} # A query close to "What is the capital of France?"
nearest_samples = knn(**query)
assert len(nearest_samples) == 2, "Incorrect number of nearest samples returned"
assert "Paris" in [sample.answer for sample in nearest_samples], "Expected Paris to be a nearest sample answer"
Loading
Loading