In [13]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [14]:
# put safety-tooling in the path
import sys
import os
import pathlib

def put_submodule_in_python_path(submodule_name: str):
    repo_root = pathlib.Path(os.getcwd())
    submodule_path = repo_root / submodule_name
    if submodule_path.exists():
        sys.path.append(str(submodule_path))

put_submodule_in_python_path("safety-tooling")


In [15]:
from safetytooling.apis.inference.openai.batch_api import OpenAIModelBatch
from safetytooling.data_models import ChatMessage, LLMResponse, MessageRole, Prompt
from safetytooling.utils import utils

utils.setup_environment(logging_level="warning", openai_tag="OPENAI_API_KEY1", anthropic_tag="ANTHROPIC_API_KEY_BATCH")

In [16]:
prompts = [
    Prompt(messages=[ChatMessage(content="Say 1", role=MessageRole.user)]),
    Prompt(messages=[ChatMessage(content="Say 2", role=MessageRole.user)]),
    Prompt(messages=[ChatMessage(content="Say 3", role=MessageRole.user)]),
    Prompt(messages=[ChatMessage(content="Say 4", role=MessageRole.user)]),
    Prompt(messages=[ChatMessage(content="Say 5", role=MessageRole.user)]),
]

In [17]:
batch_model = OpenAIModelBatch()
model_id = "gpt-4o-mini"
log_dir = pathlib.Path("logs")
log_dir.mkdir(parents=True, exist_ok=True)

In [None]:
responses, batch_id = await batch_model(model_id=model_id, prompts=prompts[0:2], log_dir=log_dir, max_tokens=10)


### Testing each function separately

In [18]:
input_file_path = batch_model.prompts_to_file(model_id, prompts, log_dir, max_tokens=100)


In [20]:
batch_file_object = batch_model.client.files.create(file=open(input_file_path, "rb"), purpose="batch")
batch_file_object

FileObject(id='file-EkKG5dL64wNrKcdiGcKqpP', bytes=1070, created_at=1736346945, filename='batch_input.jsonl', object='file', purpose='batch', status='processed', status_details=None)

In [21]:

batch_object = batch_model.create_message_batch(input_file_id=batch_file_object.id)
batch_object

Batch(id='batch_677e8d44657c81909366639bbb5ac367', completion_window='24h', created_at=1736346948, endpoint='/v1/chat/completions', input_file_id='file-EkKG5dL64wNrKcdiGcKqpP', object='batch', status='validating', cancelled_at=None, cancelling_at=None, completed_at=None, error_file_id=None, errors=None, expired_at=None, expires_at=1736433348, failed_at=None, finalizing_at=None, in_progress_at=None, metadata={'description': 'batch processing'}, output_file_id=None, request_counts=BatchRequestCounts(completed=0, failed=0, total=0))

In [22]:
import json

log_file = log_dir / f"batch_id_{batch_object.id}.json"
log_file.parent.mkdir(parents=True, exist_ok=True)
with open(log_file, "w") as f:
    json.dump(batch_object.model_dump(), f)
print(f"Batch {batch_object.id} created with {len(prompts)} requests")


Batch batch_677e8d44657c81909366639bbb5ac367 created with 5 requests


In [23]:
batch_model.retrieve_message_batch(batch_object.id)

Batch(id='batch_677e8d44657c81909366639bbb5ac367', completion_window='24h', created_at=1736346948, endpoint='/v1/chat/completions', input_file_id='file-EkKG5dL64wNrKcdiGcKqpP', object='batch', status='in_progress', cancelled_at=None, cancelling_at=None, completed_at=None, error_file_id=None, errors=None, expired_at=None, expires_at=1736433348, failed_at=None, finalizing_at=None, in_progress_at=1736346949, metadata={'description': 'batch processing'}, output_file_id=None, request_counts=BatchRequestCounts(completed=0, failed=0, total=5))

In [24]:
batch_object = await batch_model.poll_message_batch(batch_object.id)


In [None]:

results = batch_model.retrieve_message_batch_results(batch_object.output_file_id)

responses = []
for result in results:
    if result['response']['status_code'] == 200:
        body = result['response']['body']
        choice = body['choices'][0]
        responses.append(
            LLMResponse(
                model_id=model_id,
                completion=choice['message']['content'],
                stop_reason=choice['finish_reason'],
                duration=None,  # Batch does not track individual durations
                api_duration=None,
                cost=0,
            )
        )