In [4]:
import os

from datasets import load_dataset
from dotenv import load_dotenv

ds = load_dataset("allenai/WildChat-1M")
env_local_path = ".env.local"
load_dotenv(env_local_path)
api_key = os.getenv("OPENAI_API_KEY")
do_persona_ft = False  # if False, we do personaless fine-tuning, which seems to do better
filtered_languages = {"en"}  # if empty, do not do filtering. Otherwise, include only languages in this set
filtered_countries = {"United States", "Canada", "United Kingdom", "New Zealand", "Australia"}

In [5]:
from asyncio import Semaphore

from openai import AsyncOpenAI
from tqdm.asyncio import tqdm_asyncio as tqdm_asyncio

sem = Semaphore(24)  # 24 concurrent max
client = AsyncOpenAI(api_key=api_key)


async def chat_complete(
    user_prompt: str, system_prompt: str = "", model: str = "gpt-4o", temperature: float = 0, sem: Semaphore = None
) -> str:
    sem = sem or Semaphore(1)

    try:
        async with sem:
            sys_prefix = [dict(role="system", content=system_prompt)] if system_prompt else []
            response = await client.chat.completions.create(
                messages=sys_prefix + [dict(role="user", content=user_prompt)], model=model, temperature=temperature
            )

            return response.choices[0].message.content
    except:  # noqa
        return ""


async def batch_complete(prompts: list[str], **kwargs):
    return await tqdm_asyncio.gather(*[chat_complete(prompt, sem=sem, **kwargs) for prompt in prompts])

In [None]:
import json


def get_persona_prompt(user_messages: list[str]):
    message = "\n".join(user_messages)
    prompt = f"""Respond with fewer than 50 words the persona (e.g., scientist, programmer, artist, layman, etc.), interests (e.g., C++, physics, gardening), and writing style (lowercase informal, texting, formal) of this person: {message}
    
    Return it as a JSON object, e.g., {{"persona": ..., "interests": ["..."], "style": ...}}. Do not explain or add markup.
    """  # noqa

    return prompt


def switch_role(chat_line: dict[str, str]):
    chat_line["role"] = "user" if chat_line["role"] == "assistant" else "assistant"
    chat_line = dict(role=chat_line["role"], content=chat_line["content"])
    return chat_line


num_ft_rows = 300000
rows = ds["train"].shuffle()[:num_ft_rows]

In [4]:
import pandas as pd

df = pd.DataFrame(rows)

In [5]:
gdf = df.groupby("hashed_ip")
df = gdf.head(20)

In [10]:
from fast_langdetect import detect
from tqdm import tqdm

ft_examples = []
personas = []
filt_rows = []
filt_user_texts = []

for _, row in tqdm(df.iterrows()):
    if filtered_countries and row["country"] not in filtered_countries:
        continue

    user_text = "\n".join(x["content"] for x in row["conversation"] if x["role"] == "user")

    try:
        lang = detect(user_text)
    except:  # noqa
        continue

    lang = lang["lang"]

    if filtered_languages and lang not in filtered_languages:
        continue

    personas.append(get_persona_prompt([x["content"] for x in row["conversation"] if x["role"] == "user"]))
    filt_rows.append(row)
    filt_user_texts.append(row["conversation"][0]["content"])

224492it [00:07, 30666.32it/s]


In [12]:
from sentence_transformers import SentenceTransformer

model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2").cuda()
user_vecs = model.encode(filt_user_texts)

In [45]:
import numpy as np

sel_vec_idxs = []
cutoff = np.quantile(np.linalg.norm(user_vecs[:5000] - user_vecs[5000:10000], axis=1), 0.01)
print("Dedup cutoff", cutoff)

for idx, user_vec in tqdm(enumerate(user_vecs)):
    sel = True

    for sel_vec_idx in sel_vec_idxs:
        if np.linalg.norm(user_vecs[sel_vec_idx] - user_vec) < cutoff:
            sel = False
            break

    if sel:
        sel_vec_idxs.append(idx)

Cutoff 1.0082007086277007


15360it [00:50, 302.18it/s]


In [58]:
filt_new_rows = np.array(filt_rows, dtype=object)[sel_vec_idxs].tolist()
new_personas = np.array(personas, dtype=object)[sel_vec_idxs].tolist()

In [60]:
if do_persona_ft:
    persona_strings = await batch_complete(new_personas)
else:
    persona_strings = ["{}"] * len(filt_new_rows)

required_keys = {"persona", "interests", "style"}

for row, persona_string in zip(filt_new_rows, persona_strings, strict=False):
    try:
        persona_data = json.loads(persona_string)

        if do_persona_ft and required_keys.difference(set(persona_data.keys())):
            # non-empty; doesn't contain enough keys
            continue

        ft_conversation = []

        if do_persona_ft:
            ft_conversation.append(dict(role="system", content=json.dumps(persona_data)))

        ft_conversation.append(dict(role="user", content="Start chatting..."))
        ft_conversation += [switch_role(x) for x in row[3]]

        ft_examples.append(ft_conversation)
    except:  # noqa
        continue

In [61]:
len(ft_examples)

6240

In [62]:
from pathlib import Path

output_name = "data/personaless-fine-tuning-clean-6k.jsonl"

with Path(output_name).open("w") as f:
    for ft_example in ft_examples:
        ft_example = ft_example[:-1]  # drop last message, which is not an assistant
        print(json.dumps(dict(messages=ft_example)), file=f)

In [63]:
from openai import OpenAI

output_name = "data/personaless-fine-tuning-clean-clean-6k.jsonl"
client = OpenAI(api_key=api_key)

In [64]:
ret = client.files.create(file=open(output_name, "rb"), purpose="fine-tune")

In [65]:
ft_job = client.fine_tuning.jobs.create(
    training_file=ret.id, model="gpt-4o-mini-2024-07-18", hyperparameters=dict(n_epochs=2)
)

In [67]:
for job in client.fine_tuning.jobs.list(limit=3):
    print(job)
    print()

FineTuningJob(id='ftjob-5BVZ98vdyUMEoLLLBCNy8Zi9', created_at=1726294435, error=Error(code=None, message=None, param=None), fine_tuned_model=None, finished_at=None, hyperparameters=Hyperparameters(n_epochs=2, batch_size='auto', learning_rate_multiplier='auto'), model='gpt-4o-mini-2024-07-18', object='fine_tuning.job', organization_id='org-3ASX67KsWNwPQahmhfe6g8KT', result_files=[], seed=1525065784, status='validating_files', trained_tokens=None, training_file='file-24JsIyrbW2HBWU8aHWrN4Boz', validation_file=None, estimated_finish=None, integrations=[], user_provided_suffix=None)

FineTuningJob(id='ftjob-0jGeCaBGHfz5s2S3VD2ZGLdU', created_at=1726276803, error=Error(code=None, message=None, param=None), fine_tuned_model='ft:gpt-4o-mini-2024-07-18:yupp::A7DE2LPf', finished_at=1726282496, hyperparameters=Hyperparameters(n_epochs=2, batch_size=15, learning_rate_multiplier=1.8), model='gpt-4o-mini-2024-07-18', object='fine_tuning.job', organization_id='org-3ASX67KsWNwPQahmhfe6g8KT', result_f

In [116]:
client.fine_tuning.jobs.retrieve(ft_job.id)

FineTuningJob(id='ftjob-MNX1SHvY56jUNFTK70MPDOSH', created_at=1725517698, error=Error(code=None, message=None, param=None), fine_tuned_model=None, finished_at=None, hyperparameters=Hyperparameters(n_epochs=3, batch_size=9, learning_rate_multiplier=1.8), model='gpt-4o-mini-2024-07-18', object='fine_tuning.job', organization_id='org-3ASX67KsWNwPQahmhfe6g8KT', result_files=[], seed=332634286, status='validating_files', trained_tokens=None, training_file='file-c7LMnasYXwp8K82ZWpof8LPs', validation_file=None, estimated_finish=None, integrations=[], user_provided_suffix=None)

In [119]:
client.fine_tuning.jobs.list_events("ftjob-H1zgfvNuwjk8MEdvoC7hcoBP")

SyncCursorPage[FineTuningJobEvent](data=[FineTuningJobEvent(id='ftevent-B3obcYvWRUJmi9VxmCqVd1L7', created_at=1725517535, level='info', message='Validating training file: file-0HkpNIAQwZg4nDAP6b9pfTd3', object='fine_tuning.job.event', data={}, type='message'), FineTuningJobEvent(id='ftevent-n1gMrtooOEQGZSfpCFmv8ZqL', created_at=1725517535, level='info', message='Created fine-tuning job: ftjob-H1zgfvNuwjk8MEdvoC7hcoBP', object='fine_tuning.job.event', data={}, type='message')], object='list', has_more=False)