<a href="https://colab.research.google.com/github/withpi/cookbook-withpi/blob/main/colabs/Query_Classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<a href="https://withpi.ai"><img src="https://play.withpi.ai/logo/logoFullBlack.svg" width="240"></a>

<a href="https://code.withpi.ai"><font size="4">Documentation</font></a>

<a href="https://play.withpi.ai"><font size="4">Technique Catalog</font></a>

# Query Classification

Query classification is a key primitive for AI systems, and allows for powerful RAG pipelines.

A "query" could be:
  * a search term in an e-commerce search (eg. "levis 501")
  * a message from a customer to a support team
  * a prompt to an AI assistant

The "classification" part could involve:
  * deciding whether retrieval is needed
  * deciding what data sources to use
  * deciding what pre- or post-processing to apply

This Colab will take you through classifying queries for your use case, and improving query classification through minimal work. A small amount of teaching the model through **data labeling** will make it better at classifying queries for your use case, and enable you to use smaller and faster models for your use case.

## Install and initialize SDK

Connect to a regular CPU Python 3 runtime.  You won't need GPUs for this notebook.

You'll need a WITHPI_API_KEY from https://play.withpi.ai.  Add it to your notebook secrets (the key symbol) on the left.

Run the cell below to install packages and load the SDK

In [None]:
%%capture

import os
from google.colab import files, userdata

# Load the notebook secret into the environment so the Pi Client can access it.
os.environ["WITHPI_API_KEY"] = userdata.get('WITHPI_API_KEY')

%pip install withpi litellm httpx datasets jinja2 tqdm requests

# Import a bunch of useful libraries for later.
from concurrent.futures import ThreadPoolExecutor
from collections import defaultdict
import json
from pathlib import Path
import re

import datasets
import httpx
import litellm
import jinja2
from tqdm.notebook import tqdm
from withpi import PiClient
from IPython.display import display
import pandas as pd

client = PiClient()


def generate(system: str, user: str, model: str) -> str:
    """generate passes the provided system and user prompts into the given model
    via LiteLLM"""
    messages = [
        {"content": system, "role": "system"},
        {"content": user, "role": "user"},
    ]
    return litellm.completion(model=model, messages=messages).choices[0].message.content


class printer(str):
    """printer makes strings with embedded newlines print more nicely"""

    def __repr__(self):
        return self


def print_response(response: str):
    """print_response pretty-prints an LLM response, respecting newlines"""
    display(printer(response))


def print_scores(pi_scores):
    """print_scores pretty-prints a Pi Score response as a table."""
    for dimension_name, dimension_scores in pi_scores.dimension_scores.items():
        print(f"{dimension_name}: {dimension_scores.total_score}")
        for (
            subdimension_name,
            subdimension_score,
        ) in dimension_scores.subdimension_scores.items():
            print(f"\t{subdimension_name}: {subdimension_score}")
        print("\n")
    print("---------------------")
    print(f"Total score: {pi_scores.total_score}")


def save_file(filename: str, model: str):
    """save_file offers to download the model with the given filename"""
    Path(filename).write_text(model)
    files.download(filename)



def load_and_split_dataset(url: str) -> datasets.DatasetDict:
    """load_and_split_dataset pulls in the Parquet file at url and does a 90/10 split"""
    return datasets.load_dataset(
        "parquet", data_files=url, split="train"
    ).train_test_split(test_size=0.1)


def do_bulk_inference(dataset, system, model):
    """do_bulk_inference performs inference on the 'input' column of dataset, using
    the provided system prompt.  The model identified will be used via LiteLLM"""

    def do_generate(user, pbar):
        result = generate(system, user, model)
        pbar.update(1)
        return result

    futures = []
    pbar = tqdm(total=len(dataset))
    with ThreadPoolExecutor(max_workers=4) as executor:
        for row in dataset:
            futures.append(executor.submit(do_generate, row["input"], pbar))
    return [future.result() for future in futures]


def do_bulk_templated_inference(dataset, optimized, model):
    """do_bulk_templated_inference performs inference on the 'input' column of dataset,
    using the provided optimized prompt.  It should be a Jinja2 template as returned
    by DSPy"""
    prompt_template = jinja2.Template(optimized)
    result_extractor = re.compile(
        r".*\[\[ ## response ## \]\](.*)\[\[ ## completed ## \]\]", re.DOTALL
    )

    def do_generate(prompt: str, pbar) -> str:
        messages = json.loads(prompt_template.render(input=prompt))
        result = (
            litellm.completion(model=model, messages=messages)
            .choices[0]
            .message.content
        )

        pbar.update(1)
        return result_extractor.match(result).group(1)

    futures = []
    pbar = tqdm(total=len(dataset))
    with ThreadPoolExecutor(max_workers=4) as executor:
        for row in dataset:
            futures.append(executor.submit(do_generate, row["input"], pbar))
    return [future.result() for future in futures]


def generate_table(
    job_id: str, training_data: dict, is_done: bool, additional_columns: dict[str, str]
):
    """Generate a training progress table dynamically."""
    data_dict = {}
    for header in ["Step", "Epoch", "Learning Rate", "Training Loss", "Eval Loss"]:
        data_dict[header] = []
    for header in additional_columns.keys():
        data_dict[header] = []

    for step, data in training_data.items():
        data_dict["Step"].append(step)
        for header, key in [
            ("Epoch", "epoch"),
            ("Learning Rate", "learning_rate"),
            ("Training Loss", "loss"),
            ("Eval Loss", "eval_loss"),
        ]:
            data_dict[header].append(data.get(key, "X"))
        for header, key in additional_columns.items():
            data_dict[header].append(data.get(key, "X"))

    if not is_done:
        data_dict["Step"].append("...")
        for header in ["Epoch", "Learning Rate", "Training Loss", "Eval Loss"]:
            data_dict[header].append("")
        for header in additional_columns.keys():
            data_dict[header].append("")

    return pd.DataFrame(data_dict)


def stream_response(job_id: str, method, additional_columns: dict[str, str]):
    """stream_response streams messages from the provided method

    method should be a Pi client object with `retrieve` and `stream_messages`
    endpoints.  This is primarily for convenience."""

    print(f"Training Status for {job_id}")

    training_data = defaultdict(dict)
    is_log_console = False

    stream_output = display(
        generate_table(
            job_id, training_data, is_done=False, additional_columns=additional_columns
        ),
        display_id=True,
    )

    while True:
        response = method.retrieve(job_id=job_id)
        if (response.state != "QUEUED") and (response.state != "RUNNING"):
            if response.state == "DONE" and not is_log_console:
                for line in response.detailed_status:
                    try:
                        data_dict = json.loads(line)
                        training_data[data_dict["step"]].update(data_dict)
                    except Exception:
                        pass
                stream_output.update(
                    generate_table(
                        job_id,
                        training_data,
                        is_done=True,
                        additional_columns=additional_columns,
                    )
                )
            return response

        with method.with_streaming_response.stream_messages(
            job_id=job_id, timeout=None
        ) as response:
            is_done = False
            for line in response.iter_lines():
                if line == "DONE":
                    is_done = True
                try:
                    data_dict = json.loads(line)
                    training_data[data_dict["step"]].update(data_dict)
                except Exception:
                    pass
                stream_output.update(
                    generate_table(
                        job_id,
                        training_data,
                        is_done,
                        additional_columns=additional_columns,
                    )
                )
                is_log_console = True


###  QUERY CLASSIFICATION  ###

import requests
from io import BytesIO
import tarfile

# Caching jobs allows us to avoid repeating job starts.
if 'job_id_by_hash' not in locals():
    job_id_by_hash = {}

def generate_inputs(
    queries: list[str],
    exploration_mode: str = "ADVENTUROUS",
    application_description: str = "",
    num_examples: int = 10,
) -> list[str]:
    job_hash = hash(json.dumps([
        num_examples,
        exploration_mode,
        queries,
        application_description
    ]))

    if job_hash not in job_id_by_hash:
        print(f"Generating {num_examples} new examples...")
        data_generation_status = client.data.generate.start_job(
            application_description=application_description,
            num_inputs_to_generate=num_examples,
            num_shots=10,
            exploration_mode=exploration_mode,
            seeds=queries,
        )
        job_id_by_hash[job_hash] = data_generation_status.job_id

    job_id = job_id_by_hash[job_hash]
    job_status = client.data.generate.retrieve(job_id)
    if job_status.state == "ERROR":
        raise RuntimeError(f"Generation job failed")
    if job_status.state == "DONE":
        return job_status.data

    with (
        client.data.with_streaming_response.generate.stream_messages(
            job_id
        ) as response
    ):
        for line in response.iter_lines():
            print(f"Generation progress: {line}")

    job_status = client.data.generate.retrieve(
        job_id
    )
    assert job_status.state == "DONE"
    return job_status.data


def download_classifier(job_id: str, path: str, serving_id: int | None = 0) -> Path:
    if not serving_id:
        # TODO: explicitly use the lowest eval loss here by default.
        #   this currently happens to be the zeroeth element, but we can
        #   avoid relying on it!
        #
        # job_status = client.search.query_classifier.distill.retrieve(job_id=job_id)
        serving_id = 0
    url = json.loads(client.search.query_classifier.distill.download(job_id=job_id, serving_id=serving_id))
    return download_and_extract(url, path)


def download_and_extract(url: str, output_dir: str) -> Path:
    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)
    response = requests.get(url)
    response.raise_for_status()
    tar_bytes = BytesIO(response.content)
    with tarfile.open(fileobj=tar_bytes, mode="r:gz") as tar:
        tar.extractall(path=output_dir)
    return output_path


# Detecting spam in social media DMs
For our task, let's imagine that we need to detect spam messages in DMs sent between people on a social media platform.

We'll lay out:
* **input description** - describing what the inputs are to help calibrate our data generator
* **classes** - the classes we want to split it into, with descriptions to help calibrate the query classifier
* **examples** - queries and classes that we know to be correct

In [None]:
# This is describing the kind of inputs we'll be dealing with.
input_description = "direct messages received by users on a social media network"

# These are the classes that we want to classify our queries into.
# In this example, the queries are direct messages that users receive,
# which we want to classify as spam or legitimate communication.
classes = [
    {"label": "Spam", "description": "Unsolicited messages promoting products, services, or containing suspicious links"},
    {"label": "Not Spam", "description": "Legitimate messages from real users with genuine communication intent"}
]

# Here are some examples that we can provide to the zero-shot
# query classification model, and later to our trained model.
examples = [
    {"label": "Spam", "text": "CONGRATULATIONS! You've been selected to receive a free luxury watch! Click here to claim: bit.ly/2x4fakelink"},
    {"label": "Spam", "text": "Hello dear, I noticed your profile and think you have great investment potential. I manage a portfolio worth $25M and can help you earn 300% returns monthly. Message me for details."},
    {"label": "Spam", "text": "Hi there! 👋 I'm selling handmade bracelets at 70% discount today only! Check my store: handmade-jewelry.fakesite.co/discount"},
    {"label": "Not Spam", "text": "Hey, are we still meeting for coffee tomorrow at 2? Let me know if you need to reschedule."},
    {"label": "Not Spam", "text": "Hi Sarah! Just wanted to thank you for your help with the project yesterday. Your input really made a difference. Hope we can collaborate again soon!"},
    {"label": "Not Spam", "text": "Dude, did you see that game last night? Incredible comeback in the final quarter! We should watch the next one together."},
    {"label": "Spam", "text": "URGENT: Your account will be suspended in 24 hours! Verify your identity now: security-verify.suspiciouslink.net"},
    {"label": "Spam", "text": "Hello friend! I am international business person with opportunity for you. I have $5,000,000 USD to transfer and need your assistance. Reply for 30% commission."},
    {"label": "Spam", "text": "Want to lose 20 pounds in just 1 week? Our new miracle supplement guarantees results or money back! Order now with code SLIM50 for 50% off!"},
    {"label": "Not Spam", "text": "The photos from our hiking trip are finally uploaded! Here's the link to the shared album I promised: photos.legitimatesite.com/album/hiking2023"},
    {"label": "Not Spam", "text": "Mom asked me to check if you're still coming for dinner on Sunday. She's planning the menu and wants to know if you're bringing anyone."},
    {"label": "Not Spam", "text": "Hi Professor, I'm emailing about the assignment due next week. Could I get a short extension? I've been dealing with some health issues and could use a couple extra days."},
    {"label": "Spam", "text": "⚠️ Your profile has received 38 secret admirers today! Upgrade your account to see who likes you: premium-vip-access.scamsite.co"},
    {"label": "Spam", "text": "Exclusive opportunity for selected members only! Join our private cryptocurrency group and learn how I made $50,000 in just one week trading! Limited spots available!"},
    {"label": "Spam", "text": "Your package delivery failed! Reschedule delivery by confirming your details: delivery-service.maliciouslink.com"},
    {"label": "Not Spam", "text": "Hey, can you send me that recipe you mentioned yesterday? I want to try making it this weekend."},
    {"label": "Not Spam", "text": "The concert tickets just went on sale! Should I grab two for us for the Friday show, or does Saturday work better for you?"},
    {"label": "Not Spam", "text": "Just landed! The flight was delayed but I'm finally here. Can you still pick me up from the airport or should I grab a taxi?"},
]

# Let's classify and see how our classifications turn out!
response = client.search.query_classifier.classify(
    queries=[
        "Hey! I noticed your photography skills and I'm impressed! I run a modeling agency and think you'd be perfect for our upcoming campaign. DM me for details!",
        "Hi Alex, are we still on for lunch tomorrow at the usual place? I might be running about 10 minutes late.",
        "CONGRATULATIONS! You've won a $1000 gift card! Claim now: www.totallylegitprizes.co/claim-now",
        "Can you send me the notes from yesterday's meeting? I had to leave early and missed the last part.",
        "FREE iPhone 13 Pro! You've been selected in our monthly giveaway! Click here to claim within 24 hours: bit.ly/claim-prize"
    ],
    classes=classes,
    mode="generative",
    examples=examples,
)

for classification in response.results:
    print(f"Query: {repr(classification.query)}")
    print(f"Classification: {repr(classification.prediction)}")
    print()


Query: "Hey! I noticed your photography skills and I'm impressed! I run a modeling agency and think you'd be perfect for our upcoming campaign. DM me for details!"
Classification: 'Spam'

Query: 'Hi Alex, are we still on for lunch tomorrow at the usual place? I might be running about 10 minutes late.'
Classification: 'Not Spam'

Query: "CONGRATULATIONS! You've won a $1000 gift card! Claim now: www.totallylegitprizes.co/claim-now"
Classification: 'Spam'

Query: "Can you send me the notes from yesterday's meeting? I had to leave early and missed the last part."
Classification: 'Not Spam'

Query: "FREE iPhone 13 Pro! You've been selected in our monthly giveaway! Click here to claim within 24 hours: bit.ly/claim-prize"
Classification: 'Spam'



## Distillation

Distillation is the process of training a small model with the use of a big model. Distillation is a powerful tool to create fast classifiers, that can be used for decision-making in latency-sensitive scenarios such as RAG pipelines and AI agents.

In distillation, we refer to the big model as the **teacher model** and the small model as the **student model**.  In order to use the teacher model, you also need a **pool of unlabeled data**.

In this case:

* The teacher model is the query classification model from the Pi SDK.
* The student model is a ModernBERT model, a small and capable model for this task.
* The unlabeled data are queries that we're going to generate using the Pi SDK's "generate from seeds" synthetic data generator. When we give it our `queries`, it will generate more similar queries.



### Generating queries

First, let's generate some queries. To illustrate the process quickly, we'll generate only 100 queries. In a realistic run, you will have the best results generating at least 500 queries.

**Tip:** If you have a pool of real user queries, you will be able to distill a better model than using only generated inputs, since it will better help the model learn the kind of queries it will see.

In [None]:
# We'll generate 100 training samples in total, and split them
# equally across the different classes to yield a balanced
# training set.
examples_df = pd.DataFrame(examples)
groups = examples_df.groupby('label')
samples_per_group = 100 // len(groups)
classes_by_label = {c["label"]: c["description"] for c in classes}

generated_queries = []
for label, label_examples in groups:
    # We can use the Pi SDK to generate inputs given some examples of
    # input that we'd like more of. A description of the application
    # helps the input generator create inputs that are relevant to our use case.
    new_inputs = generate_inputs(
        queries=[example["text"] for example in label_examples.to_records()],
        application_description=f"""
            Queries should match the input description and belong to
            the following class:

            <class_name>{label}</class_name>
            <class_description>{classes_by_label[label]}</class_description>
            <input_description>{input_description}</input_description>
            """,
        num_examples=samples_per_group,
        exploration_mode='ADVENTUROUS'
    )

    print(f"Adding inputs for label {label}:")
    for input in new_inputs:
        print(f"- {input}")

    generated_queries.extend(new_inputs)



Generating 50 new examples...
Generation progress: LAUNCHING
Generation progress: RUNNING
Generation progress: [INFO] Progress=> Good: 0/50 Bad: 0 Similar: 0
Generation progress: [INFO] Generation LLM temperature fixed or updated to 1.43
Generation progress: [INFO] Data Generation Ongoing => Good Inputs: 2/50. Bad Inputs: 1. Similar Inputs: 0
Generation progress: [INFO] Progress=> Good: 2/50 Bad: 1 Similar: 0
Generation progress: [INFO] Generation LLM temperature fixed or updated to 1.23
Generation progress: [INFO] Data Generation Ongoing => Good Inputs: 2/50. Bad Inputs: 6. Similar Inputs: 0
Generation progress: [INFO] Progress=> Good: 2/50 Bad: 6 Similar: 0
Generation progress: [INFO] With similarity score: 0.80, New input: 'Amy, I'm setting up a shared doc for the team project notes. Can you review it and add your comments today?' too similar to: 
Generation progress:  'Hi Sarah! Just wanted to thank you for your help with the project yesterday. Your input really made a difference. 

### Labeling queries
Now we'll use the teacher model to generate good labels for these queries, going from a **pool of inputs** to **labeled examples** that we can use for training.

In [None]:
from traceback import print_exception
import pandas as pd

# This table allows us to monitor labeling progress.
table = display(pd.DataFrame(), display_id=True)

# We're going to take 5 examples at a time, and use the query
# classifier to label them. The examples that we give to the
# classifier will help it make good labeling decisions.
batch_size = 5
training_examples = []
for i in range(0, len(generated_queries), batch_size):
    batch = generated_queries[i:i+batch_size]
    response = client.search.query_classifier.classify(
        queries=batch,
        classes=classes,
        examples=examples,
        mode="generative",
    )

    for classification in response.results:
        training_examples.append({
            "llm_input": classification.query,
            "llm_output": classification.prediction,
        })
        table.update(pd.DataFrame(training_examples))



Unnamed: 0,llm_input,llm_output
0,"Hey Jamie, wanted to share the venue details f...",Not Spam
1,Fantastic photos from your trip to Paris—inspi...,Not Spam
2,"Hey Chloe, are you still selling any plants? I...",Not Spam
3,Really enjoyed the article you sent me! Do you...,Not Spam
4,Where did you say that practice group was meet...,Not Spam
...,...,...
95,Expand horizons travel absolutely free! Detail...,Spam
96,🚀 Maximize your followers today! Click the lin...,Spam
97,💻 Special Offer Just For You: Get premium soft...,Spam
98,📢 Huge discounts at the biggest sale of the ye...,Spam


### Training a tiny classifier

Lastly, let's train a classifier using the training examples that we've gathered through generated inputs and zero-shot classification.

In [None]:
# Kick off the training job.
job_status = client.search.query_classifier.distill.start_job(
    examples=training_examples,
    base_model="MODERNBERT_BASE",
    learning_rate=5e-5,
    num_train_epochs=3,
)
job_id = job_status.job_id

In [None]:
# Monitor the training.
response = stream_response(
    job_status.job_id,
    client.search.query_classifier.distill,
    additional_columns={"F1": "eval_f1"},
)

if response.state != "DONE":
  print(f"Jobs state = {response.state}")
  print("The error message:\n{}".format('\n'.join(response.detailed_status[-5:])))
else:
  print("Classifier model = {}".format(response.trained_models[0].model_dump_json(indent=2)))

# NOTE: After training finishes, it may take a minute or two for the model to
#       become available.

Training Status for classification_jobs:c9341d7667f4a5fbc66d90d39bd05d649eb5319d020962f40f5dcf41b41adf8f:e52e8fac-861f-43a0-96e2-cb21aedf387e


Unnamed: 0,Step,Epoch,Learning Rate,Training Loss,Eval Loss,F1
0,2,0.333333,3.3e-05,0.6693,0.693945,0.494949
1,4,0.666667,4.7e-05,0.655,0.526855,0.8
2,6,1.0,4e-05,0.4645,0.397363,0.89899
3,8,1.333333,3.3e-05,0.3735,0.306348,1.0
4,10,1.666667,2.7e-05,0.2753,0.262891,1.0
5,12,2.0,2e-05,0.2456,0.233252,1.0
6,14,2.333333,1.3e-05,0.2181,0.219189,1.0
7,16,2.666667,7e-06,0.2079,0.213232,1.0
8,18,3.0,0.0,0.2128,0.210596,1.0


Classifier model = {
  "contract_score": 0.0,
  "epoch": 3.0,
  "eval_loss": 0.21059569716453552,
  "serving_id": 0,
  "serving_state": "UNLOADED",
  "step": 18
}


### Evaluating the model

We already have a handful of **ground-truth examples**, the `examples` that we used in the beginning to create our query set and label our queries.

As an approximation for accuracy, let's see how well the distilled model performs on those examples.

In [None]:
from transformers import pipeline

def evaluate(dataset, classifier):
    df = pd.DataFrame(dataset)  # columns: ['text', 'label']
    df["prediction"] = [prediction["label"] for prediction in classifier(df["text"].tolist())]
    df["correct"] = df["label"] == df["prediction"]
    return df

path = download_classifier(job_id, "./model-checkpoint")
classifier = pipeline("text-classification", model="./model-checkpoint")
df = evaluate(examples, classifier)
print(f"Accuracy: {100*df['correct'].mean():.1f}%")

df[['text', 'label', 'prediction', 'correct']]

Device set to use cpu


Accuracy: 83.3%


Unnamed: 0,text,label,prediction,correct
0,CONGRATULATIONS! You've been selected to recei...,Spam,Spam,True
1,"Hello dear, I noticed your profile and think y...",Spam,Not Spam,False
2,Hi there! 👋 I'm selling handmade bracelets at ...,Spam,Spam,True
3,"Hey, are we still meeting for coffee tomorrow ...",Not Spam,Not Spam,True
4,Hi Sarah! Just wanted to thank you for your he...,Not Spam,Not Spam,True
5,"Dude, did you see that game last night? Incred...",Not Spam,Not Spam,True
6,URGENT: Your account will be suspended in 24 h...,Spam,Spam,True
7,Hello friend! I am international business pers...,Spam,Not Spam,False
8,Want to lose 20 pounds in just 1 week? Our new...,Spam,Spam,True
9,The photos from our hiking trip are finally up...,Not Spam,Spam,False


## Next steps

You now have a tiny model that approximates the performance of the zero-shot model. Here are some steps you can take to improve the performance of this model.

 1. **Providing more inputs for input generation.** If you have 20-100 inputs, sharing more inputs as seed for data generation can be helpful.
 2. **Using real input data for training.** If you have 100+ inputs, you can use those inputs as the inputs to teacher model, rather than using generated input. These models will better reflect the **input distribution**, or the data that the model will see during inference.
 3. **Using real examples for training.** If you have 300+ (the more the better) high-quality labeled examples, you can train a classifier directly. This will yield the best performance for your use case, but may take time to obtain. Therefore, steps (1) and (2) are a great way to quickly deploy a first iteration of a system that can be used to collect high-quality inputs, which can then be labeled to build a high quality training set.