In [93]:
import os

from datasets import load_dataset
from dotenv import load_dotenv

ds = load_dataset("allenai/WildChat-1M")
env_local_path = "../../sarai-chat/.env.local"
load_dotenv(env_local_path)
api_key = os.getenv("OPENAI_API_KEY")

In [56]:
from asyncio import Semaphore

from openai import AsyncOpenAI
from tqdm.asyncio import tqdm_asyncio as tqdm_asyncio

sem = Semaphore(24)  # 24 concurrent max
client = AsyncOpenAI(api_key=api_key)


async def chat_complete(
    user_prompt: str, system_prompt: str = "", model: str = "gpt-4o", temperature: float = 0, sem: Semaphore = None
) -> str:
    sem = sem or Semaphore(1)

    try:
        async with sem:
            sys_prefix = [dict(role="system", content=system_prompt)] if system_prompt else []
            response = await client.chat.completions.create(
                messages=sys_prefix + [dict(role="user", content=user_prompt)], model=model, temperature=temperature
            )

            return response.choices[0].message.content
    except:  # noqa
        return ""


async def batch_complete(prompts: list[str], **kwargs):
    return await tqdm_asyncio.gather(*[chat_complete(prompt, sem=sem, **kwargs) for prompt in prompts])

In [57]:
import json


def get_persona_prompt(user_messages: list[str]):
    message = "\n".join(user_messages)
    prompt = f"""Respond with fewer than 50 words the persona (e.g., scientist, programmer, artist, layman, etc.), interests (e.g., C++, physics, gardening), and writing style (lowercase informal, texting, formal) of this person: {message}
    
    Return it as a JSON object, e.g., {{"persona": ..., "interests": ["..."], "style": ...}}. Do not explain or add markup.
    """  # noqa

    return prompt


def switch_role(chat_line: dict[str, str]):
    chat_line["role"] = "user" if chat_line["role"] == "assistant" else "assistant"
    chat_line = dict(role=chat_line["role"], content=chat_line["content"])
    return chat_line


num_ft_rows = 10000
ft_examples = []
personas = []
row_idxs = []

row_idx = 0

while len(row_idxs) < num_ft_rows:
    row_idx += 1
    row = ds["train"][row_idx]

    if row["country"] != "United States":
        continue

    row_idxs.append(row_idx)
    personas.append(get_persona_prompt([x["content"] for x in row["conversation"] if x["role"] == "user"]))

persona_strings = await batch_complete(personas)
required_keys = {"persona", "interests", "style"}

for row_idx, persona_string in zip(row_idxs, persona_strings, strict=False):
    row = ds["train"][row_idx]

    try:
        persona_data = json.loads(persona_string)

        if required_keys.difference(set(persona_data.keys())):
            # non-empty; doesn't contain enough keys
            continue

        ft_conversation = []
        ft_conversation.append(dict(role="system", content=json.dumps(persona_data)))
        ft_conversation.append(dict(role="user", content="Start chatting..."))
        ft_conversation += [switch_role(x) for x in row["conversation"]]

        ft_examples.append(ft_conversation)
    except:  # noqa
        continue

100%|██████████| 10000/10000 [15:24<00:00, 10.81it/s]


In [84]:
from pathlib import Path

with Path("data/fine-tuning-5k.jsonl").open("w") as f:
    for ft_example in ft_examples:
        ft_example = ft_example[:-1]  # drop last message, which is not an assistant
        print(json.dumps(dict(messages=ft_example)), file=f)

In [117]:
from openai import OpenAI

api_key = os.getenv("OPENAI_API_KEY")
client = OpenAI(api_key=api_key)

In [104]:
ret = client.files.create(file=open("data/fine-tuning-5k.jsonl", "rb"), purpose="fine-tune")

In [105]:
ft_job = client.fine_tuning.jobs.create(training_file=ret.id, model="gpt-4o-mini-2024-07-18")

In [118]:
for job in client.fine_tuning.jobs.list(limit=3):
    print(job)
    print()

FineTuningJob(id='ftjob-H1zgfvNuwjk8MEdvoC7hcoBP', created_at=1725517535, error=Error(code=None, message=None, param=None), fine_tuned_model=None, finished_at=None, hyperparameters=Hyperparameters(n_epochs=3, batch_size=9, learning_rate_multiplier=2), model='gpt-3.5-turbo-0125', object='fine_tuning.job', organization_id='org-3ASX67KsWNwPQahmhfe6g8KT', result_files=[], seed=1927579229, status='validating_files', trained_tokens=None, training_file='file-0HkpNIAQwZg4nDAP6b9pfTd3', validation_file=None, estimated_finish=None, integrations=[], user_provided_suffix=None)

FineTuningJob(id='ftjob-NeqHHFppV6WNNwi3AOiWE7el', created_at=1725517026, error=Error(code=None, message=None, param=None), fine_tuned_model=None, finished_at=None, hyperparameters=Hyperparameters(n_epochs=3, batch_size=9, learning_rate_multiplier=1.8), model='gpt-4o-mini-2024-07-18', object='fine_tuning.job', organization_id='org-3ASX67KsWNwPQahmhfe6g8KT', result_files=[], seed=1756357400, status='validating_files', traine

In [116]:
client.fine_tuning.jobs.retrieve(ft_job.id)

FineTuningJob(id='ftjob-MNX1SHvY56jUNFTK70MPDOSH', created_at=1725517698, error=Error(code=None, message=None, param=None), fine_tuned_model=None, finished_at=None, hyperparameters=Hyperparameters(n_epochs=3, batch_size=9, learning_rate_multiplier=1.8), model='gpt-4o-mini-2024-07-18', object='fine_tuning.job', organization_id='org-3ASX67KsWNwPQahmhfe6g8KT', result_files=[], seed=332634286, status='validating_files', trained_tokens=None, training_file='file-c7LMnasYXwp8K82ZWpof8LPs', validation_file=None, estimated_finish=None, integrations=[], user_provided_suffix=None)

In [119]:
client.fine_tuning.jobs.list_events("ftjob-H1zgfvNuwjk8MEdvoC7hcoBP")

SyncCursorPage[FineTuningJobEvent](data=[FineTuningJobEvent(id='ftevent-B3obcYvWRUJmi9VxmCqVd1L7', created_at=1725517535, level='info', message='Validating training file: file-0HkpNIAQwZg4nDAP6b9pfTd3', object='fine_tuning.job.event', data={}, type='message'), FineTuningJobEvent(id='ftevent-n1gMrtooOEQGZSfpCFmv8ZqL', created_at=1725517535, level='info', message='Created fine-tuning job: ftjob-H1zgfvNuwjk8MEdvoC7hcoBP', object='fine_tuning.job.event', data={}, type='message')], object='list', has_more=False)