Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Any

from celeste.artifacts import ImageArtifact
from celeste.io import Chunk
from celeste.mime_types import ImageMimeType
from celeste_image_generation.io import ImageGenerationChunk, ImageGenerationUsage
from celeste_image_generation.streaming import ImageGenerationStream
Expand All @@ -21,7 +20,7 @@ def __init__(self, sse_iterator: AsyncIterator[dict[str, Any]]) -> None:
super().__init__(sse_iterator)
self._completed_usage: ImageGenerationUsage | None = None

def _parse_chunk(self, chunk_data: dict[str, Any]) -> Chunk | None:
def _parse_chunk(self, chunk_data: dict[str, Any]) -> ImageGenerationChunk | None:
"""Parse chunk from SSE event."""
event_type = chunk_data.get("type")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Any

from celeste.artifacts import ImageArtifact
from celeste.io import Chunk
from celeste_image_generation.io import ImageGenerationChunk, ImageGenerationUsage
from celeste_image_generation.streaming import ImageGenerationStream

Expand All @@ -15,7 +14,7 @@
class OpenAIImageGenerationStream(ImageGenerationStream):
"""OpenAI streaming for image generation."""

def _parse_chunk(self, chunk_data: dict[str, Any]) -> Chunk | None:
def _parse_chunk(self, chunk_data: dict[str, Any]) -> ImageGenerationChunk | None:
"""Parse chunk from SSE event.

OpenAI returns two event types:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
from celeste_image_generation.parameters import ImageGenerationParameters


class ImageGenerationStream(Stream[ImageGenerationOutput, ImageGenerationParameters]):
class ImageGenerationStream(
Stream[ImageGenerationOutput, ImageGenerationParameters, ImageGenerationChunk]
):
"""Streaming for image generation."""

def _parse_output(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from collections.abc import Callable
from typing import Any, Unpack

from celeste.io import Chunk
from celeste_text_generation.io import (
TextGenerationChunk,
TextGenerationFinishReason,
Expand Down Expand Up @@ -34,7 +33,7 @@ def __init__(
self._transform_output = transform_output
self._last_finish_reason: TextGenerationFinishReason | None = None

def _parse_chunk(self, event: dict[str, Any]) -> Chunk | None:
def _parse_chunk(self, event: dict[str, Any]) -> TextGenerationChunk | None:
"""Parse SSE event into Chunk."""
event_type = event.get("type")
if not event_type:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from collections.abc import Callable
from typing import Any, Unpack

from celeste.io import Chunk
from celeste_text_generation.io import (
TextGenerationChunk,
TextGenerationFinishReason,
Expand Down Expand Up @@ -36,7 +35,7 @@ def __init__(
super().__init__(sse_iterator, **parameters)
self._transform_output = transform_output

def _parse_chunk(self, event: dict[str, Any]) -> Chunk | None:
def _parse_chunk(self, event: dict[str, Any]) -> TextGenerationChunk | None:
"""Parse SSE event into Chunk, extracting text deltas and metadata."""
event_type = event.get("type")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from collections.abc import Callable
from typing import Any, Unpack

from celeste.io import Chunk
from celeste_text_generation.io import (
TextGenerationChunk,
TextGenerationFinishReason,
Expand Down Expand Up @@ -33,7 +32,7 @@ def __init__(
super().__init__(sse_iterator, **parameters)
self._transform_output = transform_output

def _parse_chunk(self, event: dict[str, Any]) -> Chunk | None:
def _parse_chunk(self, event: dict[str, Any]) -> TextGenerationChunk | None:
"""Parse SSE event into Chunk.

Extract text delta from candidates[0].content.parts[0].text.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from collections.abc import Callable
from typing import Any, Unpack

from celeste.io import Chunk
from celeste_text_generation.io import (
TextGenerationChunk,
TextGenerationFinishReason,
Expand Down Expand Up @@ -36,7 +35,7 @@ def __init__(
super().__init__(sse_iterator, **parameters)
self._transform_output = transform_output

def _parse_chunk(self, event: dict[str, Any]) -> Chunk | None:
def _parse_chunk(self, event: dict[str, Any]) -> TextGenerationChunk | None:
"""Parse chunk from SSE event.

Extract from choices[0].delta.content (content delta events).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from collections.abc import Callable
from typing import Any, Unpack

from celeste.io import Chunk
from celeste_text_generation.io import (
TextGenerationChunk,
TextGenerationFinishReason,
Expand All @@ -27,7 +26,7 @@ def __init__(
super().__init__(sse_iterator, **parameters)
self._transform_output = transform_output

def _parse_chunk(self, event: dict[str, Any]) -> Chunk | None:
def _parse_chunk(self, event: dict[str, Any]) -> TextGenerationChunk | None:
"""Parse SSE event into Chunk."""
event_type = event.get("type")
if not event_type:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from abc import abstractmethod
from typing import Any, Unpack

from celeste.io import Chunk
from celeste.streaming import Stream
from celeste_text_generation.io import (
TextGenerationChunk,
Expand All @@ -13,11 +12,13 @@
from celeste_text_generation.parameters import TextGenerationParameters


class TextGenerationStream(Stream[TextGenerationOutput, TextGenerationParameters]):
class TextGenerationStream(
Stream[TextGenerationOutput, TextGenerationParameters, TextGenerationChunk]
):
"""Streaming for text generation."""

@abstractmethod
def _parse_chunk(self, event: dict[str, Any]) -> Chunk | None:
def _parse_chunk(self, event: dict[str, Any]) -> TextGenerationChunk | None:
"""Parse SSE event into Chunk (provider-specific)."""

def _parse_output(
Expand Down
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ ignore_missing_imports = true
module = [
"celeste_text_generation.*",
"celeste_text_generation.client",
"celeste_text_generation.streaming",
"celeste_text_generation.providers.*",
]
disable_error_code = ["override", "return-value", "arg-type", "call-arg", "assignment", "no-any-return"]
Expand All @@ -174,7 +173,6 @@ disable_error_code = ["override", "return-value", "arg-type", "call-arg", "assig
module = [
"celeste_image_generation.*",
"celeste_image_generation.client",
"celeste_image_generation.streaming",
"celeste_image_generation.providers.*",
]
disable_error_code = ["override", "return-value", "arg-type", "call-arg", "assignment", "no-any-return"]
Expand Down
6 changes: 3 additions & 3 deletions src/celeste/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
UnsupportedCapabilityError,
)
from celeste.http import HTTPClient, get_http_client
from celeste.io import FinishReason, Input, Output, Usage
from celeste.io import Chunk, FinishReason, Input, Output, Usage
from celeste.models import Model
from celeste.parameters import ParameterMapper, Parameters
from celeste.streaming import Stream
Expand Down Expand Up @@ -75,7 +75,7 @@ def stream(
self,
*args: Any, # noqa: ANN401
**parameters: Unpack[Params], # type: ignore[misc]
) -> Stream[Out, Params]:
) -> Stream[Out, Params, Chunk]:
"""Stream content - signature varies by capability.

Args:
Expand Down Expand Up @@ -160,7 +160,7 @@ async def _make_request(
"""Make HTTP request(s) and return response object."""
...

def _stream_class(self) -> type[Stream[Out, Params]]:
def _stream_class(self) -> type[Stream[Out, Params, Chunk]]:
"""Return the Stream class for this client."""
raise StreamingNotSupportedError(model_id=self.model.id)

Expand Down
5 changes: 3 additions & 2 deletions src/celeste/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
from typing import Any, Self, Unpack

from celeste.exceptions import StreamEmptyError, StreamNotExhaustedError
from celeste.io import Chunk, Output
from celeste.io import Chunk as ChunkBase
from celeste.io import Output
from celeste.parameters import Parameters


class Stream[Out: Output, Params: Parameters](ABC):
class Stream[Out: Output, Params: Parameters, Chunk: ChunkBase](ABC):
"""Async iterator wrapper providing final Output access after stream exhaustion."""

def __init__(
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
StreamingNotSupportedError,
UnsupportedCapabilityError,
)
from celeste.io import Input, Output, Usage
from celeste.io import Chunk, Input, Output, Usage
from celeste.models import Model
from celeste.parameters import ParameterMapper, Parameters
from celeste.streaming import Stream
Expand Down Expand Up @@ -202,7 +202,7 @@ async def _make_request( # type: ignore[override]
request=httpx.Request("POST", "https://test.com"),
)

def _stream_class(self) -> type[Stream[Output, Parameters]]:
def _stream_class(self) -> type[Stream[Output, Parameters, Chunk]]:
"""Return the Stream class for this client."""
raise NotImplementedError("Streaming not implemented in test client")

Expand Down
3 changes: 2 additions & 1 deletion tests/unit_tests/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
close_all_http_clients,
get_http_client,
)
from celeste.mime_types import ApplicationMimeType


@pytest.fixture
Expand Down Expand Up @@ -183,7 +184,7 @@ async def test_post_request_with_all_parameters(
url = "https://api.example.com/generate"
headers = {
"Authorization": "Bearer sk-test",
"Content-Type": "application/json",
"Content-Type": ApplicationMimeType.JSON,
}
json_body = {"prompt": "Hello", "max_tokens": 100}
timeout = 30.0
Expand Down
8 changes: 4 additions & 4 deletions tests/unit_tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class ConcreteOutput(Output[str]):
pass


class ConcreteStream(Stream[ConcreteOutput, Parameters]):
class ConcreteStream(Stream[ConcreteOutput, Parameters, Chunk]):
"""Concrete Stream implementation for testing abstract behavior."""

def __init__(
Expand Down Expand Up @@ -369,7 +369,7 @@ async def test_subclass_without_parse_chunk_fails(
"""Subclass without _parse_chunk implementation must fail instantiation."""

# Arrange
class IncompleteStream(Stream[ConcreteOutput, Parameters]):
class IncompleteStream(Stream[ConcreteOutput, Parameters, Chunk]):
"""Missing _parse_chunk implementation."""

def _parse_output( # type: ignore[override]
Expand All @@ -388,7 +388,7 @@ async def test_subclass_without_parse_output_fails(
"""Subclass without _parse_output implementation must fail instantiation."""

# Arrange
class IncompleteStream(Stream[ConcreteOutput, Parameters]):
class IncompleteStream(Stream[ConcreteOutput, Parameters, Chunk]):
"""Missing _parse_output implementation."""

def _parse_chunk(self, event: dict[str, Any]) -> Chunk | None:
Expand Down Expand Up @@ -621,7 +621,7 @@ class TypedOutput(Output[str]):
)
finish_reason: TypedFinishReason | None = None

class TypedStream(Stream[TypedOutput, Parameters]):
class TypedStream(Stream[TypedOutput, Parameters, TypedChunk]):
"""Stream using typed classes."""

def _parse_chunk(self, event: dict[str, Any]) -> TypedChunk | None:
Expand Down
Loading