In [None]:
import os
import re
import json
import glob
import time
import openai
import tiktoken
from tqdm import tqdm
from openai import OpenAI
from collections import defaultdict

API_KEY = os.environ.get("OPENAI_API_KEY")

In [20]:
def extract_list_items(text):
    items = re.findall(r"\d+[\).\s-]+\s*(.+?)(?=\n\d+[\).\s-]+|\Z)", text.strip(), re.DOTALL)
    return [item.strip().rstrip(".") + "." for item in items if item]

def get_article(word):
    return "an" if word[0].lower() in "aeiou" else "a"

def fill_prompt_templates(template_list, category_name):
    filled = []
    article = get_article(category_name)
    
    for template in template_list:
        placeholder_count = template.count("{}")
        
        if placeholder_count == 1:
            filled.append(template.format(category_name))
        elif placeholder_count == 2:
            filled.append(template.format(article, category_name))
            
    return filled

In [21]:
prompt_templates = { 
    "cars": [
        "How can you identify {} {}?",
        "Description of {} {}, a type of car.",
        "A caption of a photo of {} {}:",
        "What are the primary characteristics of {} {}?",
        "Description of the exterior of {} {}.",
        "What are the identifying characteristics of {} {}, a type of car?",
        "Describe an image from the internet of {} {}.",
        "What does {} {} look like?",
        "Describe what {} “{}”, a type of car, looks like.",
        "List the visual cues that help in identifying {} “{}”."
    ],
    # "pets": [
    #     "Describe what a pet {} looks like.",
    #     "Visually describe {} '{}', a type of pet."
    # ],
    # "dtd": [
    #     "What does a “{}” material look like?",
    #     "What does a “{}” surface look like?",
    #     "What does a “{}” texture look like?",
    #     "What does a “{}” object look like?",
    #     "What does a “{}” thing look like?",
    #     "What does a “{}” pattern look like?"
    # ],
    # "resisc45": [
    #     "Describe a satellite photo of {} {}.",
    #     "Describe {} {} as it would appear in an aerial image.",
    #     "How can you identify {} {} in an aerial photo?",
    #     "Describe the satellite photo of {} {}.",
    #     "Describe an aerial photo of {} {}."
    # ],
    # "flowers": [
    #     "Describe how to identify {} {}, a type of flower.",
    #     "What does {} {} flower look like?"
    # ],
    # "fgvc": [
    #     "What does {} {} look like?",
    #     "Describe the exterior of {} {} aircraft.",
    #     "Visually distinguishing features of {} {} airplane?",
    #     "What are the primary characteristics of {} {}?",
    # ],
    # "ucf101":[
    #     "Describe the action of {} in a photo.",
    #     "What does the act of {} look like in an image?",
    #     "Describe key visual features of “{}”.",
    #     "Describe the action “{}”."
    # ],
    # "eurosat": [
    #     "Describe a satellite photo of {} “{}”.",
    #     "Describe {} “{}” as it would appear in an aerial image.",
    #     "How can you identify {} “{}” in an aerial photo?",
    #     "Describe the satellite photo of {} “{}”.",
    #     "Describe an aerial photo of {} “{}”.",
    #     "List how one can recognize the {} “{}” within an aerial image.",
    #     "List the distinguishing features of the {} “{}” in a satellite photo.",
    #     "List the visual cues that help in identifying the {} “{}” in an aerial image.",
    #     "List the visual characteristics that make the {} “{}” easily identifiable in a satellite image.",
    #     "List how one can identify the {} “{}” based on visual cues in an aerial photo."
    # ],
    # "gtsrb": [
    #     "What does a “{}” traffic sign look like?",
    #     "Describe where the {} sign visually appears on roads.",
    #     "Visually describe a road sign that indicates “{}”.",
    #     "What are unique visual features of a “{}” traffic sign commonly found in Germany?",
    # ]
}

In [22]:
dataset_metadata = {
    "cars": "car model",
    "pets": "pet type",
    "dtd": "texture type",
    "resisc45": "aerial location",
    "flowers": "flower species",
    "fgvc": "aircraft model",
    "ucf101": "action",
    "eurosat": "prompted location",
    "gtsrb": "traffic sign type"
}

In [23]:
MAX_TOKENS_PER_FILE = 90000
ENCODING = tiktoken.encoding_for_model("gpt-4o")

In [32]:
# Token counter functions
def count_tokens(text):
    return len(ENCODING.encode(text))

def count_chat_tokens(messages):
    tokens_per_message = 3
    tokens_per_name = 1
    num_tokens = 0
    for message in messages:
        num_tokens += tokens_per_message
        for key, value in message.items():
            num_tokens += count_tokens(value)
            if key == "name":
                num_tokens += tokens_per_name
    return num_tokens + 3  # priming

# Main batching logic
for dataset in prompt_templates.keys():
    with open("../configs/classes.json", 'r') as f:
        classes = json.load(f)[dataset]

    current_file_idx = 0
    current_token_count = 0
    current_tasks = []

    for category in tqdm(classes, desc=f"Processing {dataset}"):
        prompts = fill_prompt_templates(prompt_templates[dataset], category)

        for index, curr_prompt in enumerate(prompts):
            system_msg = "You are a helpful assistant. Give 10 numbered sentences answering prompt as visually identifiable descriptions."
            user_msg = curr_prompt + f" Include '{category}' in each sentence."

            messages = [
                {"role": "system", "content": system_msg},
                {"role": "user", "content": user_msg},
            ]
            estimated_tokens = count_chat_tokens(messages)

            if current_token_count + estimated_tokens > MAX_TOKENS_PER_FILE:
                file_name = f"batch_jobs/{dataset}_batch_prompts_part{current_file_idx}.jsonl"
                os.makedirs("batch_jobs", exist_ok=True)
                with open(file_name, 'w') as file:
                    for obj in current_tasks:
                        file.write(json.dumps(obj) + '\n')
                print(f"Wrote {file_name} with {len(current_tasks)} tasks and ~{current_token_count} tokens.")

                current_file_idx += 1
                current_tasks = []
                current_token_count = 0

            task = {
                "custom_id": f"{dataset}#{category}#{index}",
                "method": "POST",
                "url": "/v1/chat/completions",
                "body": {
                    "model": "gpt-4o",
                    "temperature": 0.1,
                    "messages": messages,
                }
            }

            current_tasks.append(task)
            current_token_count += estimated_tokens

    if current_tasks:
        file_name = f"batch_jobs/{dataset}_batch_prompts_part{current_file_idx}.jsonl"
        with open(file_name, 'w') as file:
            for obj in current_tasks:
                file.write(json.dumps(obj) + '\n')
        print(f"Wrote {file_name} with {len(current_tasks)} tasks and ~{current_token_count} tokens.")

Processing cars: 100%|██████████| 196/196 [00:00<00:00, 2045.75it/s]

Wrote batch_jobs/cars_batch_prompts_part0.jsonl with 1477 tasks and ~89951 tokens.
Wrote batch_jobs/cars_batch_prompts_part1.jsonl with 483 tasks and ~29761 tokens.





In [33]:
dataset = "cars"
batch_folder = "batch_jobs"
pattern = os.path.join(batch_folder, f"{dataset}_batch_prompts*.jsonl")
file_paths = sorted(glob.glob(pattern))  # Match part files if exist, or the single file
client = OpenAI(api_key=API_KEY)

# Check if files exist
if not file_paths:
    raise FileNotFoundError(f"No batch prompt files found for dataset: {dataset}")

# Wait function to poll batch status
def wait_for_batch_completion(client, batch_id, poll_interval=30):
    while True:
        batch_status = client.batches.retrieve(batch_id).status
        print(f"🔄 Waiting... Batch {batch_id} is currently: {batch_status}")

        if batch_status == "completed":
            print(f"✅ Batch {batch_id} completed.")
            return True
        elif batch_status == "failed":
            print(f"❌ Batch {batch_id} failed. Exiting.")
            return False
        time.sleep(poll_interval)

# Submit batches sequentially
for idx, file_path in enumerate(file_paths):
    if idx > 0:
        # Wait for the previous batch to complete before continuing
        success = wait_for_batch_completion(client, prev_batch_id)
        if not success:
            break

    print(f"\n📤 Submitting batch part {idx}: {file_path}")
    with open(file_path, 'rb') as file:
        batch_input_file = client.files.create(file=file, purpose="batch")

    batch_input_file_id = batch_input_file.id

    batch = client.batches.create(
        input_file_id=batch_input_file_id,
        endpoint="/v1/chat/completions",
        completion_window="24h",
        metadata={
            "description": f"Sathira - Batch job for generating prompts for {dataset}, part {idx}.",
        }
    )

    prev_batch_id = batch.id
    print(f"🚀 Submitted batch (part {idx}) with ID: {batch.id}")


📤 Submitting batch part 0: batch_jobs/cars_batch_prompts_part0.jsonl
🚀 Submitted batch (part 0) with ID: batch_682abd7cf8408190be7178c87af44227
🔄 Waiting... Batch batch_682abd7cf8408190be7178c87af44227 is currently: validating
🔄 Waiting... Batch batch_682abd7cf8408190be7178c87af44227 is currently: failed
❌ Batch batch_682abd7cf8408190be7178c87af44227 failed. Exiting.


In [28]:
client = OpenAI(api_key=API_KEY)
batch = client.batches.retrieve("batch_682aba65ad608190a35be87805ca2ab0")
print(batch)

Batch(id='batch_682aba65ad608190a35be87805ca2ab0', completion_window='24h', created_at=1747630693, endpoint='/v1/chat/completions', input_file_id='file-7aAmKAKNhA3LAfJdZkWNXb', object='batch', status='completed', cancelled_at=None, cancelling_at=None, completed_at=1747630821, error_file_id=None, errors=None, expired_at=None, expires_at=1747717093, failed_at=None, finalizing_at=1747630786, in_progress_at=1747630696, metadata={'description': 'Sathira - Batch job for generating prompts for cars, part 1.'}, output_file_id='file-7zRJXFB4S2GAh6wVqqdSWH', request_counts=BatchRequestCounts(completed=596, failed=0, total=596))


In [30]:
def parse_batch_response(file_ids, client: OpenAI, output_dir="."):
    dataset_wise = defaultdict(lambda: defaultdict(list))

    for file_id in file_ids:
        print(f"📥 Fetching file content from: {file_id}")
        try:
            # Download batch output text
            batch_text = client.files.content(file_id).text

            for line in batch_text.strip().split("\n"):
                data = json.loads(line)

                # Skip lines without a valid response
                if "response" not in data:
                    continue

                custom_id = data.get("custom_id", "")
                parts = custom_id.split("#")
                if len(parts) < 2:
                    continue  # malformed custom_id

                dataset, class_name = parts[0], parts[1]

                try:
                    content = data["response"]["body"]["choices"][0]["message"]["content"].strip()
                    items = extract_list_items(content)
                    dataset_wise[dataset][class_name].extend(items)
                except Exception as e:
                    print(f"⚠️ Failed to process {custom_id}: {e}")

        except Exception as e:
            print(f"❌ Error downloading or parsing file {file_id}: {e}")

    # Save one JSON file per dataset
    os.makedirs(output_dir, exist_ok=True)
    for dataset, class_dict in dataset_wise.items():
        output_path = os.path.join(output_dir, f"{dataset}.json")
        with open(output_path, "w") as f:
            json.dump(class_dict, f, indent=4)
        print(f"✅ Saved: {output_path}")

    print(f"🎉 Done! Parsed datasets: {', '.join(dataset_wise.keys())}")

In [31]:
descriptions = parse_batch_response(['file-MPn3syhgDFnrCUXvpVsCZA', 'file-7zRJXFB4S2GAh6wVqqdSWH'], client)

📥 Fetching file content from: file-MPn3syhgDFnrCUXvpVsCZA
📥 Fetching file content from: file-7zRJXFB4S2GAh6wVqqdSWH
✅ Saved: ./cars.json
🎉 Done! Parsed datasets: cars


In [89]:
with open(f"{dataset}v2.json", 'w') as f:
    json.dump(descriptions, f, indent=4)