diff --git a/src/together/abstract/api_requestor.py b/src/together/abstract/api_requestor.py index 7e37eaf8..e956bb3a 100644 --- a/src/together/abstract/api_requestor.py +++ b/src/together/abstract/api_requestor.py @@ -619,14 +619,29 @@ def _interpret_response( ) -> Tuple[TogetherResponse | Iterator[TogetherResponse], bool]: """Returns the response(s) and a bool indicating whether it is a stream.""" content_type = result.headers.get("Content-Type", "") + if stream and "text/event-stream" in content_type: + # SSE format streaming return ( self._interpret_response_line( line, result.status_code, result.headers, stream=True ) for line in parse_stream(result.iter_lines()) ), True + elif stream and content_type in [ + "audio/wav", + "audio/mpeg", + "application/octet-stream", + ]: + # Binary audio streaming - return chunks as binary data + def binary_stream_generator() -> Iterator[TogetherResponse]: + for chunk in result.iter_content(chunk_size=8192): + if chunk: # Skip empty chunks + yield TogetherResponse(chunk, dict(result.headers)) + + return binary_stream_generator(), True else: + # Non-streaming response if content_type in ["application/octet-stream", "audio/wav", "audio/mpeg"]: content = result.content else: @@ -648,23 +663,49 @@ async def _interpret_async_response( | tuple[TogetherResponse, bool] ): """Returns the response(s) and a bool indicating whether it is a stream.""" - if stream and "text/event-stream" in result.headers.get("Content-Type", ""): + content_type = result.headers.get("Content-Type", "") + + if stream and "text/event-stream" in content_type: + # SSE format streaming return ( self._interpret_response_line( line, result.status, result.headers, stream=True ) async for line in parse_stream_async(result.content) ), True + elif stream and content_type in [ + "audio/wav", + "audio/mpeg", + "application/octet-stream", + ]: + # Binary audio streaming - return chunks as binary data + async def binary_stream_generator() -> ( + AsyncGenerator[TogetherResponse, None] + ): + async for chunk in result.content.iter_chunked(8192): + if chunk: # Skip empty chunks + yield TogetherResponse(chunk, dict(result.headers)) + + return binary_stream_generator(), True else: + # Non-streaming response try: - await result.read() + content = await result.read() except (aiohttp.ServerTimeoutError, asyncio.TimeoutError) as e: raise error.Timeout("Request timed out") from e except aiohttp.ClientError as e: utils.log_warn(e, body=result.content) + + if content_type in ["application/octet-stream", "audio/wav", "audio/mpeg"]: + # Binary content - keep as bytes + response_content: str | bytes = content + else: + # Text content - decode to string + response_content = content.decode("utf-8") + return ( self._interpret_response_line( - (await result.read()).decode("utf-8"), + response_content, result.status, result.headers, stream=False, diff --git a/src/together/types/audio_speech.py b/src/together/types/audio_speech.py index bb54cc7f..3bb4e00b 100644 --- a/src/together/types/audio_speech.py +++ b/src/together/types/audio_speech.py @@ -82,27 +82,126 @@ class AudioSpeechStreamResponse(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) - def stream_to_file(self, file_path: str) -> None: + def stream_to_file( + self, file_path: str, response_format: AudioResponseFormat | str | None = None + ) -> None: + """ + Save the audio response to a file. + + For non-streaming responses, writes the complete file as received. + For streaming responses, collects binary chunks and constructs a valid + file format based on the response_format parameter. + + Args: + file_path: Path where the audio file should be saved. + response_format: Format of the audio (wav, mp3, or raw). If not provided, + will attempt to infer from file extension or default to wav. + """ + # Determine response format + if response_format is None: + # Infer from file extension + ext = file_path.lower().split(".")[-1] if "." in file_path else "" + if ext in ["wav"]: + response_format = AudioResponseFormat.WAV + elif ext in ["mp3", "mpeg"]: + response_format = AudioResponseFormat.MP3 + elif ext in ["raw", "pcm"]: + response_format = AudioResponseFormat.RAW + else: + # Default to WAV if unknown + response_format = AudioResponseFormat.WAV + + if isinstance(response_format, str): + response_format = AudioResponseFormat(response_format) + if isinstance(self.response, TogetherResponse): - # save response to file + # Non-streaming: save complete file with open(file_path, "wb") as f: f.write(self.response.data) elif isinstance(self.response, Iterator): + # Streaming: collect binary chunks + audio_chunks = [] + for chunk in self.response: + if isinstance(chunk.data, bytes): + audio_chunks.append(chunk.data) + elif isinstance(chunk.data, dict): + # SSE format with JSON/base64 + try: + stream_event = AudioSpeechStreamEventResponse( + response={"data": chunk.data} + ) + if isinstance(stream_event.response, StreamSentinel): + break + audio_chunks.append( + base64.b64decode(stream_event.response.data.b64) + ) + except Exception: + continue # Skip malformed chunks + + if not audio_chunks: + raise ValueError("No audio data received in streaming response") + + # Concatenate all chunks + audio_data = b"".join(audio_chunks) + with open(file_path, "wb") as f: - for chunk in self.response: - # Try to parse as stream chunk - stream_event_response = AudioSpeechStreamEventResponse( - response={"data": chunk.data} + if response_format == AudioResponseFormat.WAV: + if audio_data.startswith(b"RIFF"): + # Already a valid WAV file + f.write(audio_data) + else: + # Raw PCM - add WAV header + self._write_wav_header(f, audio_data) + elif response_format == AudioResponseFormat.MP3: + # MP3 format: Check if data is actually MP3 or raw PCM + # MP3 files start with ID3 tag or sync word (0xFF 0xFB/0xFA/0xF3/0xF2) + is_mp3 = audio_data.startswith(b"ID3") or ( + len(audio_data) > 0 + and audio_data[0:1] == b"\xff" + and len(audio_data) > 1 + and audio_data[1] & 0xE0 == 0xE0 ) - if isinstance(stream_event_response.response, StreamSentinel): - break - - # decode base64 - audio = base64.b64decode(stream_event_response.response.data.b64) - - f.write(audio) + if is_mp3: + f.write(audio_data) + else: + raise ValueError("Invalid MP3 data received.") + else: + # RAW format: write PCM data as-is + f.write(audio_data) + + @staticmethod + def _write_wav_header(file_handle: BinaryIO, audio_data: bytes) -> None: + """ + Write WAV file header for raw PCM audio data. + + Uses default TTS parameters: 16-bit PCM, mono, 24000 Hz sample rate. + """ + import struct + + sample_rate = 24000 + num_channels = 1 + bits_per_sample = 16 + byte_rate = sample_rate * num_channels * bits_per_sample // 8 + block_align = num_channels * bits_per_sample // 8 + data_size = len(audio_data) + + # Write WAV header + file_handle.write(b"RIFF") + file_handle.write(struct.pack("