# Text Classification with DSPy

This guide will walk you through building a text classification pipeline using DSPy and demonstrate how DSPy’s automatic prompt optimization can enhance text classification performance.

Before we get started, let's install the necesssary packages.

In [0]:
!pip install -q dspy datasets
%restart_python

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
ydata-profiling 4.5.1 requires pydantic<2,>=1.8.1, but you have pydantic 2.9.2 which is incompatible.
opentelemetry-api 1.27.0 requires importlib-metadata<=8.4.0,>=6.0, but you have importlib-metadata 8.5.0 which is incompatible.
mlflow-skinny 2.15.1 requires importlib-metadata!=4.7.0,<8,>=3.7.0, but you have importlib-metadata 8.5.0 which is incompatible.
langchain 0.1.20 requires tenacity<9.0.0,>=8.1.0, but you have tenacity 9.0.0 which is incompatible.
langchain-core 0.1.52 requires packaging<24.0,>=23.2, but you have packaging 24.1 which is incompatible.
langchain-core 0.1.52 requires tenacity<9.0.0,>=8.1.0, but you have tenacity 9.0.0 which is incompatible.
langchain-community 0.0.38 requires tenacity<9.0.0,>=8.1.0, but you have tenacity 9.0.0 which is incompatible.
composer 0.23.5 requires importlib-met

In this guide we will use OpenAI gpt-4o-mini model, let's set up the OpenAI credentials.

In [0]:

import getpass
import os

openai_key = getpass.getpass("Please enter your OpenAI API key: ")
os.environ["OPENAI_API_KEY"] = openai_key

Please enter your OpenAI API key:  [REDACTED]

Import dependencies.

In [0]:
import dspy
import random
from datasets import load_dataset
from typing import Literal

## Prepare The Dataset

In this guide, we will use the [Banking77 dataset](https://huggingface.co/datasets/PolyAI/banking77), a 77-class text classification dataset focused on banking topics. Let's load the dataset from the HuggingFace hub.

In [0]:
trainset_hf = load_dataset("PolyAI/banking77", split="train", trust_remote_code=True)
valset_hf = load_dataset("PolyAI/banking77", split="test", trust_remote_code=True)



Downloading builder script:   0%|          | 0.00/7.17k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/5.89k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/9.78k [00:00<?, ?B/s]



Downloading data:   0%|          | 0.00/158k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/51.1k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/10003 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/3080 [00:00<?, ? examples/s]



### Convert Dataset to DSPy Dataset

In order to use the dataset in DSPy optimization or evaluation phase, we need to convert the dataset into DSPy dataset. The easiest way is converting the dataset into a list of `dspy.Example`, for a complete guide on how to prepare DSPy dataset, please refer to [this guide](https://dspy.ai/building-blocks/4-data/).

Each training data will have 3 fields: `text`, `hint` and `label`, with `hint` and `label` both set as the golden label's text representation, e.g., "virtual_card_not_working". Each validation data will only have `text` and `label`.

In [0]:
CLASSES = trainset_hf.features["label"].names

trainset = []
valset = []

for text, label in zip(trainset_hf["text"], trainset_hf["label"]):
    trainset.append(dspy.Example(text=text, hint=CLASSES[label], label=CLASSES[label]).with_inputs("text", "hint"))

for text, label in zip(valset_hf["text"], valset_hf["label"]):
    valset.append(dspy.Example(text=text, label=CLASSES[label]).with_inputs("text"))

# Shuffle the dataset.
random.Random(0).shuffle(trainset)
random.Random(0).shuffle(valset)

## Build Your Model

Now let's build our text classifier. In this guide we will build a simple model with a single `dspy.ChainOfThoughtWithHint` module. We enforce the output label to be one of the predefined `CLASSES`, by adding `type_=Literal[tuple(CLASSES)]` to the sigature.

In [0]:
signature = dspy.Signature("text -> label").with_updated_fields('label', type_=Literal[tuple(CLASSES)])
classifier = dspy.ChainOfThoughtWithHint(signature)

Let's not forget configuring the LM we want to use, in this guide we will OpenAI's gpt-4o-mini.

In [0]:
dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini"))

Let's try out our classifier on sample input.

In [0]:
classifier(**valset[0].inputs())

## Optimize your Model

Now let's start optimizing the model！

### Define Scoring Metric and Run Evaluation

Before optimization, let's run an evaluation to see how our model performs without any sort of optimization. For text classification, the metric is simply accuracy.

In [0]:
def accuracy(example, pred, trace=None):
    try:
        return int(example.label == pred.label)
    except:
        return 0

In [0]:
from dspy.evaluate import Evaluate

# Set up the evaluator, and we only use the first 100 validation data for evaluation.
evaluator = Evaluate(devset=valset[:100], num_threads=5, display_progress=True, display_table=5)

In [0]:
eval_score = evaluator(classifier, metric=accuracy)

  0%|          | 0/100 [00:00<?, ?it/s]Average Metric: 1 / 1  (100.0):   0%|          | 0/100 [00:00<?, ?it/s]Average Metric: 2 / 2  (100.0):   1%|          | 1/100 [00:00<00:00, 898.33it/s]Average Metric: 3 / 3  (100.0):   2%|▏         | 2/100 [00:00<00:00, 755.73it/s]Average Metric: 4 / 4  (100.0):   3%|▎         | 3/100 [00:00<00:00, 755.00it/s]Average Metric: 5 / 5  (100.0):   4%|▍         | 4/100 [00:00<00:00, 780.84it/s]Average Metric: 6 / 6  (100.0):   5%|▌         | 5/100 [00:00<00:00, 844.84it/s]Average Metric: 7 / 7  (100.0):   6%|▌         | 6/100 [00:00<00:00, 851.26it/s]Average Metric: 8 / 8  (100.0):   7%|▋         | 7/100 [00:00<00:00, 870.78it/s]Average Metric: 9 / 9  (100.0):   8%|▊         | 8/100 [00:00<00:00, 902.27it/s]Average Metric: 10 / 10  (100.0):   9%|▉         | 9/100 [00:00<00:00, 877.04it/s]Average Metric: 11 / 11  (100.0):  10%|█         | 10/100 [00:00<00:00, 881.95it/s]Average Metric: 11 / 12  (91.7):  11%|█         | 11/100 [00:00<00:00, 8

Unnamed: 0,text,example_label,rationale,pred_label,accuracy
0,Please delete my account.,terminate_account,"terminate_account. We are processing a request to delete the user's account, which falls under the category of account termination.",terminate_account,✔️ [1]
1,There's a debit on my account that I didn't do.,direct_debit_payment_not_recognised,"This situation describes a debit on the account that the user did not authorize, which suggests a potential issue with a transaction that may need...",request_refund,
2,Where do I order a virtual card from?,getting_virtual_card,"The user is inquiring about the process of ordering a virtual card, which falls under the category of obtaining a virtual card. Therefore, the appropriate...",getting_virtual_card,✔️ [1]
3,Where is my card accepted?,card_acceptance,"card_acceptance. We are addressing a query about where the card can be used, which pertains to its acceptance at various locations or merchants.",card_acceptance,✔️ [1]
4,"If I want a physical card, do I have to pay anything?",order_physical_card,"The question is about whether there is a fee associated with obtaining a physical card, which relates to the process of getting a physical card....",get_physical_card,


### Configure Optimizer and Run Optimization

Let's configure our optimizer and run the optimizaiton.

In [0]:
from dspy.teleprompt import MIPROv2


# Initialize optimizer
optimizer = MIPROv2(
   metric=accuracy,
   num_candidates=12,
   init_temperature=0.3,
   verbose=False,
   num_threads=4,
)


# Optimize program
print(f"Optimizing program with MIPRO V2...")
optimized_classifier = optimizer.compile(
   classifier.deepcopy(),
   trainset=trainset[:200],
   valset=valset[:100],
   max_bootstrapped_demos=5,
   max_labeled_demos=5,
   num_trials=25,
   minibatch_size=20,
   minibatch_full_eval_steps=5,
)

Optimizing program with MIPRO V2...


[93m[1mProjected Language Model (LM) Calls[0m

Based on the parameters you have set, the maximum number of LM calls is projected as follows:

[93m- Prompt Generation: [94m[1m10[0m[93m data ... y

2024/11/12 22:21:37 INFO dspy.teleprompt.mipro_optimizer_v2: 
==> STEP 1: BOOTSTRAP FEWSHOT EXAMPLES <==
2024/11/12 22:21:37 INFO dspy.teleprompt.mipro_optimizer_v2: These will be used as few-shot example candidates for our program and for creating instructions.

2024/11/12 22:21:37 INFO dspy.teleprompt.mipro_optimizer_v2: Bootstrapping N=12 sets of demonstrations...


Bootstrapping set 1/12
Bootstrapping set 2/12
Bootstrapping set 3/12



  0%|          | 0/200 [00:00<?, ?it/s][A  2%|▎         | 5/200 [00:00<00:00, 610.28it/s]


Bootstrapped 5 full traces after 5 examples for up to 1 rounds, amounting to 5 attempts.
Bootstrapping set 4/12



  0%|          | 0/200 [00:00<?, ?it/s][A  2%|▏         | 3/200 [00:00<00:00, 679.20it/s]


Bootstrapped 3 full traces after 3 examples for up to 1 rounds, amounting to 3 attempts.
Bootstrapping set 5/12



  0%|          | 0/200 [00:00<?, ?it/s][A  0%|          | 1/200 [00:00<00:00, 627.61it/s]


Bootstrapped 1 full traces after 1 examples for up to 1 rounds, amounting to 1 attempts.
Bootstrapping set 6/12



  0%|          | 0/200 [00:00<?, ?it/s][A  2%|▏         | 3/200 [00:00<00:00, 426.93it/s]


Bootstrapped 3 full traces after 3 examples for up to 1 rounds, amounting to 3 attempts.
Bootstrapping set 7/12



  0%|          | 0/200 [00:00<?, ?it/s][A  0%|          | 1/200 [00:00<00:00, 666.93it/s]


Bootstrapped 1 full traces after 1 examples for up to 1 rounds, amounting to 1 attempts.
Bootstrapping set 8/12



  0%|          | 0/200 [00:00<?, ?it/s][A  0%|          | 1/200 [00:00<00:00, 666.40it/s]


Bootstrapped 1 full traces after 1 examples for up to 1 rounds, amounting to 1 attempts.
Bootstrapping set 9/12



  0%|          | 0/200 [00:00<?, ?it/s][A  2%|▏         | 3/200 [00:00<00:00, 701.74it/s]


Bootstrapped 3 full traces after 3 examples for up to 1 rounds, amounting to 3 attempts.
Bootstrapping set 10/12



  0%|          | 0/200 [00:00<?, ?it/s][A  1%|          | 2/200 [00:00<00:00, 639.33it/s]


Bootstrapped 2 full traces after 2 examples for up to 1 rounds, amounting to 2 attempts.
Bootstrapping set 11/12



  0%|          | 0/200 [00:00<?, ?it/s][A  0%|          | 1/200 [00:00<00:00, 666.08it/s]


Bootstrapped 1 full traces after 1 examples for up to 1 rounds, amounting to 1 attempts.
Bootstrapping set 12/12



  0%|          | 0/200 [00:00<?, ?it/s][A  2%|▏         | 4/200 [00:00<00:00, 629.80it/s]
2024/11/12 22:21:37 INFO dspy.teleprompt.mipro_optimizer_v2: 
==> STEP 2: PROPOSE INSTRUCTION CANDIDATES <==
2024/11/12 22:21:37 INFO dspy.teleprompt.mipro_optimizer_v2: We will use the few-shot examples from the previous step, a generated dataset summary, a summary of the program code, and a randomly selected prompting tip to propose instructions.
2024/11/12 22:21:37 INFO dspy.teleprompt.mipro_optimizer_v2: 
Proposing instructions...

2024/11/12 22:21:37 INFO dspy.teleprompt.mipro_optimizer_v2: Proposed Instructions for Predictor 0:

2024/11/12 22:21:37 INFO dspy.teleprompt.mipro_optimizer_v2: 0: Given the fields `text`, produce the fields `label`.

2024/11/12 22:21:37 INFO dspy.teleprompt.mipro_optimizer_v2: 1: Analyze the provided user inquiry in the `text` field and classify it into one of the predefined categories represented by the `label` field. Additionally, generate a detailed rationa

Bootstrapped 4 full traces after 4 examples for up to 1 rounds, amounting to 4 attempts.


  0%|          | 0/100 [00:00<?, ?it/s]Average Metric: 1 / 1  (100.0):   0%|          | 0/100 [00:00<?, ?it/s]Average Metric: 2 / 2  (100.0):   1%|          | 1/100 [00:00<00:00, 1216.45it/s]Average Metric: 3 / 3  (100.0):   2%|▏         | 2/100 [00:00<00:00, 1500.91it/s]Average Metric: 4 / 4  (100.0):   3%|▎         | 3/100 [00:00<00:00, 1587.35it/s]Average Metric: 5 / 5  (100.0):   4%|▍         | 4/100 [00:00<00:00, 1616.30it/s]Average Metric: 5 / 6  (83.3):   5%|▌         | 5/100 [00:00<00:00, 1682.03it/s] Average Metric: 6 / 7  (85.7):   6%|▌         | 6/100 [00:00<00:00, 1732.59it/s]Average Metric: 6 / 8  (75.0):   7%|▋         | 7/100 [00:00<00:00, 1763.37it/s]Average Metric: 7 / 9  (77.8):   8%|▊         | 8/100 [00:00<00:00, 1823.12it/s]Average Metric: 7 / 10  (70.0):   9%|▉         | 9/100 [00:00<00:00, 1825.82it/s]Average Metric: 8 / 11  (72.7):  10%|█         | 10/100 [00:00<00:00, 1491.04it/s]Average Metric: 9 / 12  (75.0):  11%|█         | 11/100 [00:00<00:00,

Let's see the performance of the optimized classifier. 

In [0]:
# Launch evaluation.
eval_score = evaluator(optimized_classifier, metric=accuracy)

  0%|          | 0/100 [00:00<?, ?it/s]Average Metric: 1 / 1  (100.0):   0%|          | 0/100 [00:00<?, ?it/s]Average Metric: 2 / 2  (100.0):   1%|          | 1/100 [00:00<00:00, 1224.26it/s]Average Metric: 2 / 3  (66.7):   2%|▏         | 2/100 [00:00<00:00, 1407.48it/s] Average Metric: 3 / 4  (75.0):   3%|▎         | 3/100 [00:00<00:00, 1444.82it/s]Average Metric: 4 / 5  (80.0):   4%|▍         | 4/100 [00:00<00:00, 1509.83it/s]Average Metric: 5 / 6  (83.3):   5%|▌         | 5/100 [00:00<00:00, 1551.03it/s]Average Metric: 6 / 7  (85.7):   6%|▌         | 6/100 [00:00<00:00, 1566.60it/s]Average Metric: 7 / 8  (87.5):   7%|▋         | 7/100 [00:00<00:00, 1590.65it/s]Average Metric: 8 / 9  (88.9):   8%|▊         | 8/100 [00:00<00:00, 1601.87it/s]Average Metric: 9 / 10  (90.0):   9%|▉         | 9/100 [00:00<00:00, 1608.45it/s]Average Metric: 10 / 11  (90.9):  10%|█         | 10/100 [00:00<00:00, 1613.13it/s]Average Metric: 10 / 12  (83.3):  11%|█         | 11/100 [00:00<00:00, 

Unnamed: 0,text,example_label,rationale,pred_label,accuracy
0,Please delete my account.,terminate_account,"The user is requesting to delete their account, which directly relates to the action of terminating an account.",terminate_account,✔️ [1]
1,There's a debit on my account that I didn't do.,direct_debit_payment_not_recognised,"The user is reporting an unauthorized debit on their account, which suggests that they do not recognize a transaction that has occurred. This situation is...",card_payment_not_recognised,
2,Where do I order a virtual card from?,getting_virtual_card,"The user is asking about the process to obtain a virtual card, which relates to the action of getting a virtual card.",getting_virtual_card,✔️ [1]
3,Where is my card accepted?,card_acceptance,"The user is asking about the locations or types of merchants where their card can be used, which pertains to card acceptance.",card_acceptance,✔️ [1]
4,"If I want a physical card, do I have to pay anything?",order_physical_card,"The user is asking about potential fees associated with obtaining a physical card, which relates to the costs involved in card issuance.",get_physical_card,


In our run, the accuracy on validation set got boosted from 66% to 77% bu optimization!

## Conclusion 

Congratulations on finishing reading the guide! To further improve the classification accuracy, you can try using a more powerful model like gpt-4o, having the optimizer try more candidates with more rounds, or applying finetuning. Keep exploring!