From 5f4e0f708c32c9f476955111a8f3030583175fce Mon Sep 17 00:00:00 2001 From: Carson Date: Wed, 3 Sep 2025 16:17:33 -0500 Subject: [PATCH 1/5] feat: Add chat_batch() for batch jobs --- chatlas/__init__.py | 10 ++ chatlas/_batch_chat.py | 217 ++++++++++++++++++++++++++++ chatlas/_batch_job.py | 231 ++++++++++++++++++++++++++++++ chatlas/_chat.py | 17 +-- chatlas/_provider.py | 88 ++++++++++++ chatlas/_provider_anthropic.py | 107 +++++++++++++- chatlas/_provider_openai.py | 155 ++++++++++++++++++-- tests/batch/country-capitals.json | 140 ++++++++++++++++++ tests/conftest.py | 4 + tests/test_batch_chat.py | 182 +++++++++++++++++++++++ 10 files changed, 1125 insertions(+), 26 deletions(-) create mode 100644 chatlas/_batch_chat.py create mode 100644 chatlas/_batch_job.py create mode 100644 tests/batch/country-capitals.json create mode 100644 tests/test_batch_chat.py diff --git a/chatlas/__init__.py b/chatlas/__init__.py index ece25354..7a066cb5 100644 --- a/chatlas/__init__.py +++ b/chatlas/__init__.py @@ -1,5 +1,11 @@ from . import types from ._auto import ChatAuto +from ._batch_chat import ( + batch_chat, + batch_chat_completed, + batch_chat_structured, + batch_chat_text, +) from ._chat import Chat from ._content import ( ContentToolRequest, @@ -36,6 +42,10 @@ __version__ = "0.0.0" # stub value for docs __all__ = ( + "batch_chat", + "batch_chat_completed", + "batch_chat_structured", + "batch_chat_text", "ChatAnthropic", "ChatAuto", "ChatBedrockAnthropic", diff --git a/chatlas/_batch_chat.py b/chatlas/_batch_chat.py new file mode 100644 index 00000000..b52d1972 --- /dev/null +++ b/chatlas/_batch_chat.py @@ -0,0 +1,217 @@ +""" +Batch chat processing for submitting multiple requests simultaneously. + +This module provides functionality for submitting multiple chat requests +in batches to providers that support it (currently OpenAI and Anthropic). +Batch processing can take up to 24 hours but offers significant cost savings +(up to 50% less than regular requests). +""" + +from __future__ import annotations + +import copy +from pathlib import Path +from typing import TYPE_CHECKING, TypeVar, Union + +from pydantic import BaseModel + +from ._batch_job import BatchJob, ContentT +from ._chat import Chat + +if TYPE_CHECKING: + from ._provider_anthropic import MessageBatchIndividualResponse as AnthropicResult + from ._provider_openai import BatchResult as OpenAIResult + + BatchResult = Union[OpenAIResult, AnthropicResult] + +ChatT = TypeVar("ChatT", bound=Chat) +BaseModelT = TypeVar("BaseModelT", bound=BaseModel) + + +def batch_chat( + chat: ChatT, + prompts: list[ContentT] | list[list[ContentT]], + path: Union[str, Path], + wait: bool = True, +) -> list[ChatT | None]: + """ + Submit multiple chat requests in a batch. + + This function allows you to submit multiple chat requests simultaneously + using provider batch APIs (currently OpenAI and Anthropic). Batch processing + can take up to 24 hours but offers significant cost savings. + + Parameters + ---------- + chat + Chat instance to use for the batch + prompts + List of prompts to process. Each can be a string or list of strings. + path + Path to file (with .json extension) to store batch state + wait + If True, wait for batch to complete. If False, return None if incomplete. + + Returns + ------- + List of Chat objects (one per prompt) if complete, None if wait=False and incomplete. + Individual Chat objects may be None if their request failed. + + Example + ------- + + ```python + from chatlas import ChatOpenAI + + chat = ChatOpenAI() + prompts = [ + "What's the capital of France?", + "What's the capital of Germany?", + "What's the capital of Italy?", + ] + + chats = batch_chat(chat, prompts, "capitals.json") + for i, result_chat in enumerate(chats): + if result_chat: + print(f"Prompt {i + 1}: {result_chat.get_last_turn().text}") + ``` + """ + job = BatchJob(chat, prompts, path, wait=wait) + job.step_until_done() + + chats = [] + assistant_turns = job.result_turns() + for user, assistant in zip(job.user_turns, assistant_turns): + if assistant is not None: + new_chat = copy.deepcopy(chat) + new_chat.add_turn(user) + new_chat.add_turn(assistant) + chats.append(new_chat) + else: + chats.append(None) + + return chats + + +def batch_chat_text( + chat: Chat, + prompts: list[ContentT] | list[list[ContentT]], + path: Union[str, Path], + wait: bool = True, +) -> list[str | None]: + """ + Submit multiple chat requests in a batch and return text responses. + + This is a convenience function that returns just the text of the responses + rather than full Chat objects. + + Parameters + ---------- + chat + Chat instance to use for the batch + prompts + List of prompts to process + path + Path to file (with .json extension) to store batch state + wait + If True, wait for batch to complete + + Return + ------ + List of text responses (or None for failed requests) + """ + chats = batch_chat(chat, prompts, path, wait=wait) + + texts = [] + for x in chats: + if x is None: + texts.append(None) + continue + last_turn = x.get_last_turn() + if last_turn is None: + texts.append(None) + continue + texts.append(last_turn.text) + + return texts + + +def batch_chat_structured( + chat: Chat, + prompts: list[ContentT] | list[list[ContentT]], + path: Union[str, Path], + data_model: type[BaseModelT], + wait: bool = True, +) -> list[BaseModelT | None]: + """ + Submit multiple structured data requests in a batch. + + Parameters + ---------- + chat + Chat instance to use for the batch + prompts + List of prompts to process + path + Path to file (with .json extension) to store batch state + data_model + Pydantic model class for structured responses + wait + If True, wait for batch to complete + + Return + ------ + List of structured data objects (or None for failed requests) + """ + job = BatchJob(chat, prompts, path, data_model=data_model, wait=wait) + result = job.step_until_done() + + if result is None: + return [] + + res: list[BaseModelT | None] = [] + assistant_turns = job.result_turns() + for turn in assistant_turns: + if turn is None: + res.append(None) + else: + json = chat._extract_turn_json(turn) + model = data_model.model_validate(json) + res.append(model) + + return res + + +def batch_chat_completed( + chat: Chat, + prompts: list[ContentT] | list[list[ContentT]], + path: Union[str, Path], +) -> bool: + """ + Check if a batch job is completed without waiting. + + Parameters + ---------- + chat + Chat instance used for the batch + prompts + List of prompts used for the batch + path + Path to batch state file + + Returns + ------- + True if batch is complete, False otherwise + """ + job = BatchJob(chat, prompts, path, wait=False) + stage = job.stage + + if stage == "submitting": + return False + elif stage == "waiting": + status = job._poll() + return not status.working + elif stage == "retrieving" or stage == "done": + return True + else: + raise ValueError(f"Unknown batch stage: {stage}") diff --git a/chatlas/_batch_job.py b/chatlas/_batch_job.py new file mode 100644 index 00000000..13e2f190 --- /dev/null +++ b/chatlas/_batch_job.py @@ -0,0 +1,231 @@ +import hashlib +import json +import time +from datetime import timedelta +from pathlib import Path +from typing import Any, Literal, Optional, TypedDict, TypeVar, Union + +from pydantic import BaseModel +from rich.console import Console +from rich.progress import Progress, SpinnerColumn, TextColumn + +from ._chat import Chat +from ._content import Content +from ._provider import BatchStatus +from ._turn import Turn, user_turn + +BatchStage = Literal["submitting", "waiting", "retrieving", "done"] + + +class BatchStateHash(TypedDict): + provider: str + model: str + prompts: str + user_turns: str + + +class BatchState(BaseModel): + version: int + stage: BatchStage + batch: dict[str, Any] + results: list[dict[str, Any]] + started_at: int + hash: BatchStateHash + + +ContentT = TypeVar("ContentT", bound=Union[str, Content]) + + +class BatchJob: + """ + Manages the lifecycle of a batch processing job. + + A batch job goes through several stages: + 1. "submitting" - Initial submission to the provider + 2. "waiting" - Waiting for processing to complete + 3. "retrieving" - Downloading results + 4. "done" - Processing complete + """ + + def __init__( + self, + chat: Chat, + prompts: list[ContentT] | list[list[ContentT]], + path: Union[str, Path], + data_model: Optional[type[BaseModel]] = None, + wait: bool = True, + ): + if not chat.provider.has_batch_support(): + raise ValueError("Batch requests are not supported by this provider") + + self.chat = chat + self.prompts = prompts + self.path = Path(path) + self.data_model = data_model + self.should_wait = wait + + # Convert prompts to user turns + self.user_turns: list[Turn] = [] + for prompt in prompts: + if not isinstance(prompt, (str, Content)): + turn = user_turn(*prompt) + else: + turn = user_turn(prompt) + self.user_turns.append(turn) + + # Job state management + self.provider = chat.provider + self.stage: BatchStage = "submitting" + self.batch: dict[str, Any] = {} + self.results: list[dict[str, Any]] = [] + + # Load existing state if file exists and is not empty + if self.path.exists() and self.path.stat().st_size > 0: + self._load_state() + else: + self.started_at = time.time() + + def _load_state(self) -> None: + with open(self.path, "r") as f: + state = BatchState.model_validate_json(f.read()) + + self.stage = state.stage + self.batch = state.batch + self.results = state.results + self.started_at = state.started_at + + # Verify hash to ensure consistency + stored_hash = state.hash + current_hash = self._compute_hash() + + for key, value in current_hash.items(): + if stored_hash.get(key) != value: + raise ValueError( + f"Batch state mismatch: {key} doesn't match stored value. " + f"Do you need to pick a different path?" + ) + + def _save_state(self) -> None: + state = BatchState( + version=1, + stage=self.stage, + batch=self.batch, + results=self.results, + started_at=int(self.started_at) if self.started_at else 0, + hash=self._compute_hash(), + ) + + with open(self.path, "w") as f: + f.write(state.model_dump_json(indent=2)) + + def _compute_hash(self) -> BatchStateHash: + turns = self.chat.get_turns(include_system_prompt=True) + return { + "provider": self.provider.name, + "model": self.provider.model, + "prompts": self._hash([str(p) for p in self.prompts]), + "user_turns": self._hash([str(turn) for turn in turns]), + } + + @staticmethod + def _hash(x: Any) -> str: + return hashlib.md5(json.dumps(x, sort_keys=True).encode()).hexdigest() + + def step(self) -> bool: + if self.stage == "submitting": + return self._submit() + elif self.stage == "waiting": + return self._wait() + elif self.stage == "retrieving": + return self._retrieve() + else: + raise ValueError(f"Unknown stage: {self.stage}") + + def step_until_done(self) -> Optional["BatchJob"]: + while self.stage != "done": + if not self.step(): + return None + return self + + def _submit(self) -> bool: + existing_turns = self.chat.get_turns(include_system_prompt=True) + + conversations = [] + for turn in self.user_turns: + conversation = existing_turns + [turn] + conversations.append(conversation) + + self.batch = self.provider.batch_submit(conversations, self.data_model) + self.stage = "waiting" + self._save_state() + return True + + def _wait(self) -> bool: + # Always poll once, even when wait=False + status = self._poll() + + if self.should_wait: + console = Console() + + with Progress( + SpinnerColumn(), + TextColumn("Processing..."), + TextColumn("[{task.fields[elapsed]}]"), + TextColumn("{task.fields[n_processing]} pending |"), + TextColumn("[green]{task.fields[n_succeeded]}[/green] done |"), + TextColumn("[red]{task.fields[n_failed]}[/red] failed"), + console=console, + ) as progress: + task = progress.add_task( + "processing", + elapsed=self._elapsed(), + n_processing=status.n_processing, + n_succeeded=status.n_succeeded, + n_failed=status.n_failed, + ) + + while status.working: + time.sleep(0.5) + status = self._poll() + progress.update( + task, + elapsed=self._elapsed(), + n_processing=status.n_processing, + n_succeeded=status.n_succeeded, + n_failed=status.n_failed, + ) + + if not status.working: + self.stage = "retrieving" + self._save_state() + return True + else: + return False + + def _poll(self) -> "BatchStatus": + if not self.batch: + raise ValueError("No batch to poll") + self.batch = self.provider.batch_poll(self.batch) + self._save_state() + return self.provider.batch_status(self.batch) + + def _elapsed(self) -> str: + return str(timedelta(seconds=int(time.time()) - int(self.started_at))) + + def _retrieve(self) -> bool: + if not self.batch: + raise ValueError("No batch to retrieve") + self.results = self.provider.batch_retrieve(self.batch) + self.stage = "done" + self._save_state() + return True + + def result_turns(self) -> list[Turn | None]: + turns = [] + for result in self.results: + turn = self.provider.batch_result_turn( + result, has_data_model=self.data_model is not None + ) + turns.append(turn) + + return turns diff --git a/chatlas/_chat.py b/chatlas/_chat.py index 945e0605..0d97c1bd 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -1086,18 +1086,7 @@ def _submit_and_extract_data( turn = self.get_last_turn() assert turn is not None - res: list[ContentJson] = [] - for x in turn.contents: - if isinstance(x, ContentJson): - res.append(x) - - if len(res) != 1: - raise ValueError( - f"Data extraction failed: {len(res)} data results received." - ) - - json = res[0] - return json.value + return Chat._extract_turn_json(turn) async def chat_structured_async( self, @@ -1188,6 +1177,10 @@ async def _submit_and_extract_data_async( turn = self.get_last_turn() assert turn is not None + return Chat._extract_turn_json(turn) + + @staticmethod + def _extract_turn_json(turn: Turn) -> dict[str, Any]: res: list[ContentJson] = [] for x in turn.contents: if isinstance(x, ContentJson): diff --git a/chatlas/_provider.py b/chatlas/_provider.py index deabcdf6..7e216ec2 100644 --- a/chatlas/_provider.py +++ b/chatlas/_provider.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod from datetime import date from typing import ( + Any, AsyncIterable, Generic, Iterable, @@ -100,6 +101,16 @@ class StandardModelParams(TypedDict, total=False): ] +# Provider-agnostic batch status info +class BatchStatus(BaseModel): + """Status information for a batch job.""" + + working: bool + n_processing: int + n_succeeded: int + n_failed: int + + class Provider( ABC, Generic[ @@ -261,3 +272,80 @@ def translate_model_params( @abstractmethod def supported_model_params(self) -> set[StandardModelParamNames]: ... + + def has_batch_support(self) -> bool: + """ + Returns whether this provider supports batch processing. + Override this method to return True for providers that implement batch methods. + """ + return False + + def batch_submit( + self, + conversations: list[list[Turn]], + data_model: Optional[type[BaseModel]] = None, + ) -> dict[str, Any]: + """ + Submit a batch of conversations for processing. + + Args: + conversations: List of conversation histories (each is a list of Turns) + data_model: Optional structured data model for responses + + Returns: + BatchInfo containing batch job information + """ + raise NotImplementedError("This provider does not support batch processing") + + def batch_poll(self, batch: dict[str, Any]) -> dict[str, Any]: + """ + Poll the status of a submitted batch. + + Args: + batch: Batch information returned from batch_submit + + Returns: + Updated batch information + """ + raise NotImplementedError("This provider does not support batch processing") + + def batch_status(self, batch: dict[str, Any]) -> BatchStatus: + """ + Get the status of a batch. + + Args: + batch: Batch information + + Returns: + BatchStatus with processing status information + """ + raise NotImplementedError("This provider does not support batch processing") + + def batch_retrieve(self, batch: dict[str, Any]) -> list[dict[str, Any]]: + """ + Retrieve results from a completed batch. + + Args: + batch: Batch information + + Returns: + List of BatchResult objects, one for each request in the batch + """ + raise NotImplementedError("This provider does not support batch processing") + + def batch_result_turn( + self, + result: dict[str, Any], + has_data_model: bool = False, + ) -> Turn | None: + """ + Convert a batch result to a Turn. + + Args: + result: Individual BatchResult from batch_retrieve + has_data_model: Whether the request used a structured data model + + Returns: + Turn object or None if the result was an error + """ + raise NotImplementedError("This provider does not support batch processing") diff --git a/chatlas/_provider_anthropic.py b/chatlas/_provider_anthropic.py index 29fbba55..4ca64cc3 100644 --- a/chatlas/_provider_anthropic.py +++ b/chatlas/_provider_anthropic.py @@ -1,10 +1,12 @@ from __future__ import annotations import base64 +import re import warnings from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast, overload import orjson +from openai.types.chat import ChatCompletionToolParam from pydantic import BaseModel from ._chat import Chat @@ -21,13 +23,20 @@ ContentToolResultResource, ) from ._logging import log_model_default -from ._provider import ModelInfo, Provider, StandardModelParamNames, StandardModelParams +from ._provider import ( + BatchStatus, + ModelInfo, + Provider, + StandardModelParamNames, + StandardModelParams, +) from ._tokens import get_token_pricing, tokens_log from ._tools import Tool, basemodel_to_param_schema from ._turn import Turn, user_turn from ._utils import split_http_client_kwargs if TYPE_CHECKING: + import anthropic from anthropic.types import ( Message, MessageParam, @@ -38,11 +47,12 @@ ) from anthropic.types.document_block_param import DocumentBlockParam from anthropic.types.image_block_param import ImageBlockParam + from anthropic.types.message_create_params import MessageCreateParamsNonStreaming + from anthropic.types.messages.batch_create_params import Request as BatchRequest from anthropic.types.model_param import ModelParam from anthropic.types.text_block_param import TextBlockParam from anthropic.types.tool_result_block_param import ToolResultBlockParam from anthropic.types.tool_use_block_param import ToolUseBlockParam - from openai.types.chat import ChatCompletionToolParam from .types.anthropic import ChatBedrockClientArgs, ChatClientArgs, SubmitInputArgs @@ -631,6 +641,99 @@ def _as_turn(self, completion: Message, has_data_model=False) -> Turn: completion=completion, ) + def has_batch_support(self) -> bool: + return True + + def batch_submit( + self, + conversations: list[list[Turn]], + data_model: Optional[type[BaseModel]] = None, + ): + requests: list["BatchRequest"] = [] + + for i, turns in enumerate(conversations): + kwargs = self._chat_perform_args( + stream=False, + turns=turns, + tools={}, + data_model=data_model, + ) + + params: "MessageCreateParamsNonStreaming" = { + "messages": kwargs.get("messages", {}), + "model": self.model, + "max_tokens": kwargs.get("max_tokens", 4096), + } + + # If data_model, tools/tool_choice should be present + tools = kwargs.get("tools") + tool_choice = kwargs.get("tool_choice") + if tools and not isinstance(tools, anthropic.NotGiven): + params["tools"] = tools + if tool_choice and not isinstance(tool_choice, anthropic.NotGiven): + params["tool_choice"] = tool_choice + + requests.append({"custom_id": f"request-{i}", "params": params}) + + batch = self._client.messages.batches.create(requests=requests) + return batch.model_dump() + + def batch_poll(self, batch): + from anthropic.types.messages import MessageBatch + + batch = MessageBatch.model_validate(batch) + b = self._client.messages.batches.retrieve(batch.id) + return b.model_dump() + + def batch_status(self, batch) -> "BatchStatus": + from anthropic.types.messages import MessageBatch + + batch = MessageBatch.model_validate(batch) + status = batch.processing_status + counts = batch.request_counts + + return BatchStatus( + working=status != "ended", + n_processing=counts.processing, + n_succeeded=counts.succeeded, + n_failed=counts.errored + counts.canceled + counts.expired, + ) + + # https://docs.anthropic.com/en/api/retrieving-message-batch-results + def batch_retrieve(self, batch): + from anthropic.types.messages import MessageBatch + + batch = MessageBatch.model_validate(batch) + if batch.results_url is None: + raise ValueError("Batch has no results URL") + + results: list[dict[str, Any]] = [] + for res in self._client.messages.batches.results(batch.id): + results.append(res.model_dump()) + + # Sort by custom_id to maintain order + def extract_id(x: str): + match = re.search(r"-(\d+)$", x) + return int(match.group(1)) if match else 0 + + results.sort(key=lambda x: extract_id(x.get("custom_id", ""))) + + return results + + def batch_result_turn(self, result, has_data_model: bool = False) -> Turn | None: + from anthropic.types.messages.message_batch_individual_response import ( + MessageBatchIndividualResponse, + ) + + result = MessageBatchIndividualResponse.model_validate(result) + if result.result.type != "succeeded": + # TODO: offer advice on what to do? + warnings.warn(f"Batch request didn't succeed: {result.result}") + return None + + message = result.result.message + return self._as_turn(message, has_data_model) + def ChatBedrockAnthropic( *, diff --git a/chatlas/_provider_openai.py b/chatlas/_provider_openai.py index 6744fbef..d41eb5ef 100644 --- a/chatlas/_provider_openai.py +++ b/chatlas/_provider_openai.py @@ -1,11 +1,18 @@ from __future__ import annotations import base64 +import json +import os +import re +import tempfile +import warnings from datetime import datetime from typing import TYPE_CHECKING, Any, Literal, Optional, cast, overload import orjson from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI +from openai.types.batch import Batch +from openai.types.chat import ChatCompletion, ChatCompletionChunk from pydantic import BaseModel from ._chat import Chat @@ -24,18 +31,20 @@ ) from ._logging import log_model_default from ._merge import merge_dicts -from ._provider import ModelInfo, Provider, StandardModelParamNames, StandardModelParams +from ._provider import ( + BatchStatus, + ModelInfo, + Provider, + StandardModelParamNames, + StandardModelParams, +) from ._tokens import get_token_pricing, tokens_log from ._tools import Tool, basemodel_to_param_schema from ._turn import Turn, user_turn from ._utils import MISSING, MISSING_TYPE, is_testing, split_http_client_kwargs if TYPE_CHECKING: - from openai.types.chat import ( - ChatCompletion, - ChatCompletionChunk, - ChatCompletionMessageParam, - ) + from openai.types.chat import ChatCompletionMessageParam from openai.types.chat.chat_completion_assistant_message_param import ( ContentArrayOfContentPart, ) @@ -45,10 +54,6 @@ from openai.types.chat_model import ChatModel from .types.openai import ChatAzureClientArgs, ChatClientArgs, SubmitInputArgs -else: - ChatCompletion = object - ChatCompletionChunk = object - # The dictionary form of ChatCompletion (TODO: stronger typing)? ChatCompletionDict = dict[str, Any] @@ -171,6 +176,21 @@ def ChatOpenAI( ) +# Seems there is no native typing support for `files.content()` results +# so mock them based on the docs here +# https://platform.openai.com/docs/guides/batch#5-retrieve-the-results +class BatchResult(BaseModel): + id: str + custom_id: str + response: BatchResultResponse + + +class BatchResultResponse(BaseModel): + status_code: int + request_id: str + body: ChatCompletionDict + + class OpenAIProvider( Provider[ChatCompletion, ChatCompletionChunk, ChatCompletionDict, "SubmitInputArgs"] ): @@ -353,8 +373,6 @@ def stream_merge_chunks(self, completion, chunk): return merge_dicts(completion, chunkd) def stream_turn(self, completion, has_data_model) -> Turn: - from openai.types.chat import ChatCompletion - delta = completion["choices"][0].pop("delta") # type: ignore completion["choices"][0]["message"] = delta # type: ignore completion = ChatCompletion.construct(**completion) @@ -662,6 +680,119 @@ def supported_model_params(self) -> set[StandardModelParamNames]: "stop_sequences", } + def has_batch_support(self) -> bool: + return True + + def batch_submit( + self, + conversations: list[list[Turn]], + data_model: Optional[type[BaseModel]] = None, + ): + # First put the requests in a file + # https://platform.openai.com/docs/api-reference/batch/request-input + # https://platform.openai.com/docs/api-reference/batch + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + temp_path = f.name + + for i, turns in enumerate(conversations): + kwargs = self._chat_perform_args( + stream=False, + turns=turns, + tools={}, + data_model=data_model, + ) + + body = { + "messages": kwargs.get("messages", []), + "model": self.model, + } + + if "response_format" in kwargs: + body["response_format"] = kwargs["response_format"] + + request = { + "custom_id": f"request-{i}", + "method": "POST", + "url": "/v1/chat/completions", + "body": body, + } + + f.write(orjson.dumps(request).decode() + "\n") + + try: + with open(temp_path, "rb") as f: + file_response = self._client.files.create(file=f, purpose="batch") + + batch = self._client.batches.create( + input_file_id=file_response.id, + endpoint="/v1/chat/completions", + completion_window="24h", + ) + + return batch.model_dump() + finally: + os.unlink(temp_path) + + def batch_poll(self, batch): + batch = Batch.model_validate(batch) + b = self._client.batches.retrieve(batch.id) + return b.model_dump() + + def batch_status(self, batch): + batch = Batch.model_validate(batch) + counts = batch.request_counts + total, completed, failed = 0, 0, 0 + if counts is not None: + total = counts.total + completed = counts.completed + failed = counts.failed + + return BatchStatus( + working=batch.status not in ["completed", "failed", "cancelled"], + n_processing=total - completed - failed, + n_succeeded=completed, + n_failed=failed, + ) + + def batch_retrieve(self, batch): + batch = Batch.model_validate(batch) + if batch.output_file_id is None: + raise ValueError("Batch has no output file") + + # Download and parse JSONL results + response = self._client.files.content(batch.output_file_id) + results: list[dict[str, Any]] = [] + for line in response.text.splitlines(): + results.append(json.loads(line)) + + # Sort by custom_id to maintain order + def extract_id(x: str): + match = re.search(r"-(\d+)$", x) + return int(match.group(1)) if match else 0 + + results.sort(key=lambda x: int(extract_id(x.get("custom_id", "")))) + + return results + + def batch_result_turn( + self, + result, + has_data_model: bool = False, + ) -> Turn | None: + response = BatchResult.model_validate(result).response + if response.status_code != 200: + # TODO: offer advice on what to do? + warnings.warn(f"Batch request failed: {response.body}") + return None + + completion = ChatCompletion.construct(**response.body) + return self._as_turn(completion, has_data_model) + + +# ------------------------------------------------------------------------------------- +# Azure OpenAI Chat +# ------------------------------------------------------------------------------------- + def ChatAzureOpenAI( *, diff --git a/tests/batch/country-capitals.json b/tests/batch/country-capitals.json new file mode 100644 index 00000000..a2e58515 --- /dev/null +++ b/tests/batch/country-capitals.json @@ -0,0 +1,140 @@ +{ + "version": 1, + "stage": "done", + "batch": { + "id": "batch_68bf52330768819090342407c7ab923f", + "completion_window": "24h", + "created_at": 1757368883, + "endpoint": "/v1/chat/completions", + "input_file_id": "file-X3SVvZ5wstpVkJCZEPcHun", + "object": "batch", + "status": "completed", + "cancelled_at": null, + "cancelling_at": null, + "completed_at": 1757385690, + "error_file_id": null, + "errors": null, + "expired_at": null, + "expires_at": 1757455283, + "failed_at": null, + "finalizing_at": 1757385689, + "in_progress_at": 1757368945, + "metadata": null, + "output_file_id": "file-HaWC2o5DbxVxjK5c77gk6n", + "request_counts": { + "completed": 2, + "failed": 0, + "total": 2 + }, + "usage": { + "input_tokens": 26, + "output_tokens": 14, + "total_tokens": 40, + "input_tokens_details": { + "cached_tokens": 0 + }, + "output_tokens_details": { + "reasoning_tokens": 0 + } + } + }, + "results": [ + { + "id": "batch_req_68bf93da18a48190a48d32933accb13f", + "custom_id": "request-0", + "response": { + "status_code": 200, + "request_id": "2d1f76c76844fd86c89e415880aa2196", + "body": { + "id": "chatcmpl-CDiWm6i1on5vXZobDQTYtHisxpoWy", + "object": "chat.completion", + "created": 1757385464, + "model": "gpt-4o-mini-2024-07-18", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "The capital of France is Paris.", + "refusal": null, + "annotations": [] + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 13, + "completion_tokens": 7, + "total_tokens": 20, + "prompt_tokens_details": { + "cached_tokens": 0, + "audio_tokens": 0 + }, + "completion_tokens_details": { + "reasoning_tokens": 0, + "audio_tokens": 0, + "accepted_prediction_tokens": 0, + "rejected_prediction_tokens": 0 + } + }, + "service_tier": "default", + "system_fingerprint": "fp_8bda4d3a2c" + } + }, + "error": null + }, + { + "id": "batch_req_68bf93da20208190802c1abcf31b622c", + "custom_id": "request-1", + "response": { + "status_code": 200, + "request_id": "c85dae11995328f94bca24f518d50efd", + "body": { + "id": "chatcmpl-CDiWmegbCu3MA9KsewSkdCRpCxtP6", + "object": "chat.completion", + "created": 1757385464, + "model": "gpt-4o-mini-2024-07-18", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "The capital of Germany is Berlin.", + "refusal": null, + "annotations": [] + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 13, + "completion_tokens": 7, + "total_tokens": 20, + "prompt_tokens_details": { + "cached_tokens": 0, + "audio_tokens": 0 + }, + "completion_tokens_details": { + "reasoning_tokens": 0, + "audio_tokens": 0, + "accepted_prediction_tokens": 0, + "rejected_prediction_tokens": 0 + } + }, + "service_tier": "default", + "system_fingerprint": "fp_8bda4d3a2c" + } + }, + "error": null + } + ], + "started_at": 1757368881, + "hash": { + "provider": "OpenAI", + "model": "gpt-4o-mini", + "prompts": "123ef0f8b30649a2e428cd300ed0e78a", + "user_turns": "d751713988987e9331980363e24189ce" + } +} \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index e1257832..f4ae79c2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -293,3 +293,7 @@ def assert_list_models(chat_fun: ChatFun): @pytest.fixture def test_images_dir(): return Path(__file__).parent / "images" + +@pytest.fixture +def test_batch_dir(): + return Path(__file__).parent / "batch" diff --git a/tests/test_batch_chat.py b/tests/test_batch_chat.py new file mode 100644 index 00000000..a1372699 --- /dev/null +++ b/tests/test_batch_chat.py @@ -0,0 +1,182 @@ +import tempfile + +import pytest +from pydantic import BaseModel + +from chatlas import ChatAnthropic, ChatGoogle, ChatOpenAI +from chatlas._batch_chat import ( + BatchJob, + batch_chat, + batch_chat_completed, + batch_chat_text, +) +from chatlas._provider import BatchStatus + + +class CountryCapital(BaseModel): + name: str + + +def test_can_retrieve_batch(test_batch_dir): + chat = ChatOpenAI(model="gpt-4o-mini") + prompts = ["What's the capital of France?", "What's the capital of Germany?"] + + chats = batch_chat( + chat, + prompts, + test_batch_dir / "country-capitals.json", + ) + assert len(chats) == 2 + assert chats[0] is not None + assert chats[1] is not None + + out = batch_chat_text( + chat, + prompts, + test_batch_dir / "country-capitals.json", + ) + assert len(out) == 2 + assert out[0] is not None + assert out[1] is not None + assert "Paris" in out[0] + assert "Berlin" in out[1] + + # TODO: incorporate structured output test + # capitals = batch_chat_structured( + # chat, + # prompts, + # test_batch_dir / "country-capitals-structured.json", + # CountryCapital, + # ) + # assert len(capitals) == 2 + # assert capitals[0] is not None + # assert capitals[1] is not None + # assert capitals[0].name == "Paris" + # assert capitals[1].name == "Berlin" + + +def test_informative_errors(test_batch_dir): + with pytest.raises(ValueError, match="not supported by this provider"): + batch_chat( + ChatGoogle(), + [], + "foo.json", + ) + + with pytest.raises(ValueError, match="provider doesn't match stored value"): + batch_chat( + ChatAnthropic(), + [], + test_batch_dir / "country-capitals.json", + ) + + with pytest.raises(ValueError, match="model doesn't match stored value"): + batch_chat( + ChatOpenAI(model="gpt-5"), + [], + test_batch_dir / "country-capitals.json", + ) + + with pytest.raises(ValueError, match="prompts doesn't match stored value"): + batch_chat( + ChatOpenAI(model="gpt-4o-mini"), + ["foo"], + test_batch_dir / "country-capitals.json", + ) + + with pytest.raises(ValueError, match="user_turns doesn't match stored value"): + batch_chat( + ChatOpenAI(model="gpt-4o-mini", system_prompt="foo"), + [ + "What's the capital of France?", + "What's the capital of Germany?", + ], + test_batch_dir / "country-capitals.json", + ) + + +def ChatOpenAIMockBatchSubmit(*, working: bool = False, **kwargs): + chat = ChatOpenAI(**kwargs) + + def batch_submit_mock(*args, **kwargs): + return {"id": "123"} + + def batch_poll_mock(*args, **kwargs): + return {"id": "123", "results": True} + + def batch_status_mock(*args, **kwargs): + return BatchStatus( + working=working, + n_processing=0, + n_failed=0, + n_succeeded=1, + ) + + def batch_retrieve_mock(*args, **kwargs): + return [{"x": 1, "y": 2}] + + chat.provider.batch_submit = batch_submit_mock + chat.provider.batch_poll = batch_poll_mock + chat.provider.batch_status = batch_status_mock + chat.provider.batch_retrieve = batch_retrieve_mock + + return chat + + +def test_steps_in_logical_order(): + with tempfile.NamedTemporaryFile() as temp_file: + chat = ChatOpenAIMockBatchSubmit() + prompts = ["What's your name?"] + job = BatchJob(chat, prompts, temp_file.name) + + def completed(): + return batch_chat_completed(chat, prompts, temp_file.name) + + assert job.stage == "submitting" + assert not completed() + + job.step() + assert job.stage == "waiting" + job._load_state() + assert job.stage == "waiting" + assert job.batch == {"id": "123"} + assert completed() + + job.step() + assert job.stage == "retrieving" + job._load_state() + assert job.stage == "retrieving" + assert job.batch == {"id": "123", "results": True} + assert completed() + + job.step() + assert job.stage == "done" + job._load_state() + assert job.stage == "done" + assert job.results == [{"x": 1, "y": 2}] + assert completed() + + +def test_run_all_steps_at_once(): + with tempfile.NamedTemporaryFile() as temp_file: + chat = ChatOpenAIMockBatchSubmit() + prompts = ["What's your name?"] + job = BatchJob(chat, prompts, temp_file.name) + + job = job.step_until_done() + assert job is not None + assert job.stage == "done" + assert job.results == [{"x": 1, "y": 2}] + + +def test_can_avoid_blocking(): + with tempfile.NamedTemporaryFile() as temp_file: + chat = ChatOpenAIMockBatchSubmit(working=True) + job = BatchJob( + chat, + ["What's your name?"], + temp_file.name, + wait=False, + ) + + assert job.step_until_done() is None From 02be12c7fae76de9b4f48e80dd895c88cbc417d0 Mon Sep 17 00:00:00 2001 From: Carson Date: Tue, 9 Sep 2025 11:45:54 -0500 Subject: [PATCH 2/5] TypedDict version compatibility --- chatlas/_batch_job.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/chatlas/_batch_job.py b/chatlas/_batch_job.py index 13e2f190..16c94916 100644 --- a/chatlas/_batch_job.py +++ b/chatlas/_batch_job.py @@ -3,7 +3,7 @@ import time from datetime import timedelta from pathlib import Path -from typing import Any, Literal, Optional, TypedDict, TypeVar, Union +from typing import Any, Literal, Optional, TypeVar, Union from pydantic import BaseModel from rich.console import Console @@ -13,6 +13,7 @@ from ._content import Content from ._provider import BatchStatus from ._turn import Turn, user_turn +from ._typing_extensions import TypedDict BatchStage = Literal["submitting", "waiting", "retrieving", "done"] From 422795791be61f2d4549e32e2d50bb8d794a6715 Mon Sep 17 00:00:00 2001 From: Carson Date: Tue, 9 Sep 2025 11:57:25 -0500 Subject: [PATCH 3/5] Future annotations; test batch submission --- chatlas/_batch_job.py | 2 ++ tests/test_batch_chat.py | 20 ++++++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/chatlas/_batch_job.py b/chatlas/_batch_job.py index 16c94916..1e9d00b6 100644 --- a/chatlas/_batch_job.py +++ b/chatlas/_batch_job.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import hashlib import json import time diff --git a/tests/test_batch_chat.py b/tests/test_batch_chat.py index a1372699..7b1c1b28 100644 --- a/tests/test_batch_chat.py +++ b/tests/test_batch_chat.py @@ -55,6 +55,26 @@ def test_can_retrieve_batch(test_batch_dir): # assert capitals[1].name == "Berlin" +def test_can_submit_openai_batch(): + with tempfile.NamedTemporaryFile() as temp_file: + chat = ChatOpenAI() + prompts = ["What's the capital of France?", "What's the capital of Germany?"] + job = BatchJob(chat, prompts, temp_file.name, wait=False) + assert job.stage == "submitting" + job.step() + assert job.stage == "waiting" + + +def test_can_submit_anthropic_batch(): + with tempfile.NamedTemporaryFile() as temp_file: + chat = ChatAnthropic() + prompts = ["What's the capital of France?", "What's the capital of Germany?"] + job = BatchJob(chat, prompts, temp_file.name, wait=False) + assert job.stage == "submitting" + job.step() + assert job.stage == "waiting" + + def test_informative_errors(test_batch_dir): with pytest.raises(ValueError, match="not supported by this provider"): batch_chat( From 018e5dacb09296aefb7fe31eba374e061a50cd64 Mon Sep 17 00:00:00 2001 From: Carson Date: Tue, 9 Sep 2025 13:42:43 -0500 Subject: [PATCH 4/5] Simplify; test batch_chat_structured() result --- chatlas/_batch_chat.py | 8 +- tests/batch/country-capitals-structured.json | 140 +++++++++++++++++++ tests/test_batch_chat.py | 24 ++-- 3 files changed, 153 insertions(+), 19 deletions(-) create mode 100644 tests/batch/country-capitals-structured.json diff --git a/chatlas/_batch_chat.py b/chatlas/_batch_chat.py index b52d1972..9a55baf3 100644 --- a/chatlas/_batch_chat.py +++ b/chatlas/_batch_chat.py @@ -11,19 +11,13 @@ import copy from pathlib import Path -from typing import TYPE_CHECKING, TypeVar, Union +from typing import TypeVar, Union from pydantic import BaseModel from ._batch_job import BatchJob, ContentT from ._chat import Chat -if TYPE_CHECKING: - from ._provider_anthropic import MessageBatchIndividualResponse as AnthropicResult - from ._provider_openai import BatchResult as OpenAIResult - - BatchResult = Union[OpenAIResult, AnthropicResult] - ChatT = TypeVar("ChatT", bound=Chat) BaseModelT = TypeVar("BaseModelT", bound=BaseModel) diff --git a/tests/batch/country-capitals-structured.json b/tests/batch/country-capitals-structured.json new file mode 100644 index 00000000..bd674d27 --- /dev/null +++ b/tests/batch/country-capitals-structured.json @@ -0,0 +1,140 @@ +{ + "version": 1, + "stage": "done", + "batch": { + "id": "batch_68c045e8d5108190949582bc3c028af4", + "completion_window": "24h", + "created_at": 1757431272, + "endpoint": "/v1/chat/completions", + "input_file_id": "file-2oi7LNA358NPBuSrWXcZsF", + "object": "batch", + "status": "completed", + "cancelled_at": null, + "cancelling_at": null, + "completed_at": 1757440259, + "error_file_id": null, + "errors": null, + "expired_at": null, + "expires_at": 1757517672, + "failed_at": null, + "finalizing_at": 1757440258, + "in_progress_at": 1757431335, + "metadata": null, + "output_file_id": "file-KBzTN6egssBMxEJ9E3iGaq", + "request_counts": { + "completed": 2, + "failed": 0, + "total": 2 + }, + "usage": { + "input_tokens": 86, + "output_tokens": 10, + "total_tokens": 96, + "input_tokens_details": { + "cached_tokens": 0 + }, + "output_tokens_details": { + "reasoning_tokens": 0 + } + } + }, + "results": [ + { + "id": "batch_req_68c069035300819093210276a3530651", + "custom_id": "request-0", + "response": { + "status_code": 200, + "request_id": "d968ee5ca37299bddef923f7217ff987", + "body": { + "id": "chatcmpl-CDwm9AC2NCS5m1z5IcY0XwBbvIfdP", + "object": "chat.completion", + "created": 1757440233, + "model": "gpt-4o-mini-2024-07-18", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "{\"name\":\"Paris\"}", + "refusal": null, + "annotations": [] + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 43, + "completion_tokens": 5, + "total_tokens": 48, + "prompt_tokens_details": { + "cached_tokens": 0, + "audio_tokens": 0 + }, + "completion_tokens_details": { + "reasoning_tokens": 0, + "audio_tokens": 0, + "accepted_prediction_tokens": 0, + "rejected_prediction_tokens": 0 + } + }, + "service_tier": "default", + "system_fingerprint": "fp_e665f7564b" + } + }, + "error": null + }, + { + "id": "batch_req_68c06903978c8190ad48a117b7882887", + "custom_id": "request-1", + "response": { + "status_code": 200, + "request_id": "eb6973b88aba352cfa2be3a905a20094", + "body": { + "id": "chatcmpl-CDwmCHXOgh9gssdwGcVGEj3pE9zZl", + "object": "chat.completion", + "created": 1757440236, + "model": "gpt-4o-mini-2024-07-18", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "{\"name\":\"Berlin\"}", + "refusal": null, + "annotations": [] + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 43, + "completion_tokens": 5, + "total_tokens": 48, + "prompt_tokens_details": { + "cached_tokens": 0, + "audio_tokens": 0 + }, + "completion_tokens_details": { + "reasoning_tokens": 0, + "audio_tokens": 0, + "accepted_prediction_tokens": 0, + "rejected_prediction_tokens": 0 + } + }, + "service_tier": "default", + "system_fingerprint": "fp_8bda4d3a2c" + } + }, + "error": null + } + ], + "started_at": 1757431270, + "hash": { + "provider": "OpenAI", + "model": "gpt-4o-mini", + "prompts": "123ef0f8b30649a2e428cd300ed0e78a", + "user_turns": "d751713988987e9331980363e24189ce" + } +} \ No newline at end of file diff --git a/tests/test_batch_chat.py b/tests/test_batch_chat.py index 7b1c1b28..401335c7 100644 --- a/tests/test_batch_chat.py +++ b/tests/test_batch_chat.py @@ -8,6 +8,7 @@ BatchJob, batch_chat, batch_chat_completed, + batch_chat_structured, batch_chat_text, ) from chatlas._provider import BatchStatus @@ -41,18 +42,17 @@ def test_can_retrieve_batch(test_batch_dir): assert "Paris" in out[0] assert "Berlin" in out[1] - # TODO: incorporate structured output test - # capitals = batch_chat_structured( - # chat, - # prompts, - # test_batch_dir / "country-capitals-structured.json", - # CountryCapital, - # ) - # assert len(capitals) == 2 - # assert capitals[0] is not None - # assert capitals[1] is not None - # assert capitals[0].name == "Paris" - # assert capitals[1].name == "Berlin" + capitals = batch_chat_structured( + chat, + prompts, + test_batch_dir / "country-capitals-structured.json", + CountryCapital, + ) + assert len(capitals) == 2 + assert capitals[0] is not None + assert capitals[1] is not None + assert capitals[0].name == "Paris" + assert capitals[1].name == "Berlin" def test_can_submit_openai_batch(): From 72639016fdd82348701ab237b8e5ed174cf61612 Mon Sep 17 00:00:00 2001 From: Carson Date: Tue, 9 Sep 2025 15:58:57 -0500 Subject: [PATCH 5/5] Add to reference; update changelog --- CHANGELOG.md | 1 + docs/_quarto.yml | 17 ++++++++++++----- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 028a2329..24376a29 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### New features +* Added support for submitting multiple chats in one batch. With batch submission, results can take up to 24 hours to complete, but in return you pay ~50% less than usual. For more, see the [reference](https://posit-dev.github.io/chatlas/reference/) for `batch_chat()`, `batch_chat_text()`, `batch_chat_structured()` and `batch_chat_completed()`. (#177) * The `Chat` class gains new `.chat_structured()` (and `.chat_structured_async()`) methods. These methods supersede the now deprecated `.extract_data()` (and `.extract_data_async()`). The only difference is that the new methods return a `BaseModel` instance (instead of a `dict()`), leading to a better type hinting/checking experience. (#175) ## [0.12.0] - 2025-09-08 diff --git a/docs/_quarto.yml b/docs/_quarto.yml index 94a7c434..a7fa88f2 100644 --- a/docs/_quarto.yml +++ b/docs/_quarto.yml @@ -151,16 +151,23 @@ quartodoc: contents: - content_pdf_file - content_pdf_url - - title: Prompt interpolation - desc: Interpolate variables into prompt templates - contents: - - interpolate - - interpolate_file - title: Tool calling desc: Add context to python function before registering it as a tool. contents: - Tool - ToolRejectError + - title: Batch chat + desc: Submit multiple chats in one batch + contents: + - batch_chat + - batch_chat_text + - batch_chat_structured + - batch_chat_completed + - title: Prompt interpolation + desc: Interpolate variables into prompt templates + contents: + - interpolate + - interpolate_file - title: Turns desc: A provider-agnostic representation of content generated during an assistant/user turn. contents: