Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 44 additions & 3 deletions src/together/abstract/api_requestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,14 +619,29 @@ def _interpret_response(
) -> Tuple[TogetherResponse | Iterator[TogetherResponse], bool]:
"""Returns the response(s) and a bool indicating whether it is a stream."""
content_type = result.headers.get("Content-Type", "")

if stream and "text/event-stream" in content_type:
# SSE format streaming
return (
self._interpret_response_line(
line, result.status_code, result.headers, stream=True
)
for line in parse_stream(result.iter_lines())
), True
elif stream and content_type in [
"audio/wav",
"audio/mpeg",
"application/octet-stream",
]:
# Binary audio streaming - return chunks as binary data
def binary_stream_generator() -> Iterator[TogetherResponse]:
for chunk in result.iter_content(chunk_size=8192):
if chunk: # Skip empty chunks
yield TogetherResponse(chunk, dict(result.headers))

return binary_stream_generator(), True
else:
# Non-streaming response
if content_type in ["application/octet-stream", "audio/wav", "audio/mpeg"]:
content = result.content
else:
Expand All @@ -648,23 +663,49 @@ async def _interpret_async_response(
| tuple[TogetherResponse, bool]
):
"""Returns the response(s) and a bool indicating whether it is a stream."""
if stream and "text/event-stream" in result.headers.get("Content-Type", ""):
content_type = result.headers.get("Content-Type", "")

if stream and "text/event-stream" in content_type:
# SSE format streaming
return (
self._interpret_response_line(
line, result.status, result.headers, stream=True
)
async for line in parse_stream_async(result.content)
), True
elif stream and content_type in [
"audio/wav",
"audio/mpeg",
"application/octet-stream",
]:
# Binary audio streaming - return chunks as binary data
async def binary_stream_generator() -> (
AsyncGenerator[TogetherResponse, None]
):
async for chunk in result.content.iter_chunked(8192):
if chunk: # Skip empty chunks
yield TogetherResponse(chunk, dict(result.headers))

return binary_stream_generator(), True
else:
# Non-streaming response
try:
await result.read()
content = await result.read()
except (aiohttp.ServerTimeoutError, asyncio.TimeoutError) as e:
raise error.Timeout("Request timed out") from e
except aiohttp.ClientError as e:
utils.log_warn(e, body=result.content)

if content_type in ["application/octet-stream", "audio/wav", "audio/mpeg"]:
# Binary content - keep as bytes
response_content: str | bytes = content
else:
# Text content - decode to string
response_content = content.decode("utf-8")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Async Error Handling Fails to Define Content

When an aiohttp.ClientError is caught during await result.read() in _interpret_async_response, the content variable is not defined. While the error is logged, execution proceeds, causing a NameError when content is later referenced.

Fix in Cursor Fix in Web


return (
self._interpret_response_line(
(await result.read()).decode("utf-8"),
response_content,
result.status,
result.headers,
stream=False,
Expand Down
125 changes: 112 additions & 13 deletions src/together/types/audio_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,27 +82,126 @@ class AudioSpeechStreamResponse(BaseModel):

model_config = ConfigDict(arbitrary_types_allowed=True)

def stream_to_file(self, file_path: str) -> None:
def stream_to_file(
self, file_path: str, response_format: AudioResponseFormat | str | None = None
) -> None:
"""
Save the audio response to a file.

For non-streaming responses, writes the complete file as received.
For streaming responses, collects binary chunks and constructs a valid
file format based on the response_format parameter.

Args:
file_path: Path where the audio file should be saved.
response_format: Format of the audio (wav, mp3, or raw). If not provided,
will attempt to infer from file extension or default to wav.
"""
# Determine response format
if response_format is None:
# Infer from file extension
ext = file_path.lower().split(".")[-1] if "." in file_path else ""
if ext in ["wav"]:
response_format = AudioResponseFormat.WAV
elif ext in ["mp3", "mpeg"]:
response_format = AudioResponseFormat.MP3
elif ext in ["raw", "pcm"]:
response_format = AudioResponseFormat.RAW
else:
# Default to WAV if unknown
response_format = AudioResponseFormat.WAV

if isinstance(response_format, str):
response_format = AudioResponseFormat(response_format)

if isinstance(self.response, TogetherResponse):
# save response to file
# Non-streaming: save complete file
with open(file_path, "wb") as f:
f.write(self.response.data)

elif isinstance(self.response, Iterator):
# Streaming: collect binary chunks
audio_chunks = []
for chunk in self.response:
if isinstance(chunk.data, bytes):
audio_chunks.append(chunk.data)
elif isinstance(chunk.data, dict):
# SSE format with JSON/base64
try:
stream_event = AudioSpeechStreamEventResponse(
response={"data": chunk.data}
)
if isinstance(stream_event.response, StreamSentinel):
break
audio_chunks.append(
base64.b64decode(stream_event.response.data.b64)
)
except Exception:
continue # Skip malformed chunks

if not audio_chunks:
raise ValueError("No audio data received in streaming response")

# Concatenate all chunks
audio_data = b"".join(audio_chunks)

with open(file_path, "wb") as f:
for chunk in self.response:
# Try to parse as stream chunk
stream_event_response = AudioSpeechStreamEventResponse(
response={"data": chunk.data}
if response_format == AudioResponseFormat.WAV:
if audio_data.startswith(b"RIFF"):
# Already a valid WAV file
f.write(audio_data)
else:
# Raw PCM - add WAV header
self._write_wav_header(f, audio_data)
elif response_format == AudioResponseFormat.MP3:
# MP3 format: Check if data is actually MP3 or raw PCM
# MP3 files start with ID3 tag or sync word (0xFF 0xFB/0xFA/0xF3/0xF2)
is_mp3 = audio_data.startswith(b"ID3") or (
len(audio_data) > 0
and audio_data[0:1] == b"\xff"
and len(audio_data) > 1
and audio_data[1] & 0xE0 == 0xE0
)

if isinstance(stream_event_response.response, StreamSentinel):
break

# decode base64
audio = base64.b64decode(stream_event_response.response.data.b64)

f.write(audio)
if is_mp3:
f.write(audio_data)
else:
raise ValueError("Invalid MP3 data received.")
else:
# RAW format: write PCM data as-is
f.write(audio_data)

@staticmethod
def _write_wav_header(file_handle: BinaryIO, audio_data: bytes) -> None:
"""
Write WAV file header for raw PCM audio data.

Uses default TTS parameters: 16-bit PCM, mono, 24000 Hz sample rate.
"""
import struct

sample_rate = 24000
num_channels = 1
bits_per_sample = 16
byte_rate = sample_rate * num_channels * bits_per_sample // 8
block_align = num_channels * bits_per_sample // 8
data_size = len(audio_data)

# Write WAV header
file_handle.write(b"RIFF")
file_handle.write(struct.pack("<I", 36 + data_size)) # File size - 8
file_handle.write(b"WAVE")
file_handle.write(b"fmt ")
file_handle.write(struct.pack("<I", 16)) # fmt chunk size
file_handle.write(struct.pack("<H", 1)) # Audio format (1 = PCM)
file_handle.write(struct.pack("<H", num_channels))
file_handle.write(struct.pack("<I", sample_rate))
file_handle.write(struct.pack("<I", byte_rate))
file_handle.write(struct.pack("<H", block_align))
file_handle.write(struct.pack("<H", bits_per_sample))
file_handle.write(b"data")
file_handle.write(struct.pack("<I", data_size))
file_handle.write(audio_data)


class AudioTranscriptionResponseFormat(str, Enum):
Expand Down