diff --git a/packages/image-generation/src/celeste_image_generation/providers/bytedance/streaming.py b/packages/image-generation/src/celeste_image_generation/providers/bytedance/streaming.py index 215ba82..4a3fafe 100644 --- a/packages/image-generation/src/celeste_image_generation/providers/bytedance/streaming.py +++ b/packages/image-generation/src/celeste_image_generation/providers/bytedance/streaming.py @@ -5,7 +5,6 @@ from typing import Any from celeste.artifacts import ImageArtifact -from celeste.io import Chunk from celeste.mime_types import ImageMimeType from celeste_image_generation.io import ImageGenerationChunk, ImageGenerationUsage from celeste_image_generation.streaming import ImageGenerationStream @@ -21,7 +20,7 @@ def __init__(self, sse_iterator: AsyncIterator[dict[str, Any]]) -> None: super().__init__(sse_iterator) self._completed_usage: ImageGenerationUsage | None = None - def _parse_chunk(self, chunk_data: dict[str, Any]) -> Chunk | None: + def _parse_chunk(self, chunk_data: dict[str, Any]) -> ImageGenerationChunk | None: """Parse chunk from SSE event.""" event_type = chunk_data.get("type") diff --git a/packages/image-generation/src/celeste_image_generation/providers/openai/streaming.py b/packages/image-generation/src/celeste_image_generation/providers/openai/streaming.py index 182ab15..b0a5794 100644 --- a/packages/image-generation/src/celeste_image_generation/providers/openai/streaming.py +++ b/packages/image-generation/src/celeste_image_generation/providers/openai/streaming.py @@ -5,7 +5,6 @@ from typing import Any from celeste.artifacts import ImageArtifact -from celeste.io import Chunk from celeste_image_generation.io import ImageGenerationChunk, ImageGenerationUsage from celeste_image_generation.streaming import ImageGenerationStream @@ -15,7 +14,7 @@ class OpenAIImageGenerationStream(ImageGenerationStream): """OpenAI streaming for image generation.""" - def _parse_chunk(self, chunk_data: dict[str, Any]) -> Chunk | None: + def _parse_chunk(self, chunk_data: dict[str, Any]) -> ImageGenerationChunk | None: """Parse chunk from SSE event. OpenAI returns two event types: diff --git a/packages/image-generation/src/celeste_image_generation/streaming.py b/packages/image-generation/src/celeste_image_generation/streaming.py index 4e07f7d..5574a57 100644 --- a/packages/image-generation/src/celeste_image_generation/streaming.py +++ b/packages/image-generation/src/celeste_image_generation/streaming.py @@ -12,7 +12,9 @@ from celeste_image_generation.parameters import ImageGenerationParameters -class ImageGenerationStream(Stream[ImageGenerationOutput, ImageGenerationParameters]): +class ImageGenerationStream( + Stream[ImageGenerationOutput, ImageGenerationParameters, ImageGenerationChunk] +): """Streaming for image generation.""" def _parse_output( diff --git a/packages/text-generation/src/celeste_text_generation/providers/anthropic/streaming.py b/packages/text-generation/src/celeste_text_generation/providers/anthropic/streaming.py index 8c516ca..7e12bfc 100644 --- a/packages/text-generation/src/celeste_text_generation/providers/anthropic/streaming.py +++ b/packages/text-generation/src/celeste_text_generation/providers/anthropic/streaming.py @@ -3,7 +3,6 @@ from collections.abc import Callable from typing import Any, Unpack -from celeste.io import Chunk from celeste_text_generation.io import ( TextGenerationChunk, TextGenerationFinishReason, @@ -34,7 +33,7 @@ def __init__( self._transform_output = transform_output self._last_finish_reason: TextGenerationFinishReason | None = None - def _parse_chunk(self, event: dict[str, Any]) -> Chunk | None: + def _parse_chunk(self, event: dict[str, Any]) -> TextGenerationChunk | None: """Parse SSE event into Chunk.""" event_type = event.get("type") if not event_type: diff --git a/packages/text-generation/src/celeste_text_generation/providers/cohere/streaming.py b/packages/text-generation/src/celeste_text_generation/providers/cohere/streaming.py index 4c31966..5573575 100644 --- a/packages/text-generation/src/celeste_text_generation/providers/cohere/streaming.py +++ b/packages/text-generation/src/celeste_text_generation/providers/cohere/streaming.py @@ -4,7 +4,6 @@ from collections.abc import Callable from typing import Any, Unpack -from celeste.io import Chunk from celeste_text_generation.io import ( TextGenerationChunk, TextGenerationFinishReason, @@ -36,7 +35,7 @@ def __init__( super().__init__(sse_iterator, **parameters) self._transform_output = transform_output - def _parse_chunk(self, event: dict[str, Any]) -> Chunk | None: + def _parse_chunk(self, event: dict[str, Any]) -> TextGenerationChunk | None: """Parse SSE event into Chunk, extracting text deltas and metadata.""" event_type = event.get("type") diff --git a/packages/text-generation/src/celeste_text_generation/providers/google/streaming.py b/packages/text-generation/src/celeste_text_generation/providers/google/streaming.py index 8e17478..de60696 100644 --- a/packages/text-generation/src/celeste_text_generation/providers/google/streaming.py +++ b/packages/text-generation/src/celeste_text_generation/providers/google/streaming.py @@ -3,7 +3,6 @@ from collections.abc import Callable from typing import Any, Unpack -from celeste.io import Chunk from celeste_text_generation.io import ( TextGenerationChunk, TextGenerationFinishReason, @@ -33,7 +32,7 @@ def __init__( super().__init__(sse_iterator, **parameters) self._transform_output = transform_output - def _parse_chunk(self, event: dict[str, Any]) -> Chunk | None: + def _parse_chunk(self, event: dict[str, Any]) -> TextGenerationChunk | None: """Parse SSE event into Chunk. Extract text delta from candidates[0].content.parts[0].text. diff --git a/packages/text-generation/src/celeste_text_generation/providers/mistral/streaming.py b/packages/text-generation/src/celeste_text_generation/providers/mistral/streaming.py index fcab62d..c36fe5f 100644 --- a/packages/text-generation/src/celeste_text_generation/providers/mistral/streaming.py +++ b/packages/text-generation/src/celeste_text_generation/providers/mistral/streaming.py @@ -4,7 +4,6 @@ from collections.abc import Callable from typing import Any, Unpack -from celeste.io import Chunk from celeste_text_generation.io import ( TextGenerationChunk, TextGenerationFinishReason, @@ -36,7 +35,7 @@ def __init__( super().__init__(sse_iterator, **parameters) self._transform_output = transform_output - def _parse_chunk(self, event: dict[str, Any]) -> Chunk | None: + def _parse_chunk(self, event: dict[str, Any]) -> TextGenerationChunk | None: """Parse chunk from SSE event. Extract from choices[0].delta.content (content delta events). diff --git a/packages/text-generation/src/celeste_text_generation/providers/openai/streaming.py b/packages/text-generation/src/celeste_text_generation/providers/openai/streaming.py index 0c1f647..2dc1530 100644 --- a/packages/text-generation/src/celeste_text_generation/providers/openai/streaming.py +++ b/packages/text-generation/src/celeste_text_generation/providers/openai/streaming.py @@ -3,7 +3,6 @@ from collections.abc import Callable from typing import Any, Unpack -from celeste.io import Chunk from celeste_text_generation.io import ( TextGenerationChunk, TextGenerationFinishReason, @@ -27,7 +26,7 @@ def __init__( super().__init__(sse_iterator, **parameters) self._transform_output = transform_output - def _parse_chunk(self, event: dict[str, Any]) -> Chunk | None: + def _parse_chunk(self, event: dict[str, Any]) -> TextGenerationChunk | None: """Parse SSE event into Chunk.""" event_type = event.get("type") if not event_type: diff --git a/packages/text-generation/src/celeste_text_generation/streaming.py b/packages/text-generation/src/celeste_text_generation/streaming.py index ba3ddf7..0b610cb 100644 --- a/packages/text-generation/src/celeste_text_generation/streaming.py +++ b/packages/text-generation/src/celeste_text_generation/streaming.py @@ -3,7 +3,6 @@ from abc import abstractmethod from typing import Any, Unpack -from celeste.io import Chunk from celeste.streaming import Stream from celeste_text_generation.io import ( TextGenerationChunk, @@ -13,11 +12,13 @@ from celeste_text_generation.parameters import TextGenerationParameters -class TextGenerationStream(Stream[TextGenerationOutput, TextGenerationParameters]): +class TextGenerationStream( + Stream[TextGenerationOutput, TextGenerationParameters, TextGenerationChunk] +): """Streaming for text generation.""" @abstractmethod - def _parse_chunk(self, event: dict[str, Any]) -> Chunk | None: + def _parse_chunk(self, event: dict[str, Any]) -> TextGenerationChunk | None: """Parse SSE event into Chunk (provider-specific).""" def _parse_output( diff --git a/pyproject.toml b/pyproject.toml index 0d4fd7f..1036efc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -165,7 +165,6 @@ ignore_missing_imports = true module = [ "celeste_text_generation.*", "celeste_text_generation.client", - "celeste_text_generation.streaming", "celeste_text_generation.providers.*", ] disable_error_code = ["override", "return-value", "arg-type", "call-arg", "assignment", "no-any-return"] @@ -174,7 +173,6 @@ disable_error_code = ["override", "return-value", "arg-type", "call-arg", "assig module = [ "celeste_image_generation.*", "celeste_image_generation.client", - "celeste_image_generation.streaming", "celeste_image_generation.providers.*", ] disable_error_code = ["override", "return-value", "arg-type", "call-arg", "assignment", "no-any-return"] diff --git a/src/celeste/client.py b/src/celeste/client.py index 45fa165..c2510b8 100644 --- a/src/celeste/client.py +++ b/src/celeste/client.py @@ -15,7 +15,7 @@ UnsupportedCapabilityError, ) from celeste.http import HTTPClient, get_http_client -from celeste.io import FinishReason, Input, Output, Usage +from celeste.io import Chunk, FinishReason, Input, Output, Usage from celeste.models import Model from celeste.parameters import ParameterMapper, Parameters from celeste.streaming import Stream @@ -75,7 +75,7 @@ def stream( self, *args: Any, # noqa: ANN401 **parameters: Unpack[Params], # type: ignore[misc] - ) -> Stream[Out, Params]: + ) -> Stream[Out, Params, Chunk]: """Stream content - signature varies by capability. Args: @@ -160,7 +160,7 @@ async def _make_request( """Make HTTP request(s) and return response object.""" ... - def _stream_class(self) -> type[Stream[Out, Params]]: + def _stream_class(self) -> type[Stream[Out, Params, Chunk]]: """Return the Stream class for this client.""" raise StreamingNotSupportedError(model_id=self.model.id) diff --git a/src/celeste/streaming.py b/src/celeste/streaming.py index cdb2a8e..36de0a4 100644 --- a/src/celeste/streaming.py +++ b/src/celeste/streaming.py @@ -6,11 +6,12 @@ from typing import Any, Self, Unpack from celeste.exceptions import StreamEmptyError, StreamNotExhaustedError -from celeste.io import Chunk, Output +from celeste.io import Chunk as ChunkBase +from celeste.io import Output from celeste.parameters import Parameters -class Stream[Out: Output, Params: Parameters](ABC): +class Stream[Out: Output, Params: Parameters, Chunk: ChunkBase](ABC): """Async iterator wrapper providing final Output access after stream exhaustion.""" def __init__( diff --git a/tests/unit_tests/test_client.py b/tests/unit_tests/test_client.py index 8bc96f5..8d027f2 100644 --- a/tests/unit_tests/test_client.py +++ b/tests/unit_tests/test_client.py @@ -15,7 +15,7 @@ StreamingNotSupportedError, UnsupportedCapabilityError, ) -from celeste.io import Input, Output, Usage +from celeste.io import Chunk, Input, Output, Usage from celeste.models import Model from celeste.parameters import ParameterMapper, Parameters from celeste.streaming import Stream @@ -202,7 +202,7 @@ async def _make_request( # type: ignore[override] request=httpx.Request("POST", "https://test.com"), ) - def _stream_class(self) -> type[Stream[Output, Parameters]]: + def _stream_class(self) -> type[Stream[Output, Parameters, Chunk]]: """Return the Stream class for this client.""" raise NotImplementedError("Streaming not implemented in test client") diff --git a/tests/unit_tests/test_http.py b/tests/unit_tests/test_http.py index 07ddaaf..6fbd219 100644 --- a/tests/unit_tests/test_http.py +++ b/tests/unit_tests/test_http.py @@ -13,6 +13,7 @@ close_all_http_clients, get_http_client, ) +from celeste.mime_types import ApplicationMimeType @pytest.fixture @@ -183,7 +184,7 @@ async def test_post_request_with_all_parameters( url = "https://api.example.com/generate" headers = { "Authorization": "Bearer sk-test", - "Content-Type": "application/json", + "Content-Type": ApplicationMimeType.JSON, } json_body = {"prompt": "Hello", "max_tokens": 100} timeout = 30.0 diff --git a/tests/unit_tests/test_streaming.py b/tests/unit_tests/test_streaming.py index 296bfc1..5b76eb7 100644 --- a/tests/unit_tests/test_streaming.py +++ b/tests/unit_tests/test_streaming.py @@ -19,7 +19,7 @@ class ConcreteOutput(Output[str]): pass -class ConcreteStream(Stream[ConcreteOutput, Parameters]): +class ConcreteStream(Stream[ConcreteOutput, Parameters, Chunk]): """Concrete Stream implementation for testing abstract behavior.""" def __init__( @@ -369,7 +369,7 @@ async def test_subclass_without_parse_chunk_fails( """Subclass without _parse_chunk implementation must fail instantiation.""" # Arrange - class IncompleteStream(Stream[ConcreteOutput, Parameters]): + class IncompleteStream(Stream[ConcreteOutput, Parameters, Chunk]): """Missing _parse_chunk implementation.""" def _parse_output( # type: ignore[override] @@ -388,7 +388,7 @@ async def test_subclass_without_parse_output_fails( """Subclass without _parse_output implementation must fail instantiation.""" # Arrange - class IncompleteStream(Stream[ConcreteOutput, Parameters]): + class IncompleteStream(Stream[ConcreteOutput, Parameters, Chunk]): """Missing _parse_output implementation.""" def _parse_chunk(self, event: dict[str, Any]) -> Chunk | None: @@ -621,7 +621,7 @@ class TypedOutput(Output[str]): ) finish_reason: TypedFinishReason | None = None - class TypedStream(Stream[TypedOutput, Parameters]): + class TypedStream(Stream[TypedOutput, Parameters, TypedChunk]): """Stream using typed classes.""" def _parse_chunk(self, event: dict[str, Any]) -> TypedChunk | None: