# Vendor quality labeling run

This notebook collects a random sample of prompts and responses from backfill chats, adding missing QT responses if needed.

This can be used to share with a vendor for labeling.

In [1]:
import sys  # noqa

sys.path.append("..")
from dotenv import load_dotenv      
from openai import OpenAI
from anthropic import Anthropic
from google import generativeai
from tenacity import retry, stop_after_attempt, wait_fixed
import pandas as pd
from tqdm import tqdm
from sqlmodel import Session, select, text, func
from ypl.backend.db import get_engine
from ypl.db.chats import Chat, ChatMessage, Turn, TurnQuality, Category, MessageType
from ypl.db.language_models import LanguageModel
from ypl.db.users import User
import uuid
import random

load_dotenv()

True

In [2]:
# Get a random sample of backfill chats and their messages.

excluded_categories = {'Math', 'Code', 'Comparison'}

with Session(get_engine()) as session:
    # Select random backfill chats.
    random_chats_query = (
        select(Chat.chat_id)
        .join(User, Chat.creator_user_id == User.user_id)
        .where(
            User.backfill_job_id.is_not(None),
            Chat.deleted_at.is_(None),
            User.deleted_at.is_(None),
        )
        .order_by(func.random())
        .limit(3000)
    )
    random_chats = session.exec(random_chats_query).all()
    random_chat_ids = [str(chat_id) for chat_id in random_chats]

    # Get the prompt and its responses.
    query = (
        select(
            Chat.chat_id,
            ChatMessage.content,
            ChatMessage.message_type,
            ChatMessage.message_id,
            Turn.turn_id,
            TurnQuality.prompt_difficulty,
            Category.name.label("category_name"),
        )
        .join(Turn, ChatMessage.turn_id == Turn.turn_id)
        .join(TurnQuality, Turn.turn_id == TurnQuality.turn_id)
        .join(Chat, Turn.chat_id == Chat.chat_id)
        .join(Category, ChatMessage.category_id == Category.category_id, isouter=True)
        .where(
            ChatMessage.deleted_at.is_(None),
            Chat.chat_id.in_(random_chat_ids),
            (Category.name.notin_(excluded_categories) | Category.name.is_(None)),
        )
        .order_by(Turn.turn_id, ChatMessage.created_at)
    )

    df = pd.read_sql(query, session.connection())

print(df.shape)

(3800, 7)


In [5]:
# Helpers to add QT responses, if needed (most of the backfill chats don't have them).

qt_system_prompt = """
You are a model that will give very concise responses.
IMPORTANT: don't add any explanations on the answer;
don't write full sentences, unless the user is very specifically asking you for a long answer;
for answers that are non-factual, make it witty or funny, but still brief;
if the user asks to output markdown or any markup, return the cleaned text only;
do not use newlines;
NEVER prompt for more information, feedback, or responses.
Respond in fewer than 160 characters, in the language of the user's message.

Here are some examples: 

Question: Why is the sky blue? 
Answer: Rayleigh scattering of sunlight by the atmosphere.

Question: what is the meaning of life?
Answer: 42

Question: How many people are there in the US? 
Answer: 333.3 million

Question: Should I buy Elden Ring?
Answer: Only if you enjoy gorgeous landscapes and repeatedly dying in them
"""

qt_model_names = ['gemini-1.5-flash-8b', 'claude-3-5-sonnet-20240620', 'gpt-4o']
client_openai = OpenAI()
client_anthropic = Anthropic()
client_google = generativeai.configure()


# Get the model ids for the QT models.
with Session(get_engine()) as session:
    qt_models = session.exec(
        select(LanguageModel).where(LanguageModel.name.in_(qt_model_names))
    ).all()
    qt_models = {model.name: str(model.language_model_id) for model in qt_models}

# Get a QT response for a prompt.
@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
def qt(user_prompt: str):
    qt_model = random.choice(list(qt_models.keys()))
    qt_model_id = qt_models[qt_model]
    response = None

    if qt_model.startswith("claude"):
        message = client_anthropic.messages.create(
            model=qt_model,
            max_tokens=1024,
            system=qt_system_prompt,
            messages=[
                {"role": "user", "content": user_prompt}
            ],
        )
        response = message.content[0].text

    elif qt_model.startswith("gemini"):
        google_model = generativeai.GenerativeModel(qt_model, system_instruction=qt_system_prompt)
        content = google_model.generate_content(user_prompt)
        response = content.text

    elif qt_model.startswith('gpt-4o'):
        completion = client_openai.chat.completions.create(
            model=qt_model,
            messages=[
                {"role": "system", "content": qt_system_prompt},
                {"role": "user", "content": f"User prompt: {user_prompt}"},
            ],
        )
        response = completion.choices[0].message.content

    else:
        raise ValueError(f"Unknown model: {qt_model}")
    
    return response, qt_model

# Add a QT response to a turn, storing it in the database.
def add_qt(turn_id: str, prompt: str):
    qt_response, qt_model = qt(prompt)

    with Session(get_engine()) as session:
        message = ChatMessage(
            turn_id=turn_id,
            content=qt_response,
            message_type=MessageType.QUICK_RESPONSE_MESSAGE,
            assistant_model_name=qt_model,   
            assistant_language_model_id=qt_models[qt_model],
            message_id=str(uuid.uuid4()),         
        )
        session.add(message)
        session.commit()
        message_id = message.message_id
        return message_id, qt_response


In [6]:
# Process a turn, adding QT responses if needed, and converting it to a row in the output dataframe.
# For now QT generation is blocking as the number of requests is low, but we can make it async if needed.
def process_turn(turn_id, df: pd.DataFrame):
    prompt = df[df['message_type'] == MessageType.USER_MESSAGE]
    if len(prompt) != 1:
        # Excluded due to the category.
        return {}
    prompt_text = prompt.content.values[0]
    prompt_message_id = prompt.message_id.values[0]

    responses = df[df['message_type'] == MessageType.ASSISTANT_MESSAGE]
    if len(responses) != 2:
        print("Wrong number of responses in turn ", turn_id)
        return {}
    
    qt_response = df[df['message_type'] == MessageType.QUICK_RESPONSE_MESSAGE]
    if len(qt_response) < 1:
        qt_message_id, qt_text = add_qt(turn_id, prompt_text)
    else:
        qt_message_id = qt_response.message_id.values[0]
        qt_text = qt_response.content.values[0]

    res = {
        'turn_id': str(turn_id),
        'prompt_message_id': str(prompt_message_id),
        'prompt_text': prompt_text,
        'prompt_difficulty': df.prompt_difficulty.values[0],
        'prompt_category': df.category_name.values[0],
    }

    responses = [
        (qt_message_id, qt_text),
        (str(responses.message_id.values[0]), responses.content.values[0]),
        (str(responses.message_id.values[1]), responses.content.values[1]),
    ]
    random.shuffle(responses)

    for i in range(len(responses)):
        res[f'response_{i}_message_id'] = str(responses[i][0])
        res[f'response_{i}_text'] = responses[i][1]

    return res


# Process all turns into a dataframe.
rows = []
groups = df.groupby('turn_id')
for group in tqdm(groups):
    rows.append(process_turn(*group))

df_results = pd.DataFrame(rows).dropna()

100%|██████████| 1223/1223 [23:05<00:00,  1.13s/it]


In [7]:
min_count_per_category = 20
total_samples = 500

# Ensure a minimal number of rows for each category.
sampled_df = pd.concat([
    df_results[df_results['prompt_category'] == category].sample(n=min_count_per_category, replace=True)
    for category in df_results['prompt_category'].unique()
])

# Add remaining samples to meet `total_samples`.
remaining_sample = df_results.drop(sampled_df.index).sample(n=total_samples - len(sampled_df), replace=False)
sampled_df = pd.concat([sampled_df, remaining_sample])

sampled_df.prompt_category.value_counts()

prompt_category
Factual             144
Creative Writing     79
Advice               48
Opinion              38
Other                37
Education            36
Analysis             26
Summarization        24
Multilingual         23
Reasoning            23
Entertainment        22
Name: count, dtype: int64

In [8]:
# Spot check.
df_results.sample(20)[[c for c in df_results.columns if not c.endswith('_id')]]

Unnamed: 0,prompt_text,prompt_difficulty,prompt_category,response_0_text,response_1_text,response_2_text
199,What is the best way to clean a small stain fr...,4.0,Advice,Cleaning a small stain from a leather wallet c...,Cleaning a small stain from a leather wallet r...,"Use a damp cloth and mild soap, gently rub, th..."
170,Is there a known method to predict if a person...,7.0,Factual,"Yes, there are methods to predict mortality ri...","Yes, there are methods and models used in the ...","No foolproof method, just risk models with pro..."
216,"give me a quote from any book, movie, or anime...",4.0,Entertainment,"""People will come and go in your life, but onl...","""I'm not crazy. My reality is just different f...",One relatable quote comes from the anime *Full...
1188,Why is there a small white line on the left ha...,3.0,Factual,\n\nPlease provide the example image so I can ...,"However, I don't see an example image or any c...",Possibly a misalignment or formatting issue in...
921,How do we design a video game?,7.0,Other,Designing a video game is a complex process t...,"Concept, gameplay, story, art, code, test, pol...","Designing a video game is a complex, creative,..."
935,You are a Internet troll. You were tasked to c...,3.0,Creative Writing,I'm here to promote positive and respectful in...,Backside Buffoon,I cannot fulfill your request. \n\nMy purpose ...
72,List all the colleges in Texas that have a mus...,5.0,Factual,"Sam Houston State University, Texas Woman's Un...",Here's a list of colleges in Texas that offer ...,As of my last knowledge update in October 2023...
693,How much would it cost to buy 80% of the stock...,6.0,Factual,"Apple Inc.'s market cap is around $3 trillion,...",To calculate the cost of buying 80% of Apple I...,I cannot give you a precise cost for buying 80...
893,I need a technical overview of the concept of ...,7.0,Analysis,The pseudospectral method is a numerical techn...,Numerical technique using spectral basis funct...,Pseudospectral methods are a class of numerica...
504,Where can I find a library in new york,3.0,Factual,The city that never sleeps... and reads! There...,Check the New York Public Library or use Googl...,New York City has many libraries across its fi...


In [9]:
# Shuffle and remove columns that are not shared with the vendor, and store to CSV.
shuffled_df = sampled_df.sample(frac=1).reset_index(drop=True)

shuffled_df.drop(columns=['prompt_category', 'prompt_difficulty']).to_csv('promptstart_500.csv', index=False)
shuffled_df.to_csv('promptstart_500_with_categories_and_difficulty.csv', index=False)