# Tutorial: Using the OpenAI Batch API (for rewrites)

Helpful links:
- Pricing of available models: https://openai.com/api/pricing/
    - This seems to be the only way to find the lastest, cheaper version of GPT 4o, which is half the price of the default GPT-4o (as of 8/30/24)

- official docs for the Batch API: https://platform.openai.com/docs/guides/batch/getting-started

In [1]:
# Load imdb for examples
from datasets import load_dataset
import pandas as pd

# Batch API related imports
from openai import OpenAI # OpenAI API
import json # for creating batch files

# And for loading the API key from the .env file
from dotenv import load_dotenv
import os

### Prep the data you want to send to Chat-GPT

We'll just consider the task of rewriting the length of some IMDB reviews. So, we'll load the reviews like normal (and add a feature about whether it's long or not).

In [2]:
# load imdb dataset from huggingface
dataset = load_dataset('imdb')

# create pandas dataframe with both label and text, including both train and test data
imdb = pd.DataFrame({'text': dataset['train']['text'] + dataset['test']['text'], 'label': dataset['train']['label'] + dataset['test']['label']})

# Remove samples which have length greater than 4000 characters
imdb = imdb[imdb['text'].apply(lambda x: len(x) < 4000)]

# Bin text lengths
imdb.loc[:, 'text_length'] = imdb['text'].str.len()
length_threshold = 1200
imdb['is_long'] = imdb['text_length'] > length_threshold

The way that the batch API works is by sending all your requests at once in a jsonl file.

So, we'll write a few functions which define the rewriting task, then save those to a jsonl file.

Example input:

{"custom_id": "9546", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-4o", "messages": [{"role": "user", "content": "This is a movie review..."}], "temperature": 0.7}}

In [3]:
# create a prompt that we will use to generate rewrites
def review_messages(review, is_long):
    """"messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}]"""

    flipped_length = "longer" if is_long == 0 else "shorter"
    counterfactual_prompt = f"""
    Original Completion: 
    {review}
    [Adjust the original completion so it's 50% {flipped_length}, but change *nothing* else.]
    Adjusted Completion:
    """
    
    messages=[
                {
                    "role": "user",
                    "content": counterfactual_prompt,
                },
            ]
    return messages

# create each line of the batch input file
def write_input_jsonl(df, batch_input_filename):
    batch_input = []
    for idx, row in df.iterrows():
        messages = review_messages(row['text'], row['label'])
        batch_input.append({
            "custom_id": str(idx), # Each request needs a unique ID! Will fail otherwise
            "method": "POST", # only option for now
            "url": "/v1/chat/completions", # endpoint for chat completions
            "body": {
                "model": "gpt-4o-2024-08-06", # This is the cheapest model (as of 8/30/24). It may not be the highest quality.
                "messages": messages,
                "temperature": 0.7
            }
        })
    # save the batch input file as jsonl
    with open(batch_input_filename, 'w') as f:
        for item in batch_input:
            json.dump(item, f)  # Properly escape quotes and special characters
            f.write("\n")  # Ensure each JSON object is on a new line


In [4]:
# create the batch input file
batch_input_filename = "batch_api_data/sentiment_batch_input.jsonl"

# only sample 10 examples for now, for illustration purposes
write_input_jsonl(imdb.sample(10), batch_input_filename)

### Get rewrites from OpenAI Batch API

Now that we have our batch file created, we're ready to submit our batch!

First, we setup the OpenAI client like normal. Then, we upload our batch file, and start the batch!

In [5]:
# Setup the OpenAI client
load_dotenv()
client = OpenAI(
    # This is the default and can be omitted
    api_key=os.environ.get("OPENAI_API_KEY"),
)

In [6]:
# Upload the batch file
batch_input_file = client.files.create(
    file=open(batch_input_filename, "rb"),
    purpose="batch"
)

# Create the batch job
batch = client.batches.create(
    input_file_id=batch_input_file.id, # the file we just uploaded
    endpoint="/v1/chat/completions", # this is the default (only?) endpoint
    completion_window="24h", # this is the only option
    metadata={
        "description":"length rewrite" # Pick a descriptive name
    }
)
print(batch)

Batch(id='batch_SNvLgDYvpYOuRkOphpRygNjV', completion_window='24h', created_at=1725053789, endpoint='/v1/chat/completions', input_file_id='file-4HYg6vMt0iN9Zo9jCsg5BbBQ', object='batch', status='validating', cancelled_at=None, cancelling_at=None, completed_at=None, error_file_id=None, errors=None, expired_at=None, expires_at=1725140189, failed_at=None, finalizing_at=None, in_progress_at=None, metadata={'description': 'length rewrite'}, output_file_id=None, request_counts=BatchRequestCounts(completed=0, failed=0, total=0))


You can check on the status of the batch using the following command. Once status='completed', you're ready to download it! (usually only a couple minutes, if that).

In [10]:
# Retrieve and print the details of the batch
print(client.batches.retrieve(batch.id))

Batch(id='batch_SNvLgDYvpYOuRkOphpRygNjV', completion_window='24h', created_at=1725053789, endpoint='/v1/chat/completions', input_file_id='file-4HYg6vMt0iN9Zo9jCsg5BbBQ', object='batch', status='completed', cancelled_at=None, cancelling_at=None, completed_at=1725053797, error_file_id=None, errors=None, expired_at=None, expires_at=1725140189, failed_at=None, finalizing_at=1725053797, in_progress_at=1725053789, metadata={'description': 'length rewrite'}, output_file_id='file-PHCHnuVftEDXrqK0BbRkTQ31', request_counts=BatchRequestCounts(completed=10, failed=0, total=10))


In [11]:
# create the batch output file
batch_output_filename = "batch_api_data/sentiment_batch_output.jsonl"

# Retrieve the output file and write it to a local file
content = client.files.content(client.batches.retrieve(batch.id).output_file_id)
content.write_to_file(batch_output_filename)

### Process the outputs

The main thing to note is that the outputs may be out of order (which is why we assigned unique IDs). 

You can also see any error messages, if any.