diff --git a/src/chat/CHANGELOG.md b/src/chat/CHANGELOG.md index 64267a841..de18d90a9 100644 --- a/src/chat/CHANGELOG.md +++ b/src/chat/CHANGELOG.md @@ -4,6 +4,11 @@ CHANGELOG 0.1 --- +* Add streaming support to `ChatInterface::submit()` + - Add `StreamableStoreInterface` which indicates `StoreInterface` implementation can be configured with streaming + - Add `AccumulatingStreamResult` wrapper class which adds accumulation logic & callback chaining to `StreamResult` implementations (can wrap both `Agent` and `Platform` variants) to return the full message once `Generator` is exhausted + - Streamed responses now also create `AssistantMessage` & are added to `Store` in `Chat::submit()` + - Bugfixed loss of metadata in `Chat::submit()` * Introduce the component * Add support for external message stores: - Doctrine diff --git a/src/chat/src/Bridge/Doctrine/DoctrineDbalMessageStore.php b/src/chat/src/Bridge/Doctrine/DoctrineDbalMessageStore.php index 6a2ef80fb..31e526253 100644 --- a/src/chat/src/Bridge/Doctrine/DoctrineDbalMessageStore.php +++ b/src/chat/src/Bridge/Doctrine/DoctrineDbalMessageStore.php @@ -24,6 +24,7 @@ use Symfony\AI\Chat\ManagedStoreInterface; use Symfony\AI\Chat\MessageNormalizer; use Symfony\AI\Chat\MessageStoreInterface; +use Symfony\AI\Chat\StreamableStoreInterface; use Symfony\AI\Platform\Message\MessageBag; use Symfony\AI\Platform\Message\MessageInterface; use Symfony\Component\Serializer\Encoder\JsonEncoder; @@ -34,7 +35,7 @@ /** * @author Guillaume Loulier */ -final class DoctrineDbalMessageStore implements ManagedStoreInterface, MessageStoreInterface +final class DoctrineDbalMessageStore implements ManagedStoreInterface, MessageStoreInterface, StreamableStoreInterface { public function __construct( private readonly string $tableName, diff --git a/src/chat/src/Bridge/Local/CacheStore.php b/src/chat/src/Bridge/Local/CacheStore.php index 0b2626fa3..57bec1b05 100644 --- a/src/chat/src/Bridge/Local/CacheStore.php +++ b/src/chat/src/Bridge/Local/CacheStore.php @@ -15,12 +15,13 @@ use Symfony\AI\Agent\Exception\RuntimeException; use Symfony\AI\Chat\ManagedStoreInterface; use Symfony\AI\Chat\MessageStoreInterface; +use Symfony\AI\Chat\StreamableStoreInterface; use Symfony\AI\Platform\Message\MessageBag; /** * @author Christopher Hertel */ -final class CacheStore implements ManagedStoreInterface, MessageStoreInterface +final class CacheStore implements ManagedStoreInterface, MessageStoreInterface, StreamableStoreInterface { public function __construct( private readonly CacheItemPoolInterface $cache, diff --git a/src/chat/src/Bridge/Local/InMemoryStore.php b/src/chat/src/Bridge/Local/InMemoryStore.php index 362a90472..e75f10cbd 100644 --- a/src/chat/src/Bridge/Local/InMemoryStore.php +++ b/src/chat/src/Bridge/Local/InMemoryStore.php @@ -13,12 +13,13 @@ use Symfony\AI\Chat\ManagedStoreInterface; use Symfony\AI\Chat\MessageStoreInterface; +use Symfony\AI\Chat\StreamableStoreInterface; use Symfony\AI\Platform\Message\MessageBag; /** * @author Christopher Hertel */ -final class InMemoryStore implements ManagedStoreInterface, MessageStoreInterface +final class InMemoryStore implements ManagedStoreInterface, MessageStoreInterface, StreamableStoreInterface { /** * @var MessageBag[] diff --git a/src/chat/src/Bridge/Meilisearch/MessageStore.php b/src/chat/src/Bridge/Meilisearch/MessageStore.php index 83b7add7c..b72fba2b5 100644 --- a/src/chat/src/Bridge/Meilisearch/MessageStore.php +++ b/src/chat/src/Bridge/Meilisearch/MessageStore.php @@ -16,6 +16,7 @@ use Symfony\AI\Chat\ManagedStoreInterface; use Symfony\AI\Chat\MessageNormalizer; use Symfony\AI\Chat\MessageStoreInterface; +use Symfony\AI\Chat\StreamableStoreInterface; use Symfony\AI\Platform\Message\MessageBag; use Symfony\AI\Platform\Message\MessageInterface; use Symfony\Component\Clock\ClockInterface; @@ -31,7 +32,7 @@ /** * @author Guillaume Loulier */ -final class MessageStore implements ManagedStoreInterface, MessageStoreInterface +final class MessageStore implements ManagedStoreInterface, MessageStoreInterface, StreamableStoreInterface { public function __construct( private readonly HttpClientInterface $httpClient, diff --git a/src/chat/src/Bridge/MongoDb/MessageStore.php b/src/chat/src/Bridge/MongoDb/MessageStore.php index 3d167e8f1..eb7c2a857 100644 --- a/src/chat/src/Bridge/MongoDb/MessageStore.php +++ b/src/chat/src/Bridge/MongoDb/MessageStore.php @@ -15,6 +15,7 @@ use Symfony\AI\Chat\ManagedStoreInterface; use Symfony\AI\Chat\MessageNormalizer; use Symfony\AI\Chat\MessageStoreInterface; +use Symfony\AI\Chat\StreamableStoreInterface; use Symfony\AI\Platform\Message\MessageBag; use Symfony\AI\Platform\Message\MessageInterface; use Symfony\Component\Serializer\Encoder\JsonEncoder; @@ -27,7 +28,7 @@ /** * @author Guillaume Loulier */ -final class MessageStore implements ManagedStoreInterface, MessageStoreInterface +final class MessageStore implements ManagedStoreInterface, MessageStoreInterface, StreamableStoreInterface { public function __construct( private readonly Client $client, diff --git a/src/chat/src/Bridge/Pogocache/MessageStore.php b/src/chat/src/Bridge/Pogocache/MessageStore.php index 2a7a28c6d..c80cec290 100644 --- a/src/chat/src/Bridge/Pogocache/MessageStore.php +++ b/src/chat/src/Bridge/Pogocache/MessageStore.php @@ -15,6 +15,7 @@ use Symfony\AI\Chat\ManagedStoreInterface; use Symfony\AI\Chat\MessageNormalizer; use Symfony\AI\Chat\MessageStoreInterface; +use Symfony\AI\Chat\StreamableStoreInterface; use Symfony\AI\Platform\Message\MessageBag; use Symfony\AI\Platform\Message\MessageInterface; use Symfony\Component\Serializer\Encoder\JsonEncoder; @@ -28,7 +29,7 @@ /** * @author Guillaume Loulier */ -final class MessageStore implements ManagedStoreInterface, MessageStoreInterface +final class MessageStore implements ManagedStoreInterface, MessageStoreInterface, StreamableStoreInterface { public function __construct( private readonly HttpClientInterface $httpClient, diff --git a/src/chat/src/Bridge/Redis/MessageStore.php b/src/chat/src/Bridge/Redis/MessageStore.php index d8d964fa2..0e176764b 100644 --- a/src/chat/src/Bridge/Redis/MessageStore.php +++ b/src/chat/src/Bridge/Redis/MessageStore.php @@ -14,6 +14,7 @@ use Symfony\AI\Chat\ManagedStoreInterface; use Symfony\AI\Chat\MessageNormalizer; use Symfony\AI\Chat\MessageStoreInterface; +use Symfony\AI\Chat\StreamableStoreInterface; use Symfony\AI\Platform\Message\MessageBag; use Symfony\AI\Platform\Message\MessageInterface; use Symfony\Component\Serializer\Encoder\JsonEncoder; @@ -24,7 +25,7 @@ /** * @author Guillaume Loulier */ -final class MessageStore implements ManagedStoreInterface, MessageStoreInterface +final class MessageStore implements ManagedStoreInterface, MessageStoreInterface, StreamableStoreInterface { public function __construct( private readonly \Redis $redis, diff --git a/src/chat/src/Bridge/SurrealDb/MessageStore.php b/src/chat/src/Bridge/SurrealDb/MessageStore.php index 707a4c0be..d77a88187 100644 --- a/src/chat/src/Bridge/SurrealDb/MessageStore.php +++ b/src/chat/src/Bridge/SurrealDb/MessageStore.php @@ -16,6 +16,7 @@ use Symfony\AI\Chat\ManagedStoreInterface; use Symfony\AI\Chat\MessageNormalizer; use Symfony\AI\Chat\MessageStoreInterface; +use Symfony\AI\Chat\StreamableStoreInterface; use Symfony\AI\Platform\Message\MessageBag; use Symfony\AI\Platform\Message\MessageInterface; use Symfony\Component\Serializer\Encoder\JsonEncoder; @@ -29,7 +30,7 @@ /** * @author Guillaume Loulier */ -final class MessageStore implements ManagedStoreInterface, MessageStoreInterface +final class MessageStore implements ManagedStoreInterface, MessageStoreInterface, StreamableStoreInterface { private string $authenticationToken = ''; diff --git a/src/chat/src/Chat.php b/src/chat/src/Chat.php index 153672f02..b181b34d7 100644 --- a/src/chat/src/Chat.php +++ b/src/chat/src/Chat.php @@ -12,10 +12,14 @@ namespace Symfony\AI\Chat; use Symfony\AI\Agent\AgentInterface; +use Symfony\AI\Agent\Exception\RuntimeException; +use Symfony\AI\Agent\Toolbox\StreamResult as ToolboxStreamResult; +use Symfony\AI\Chat\Result\AccumulatingStreamResult; use Symfony\AI\Platform\Message\AssistantMessage; use Symfony\AI\Platform\Message\Message; use Symfony\AI\Platform\Message\MessageBag; use Symfony\AI\Platform\Message\UserMessage; +use Symfony\AI\Platform\Result\StreamResult; use Symfony\AI\Platform\Result\TextResult; /** @@ -35,18 +39,31 @@ public function initiate(MessageBag $messages): void $this->store->save($messages); } - public function submit(UserMessage $message): AssistantMessage + public function submit(UserMessage $message): AssistantMessage|AccumulatingStreamResult { $messages = $this->store->load(); $messages->add($message); $result = $this->agent->call($messages); + if ($result instanceof StreamResult || $result instanceof ToolboxStreamResult) { + if (!$this->store instanceof StreamableStoreInterface) { + throw new RuntimeException($this->store::class.' does not support streaming.'); + } + + return new AccumulatingStreamResult($result, function (AssistantMessage $assistantMessage) use ($messages) { + $messages->add($assistantMessage); + $this->store->save($messages); + }); + } + \assert($result instanceof TextResult); $assistantMessage = Message::ofAssistant($result->getContent()); - $messages->add($assistantMessage); + $assistantMessage->getMetadata()->set($result->getMetadata()->all()); + + $messages->add($assistantMessage); $this->store->save($messages); return $assistantMessage; diff --git a/src/chat/src/ChatInterface.php b/src/chat/src/ChatInterface.php index 727146131..d9122f983 100644 --- a/src/chat/src/ChatInterface.php +++ b/src/chat/src/ChatInterface.php @@ -12,6 +12,7 @@ namespace Symfony\AI\Chat; use Symfony\AI\Agent\Exception\ExceptionInterface; +use Symfony\AI\Chat\Result\AccumulatingStreamResult; use Symfony\AI\Platform\Message\AssistantMessage; use Symfony\AI\Platform\Message\MessageBag; use Symfony\AI\Platform\Message\UserMessage; @@ -26,5 +27,5 @@ public function initiate(MessageBag $messages): void; /** * @throws ExceptionInterface When the chat submission fails due to agent errors */ - public function submit(UserMessage $message): AssistantMessage; + public function submit(UserMessage $message): AssistantMessage|AccumulatingStreamResult; } diff --git a/src/chat/src/Result/AccumulatingStreamResult.php b/src/chat/src/Result/AccumulatingStreamResult.php new file mode 100644 index 000000000..2b467218a --- /dev/null +++ b/src/chat/src/Result/AccumulatingStreamResult.php @@ -0,0 +1,81 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Chat\Result; + +use Symfony\AI\Agent\Toolbox\StreamResult as ToolboxStreamResult; +use Symfony\AI\Platform\Message\AssistantMessage; +use Symfony\AI\Platform\Message\Message; +use Symfony\AI\Platform\Metadata\Metadata; +use Symfony\AI\Platform\Result\StreamResult; +use Symfony\AI\Platform\Result\ToolCallResult; + +/** + * @author Marco van Angeren + */ +final class AccumulatingStreamResult +{ + private ?\Closure $onComplete = null; + + public function __construct( + private readonly StreamResult|ToolboxStreamResult $innerResult, + ?\Closure $onComplete = null, + ) { + $this->onComplete = $onComplete; + } + + public function addOnComplete(\Closure $callback): void + { + $existingCallback = $this->onComplete; + + $this->onComplete = $existingCallback + ? function (AssistantMessage $message) use ($existingCallback, $callback) { + $existingCallback($message); + $callback($message); + } + : $callback; + } + + public function getContent(): \Generator + { + $accumulatedContent = ''; + $toolCalls = []; + + try { + foreach ($this->innerResult->getContent() as $value) { + if ($value instanceof ToolCallResult) { + array_push($toolCalls, ...$value->getContent()); + yield $value; + continue; + } + + $accumulatedContent .= $value; + yield $value; + } + } finally { + if (null !== $this->onComplete) { + $assistantMessage = Message::ofAssistant( + '' === $accumulatedContent ? null : $accumulatedContent, + $toolCalls ?: null + ); + + $assistantMessage->getMetadata()->set($this->innerResult->getMetadata()->all()); + + ($this->onComplete)($assistantMessage); + } + } + } + + public function getMetadata(): Metadata + { + return $this->innerResult->getMetadata(); + } +} diff --git a/src/chat/src/StreamableStoreInterface.php b/src/chat/src/StreamableStoreInterface.php new file mode 100644 index 000000000..45268b0f6 --- /dev/null +++ b/src/chat/src/StreamableStoreInterface.php @@ -0,0 +1,19 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Chat; + +/** + * @author Marco van Angeren + */ +interface StreamableStoreInterface +{ +} diff --git a/src/chat/tests/ChatTest.php b/src/chat/tests/ChatTest.php index b6ad22d9a..6a1b9bfbe 100644 --- a/src/chat/tests/ChatTest.php +++ b/src/chat/tests/ChatTest.php @@ -16,10 +16,14 @@ use Symfony\AI\Agent\AgentInterface; use Symfony\AI\Chat\Bridge\Local\InMemoryStore; use Symfony\AI\Chat\Chat; +use Symfony\AI\Chat\Result\AccumulatingStreamResult; use Symfony\AI\Platform\Message\AssistantMessage; use Symfony\AI\Platform\Message\Message; use Symfony\AI\Platform\Message\MessageBag; +use Symfony\AI\Platform\Result\StreamResult; use Symfony\AI\Platform\Result\TextResult; +use Symfony\AI\Platform\Result\ToolCall; +use Symfony\AI\Platform\Result\ToolCallResult; final class ChatTest extends TestCase { @@ -113,4 +117,118 @@ public function testItHandlesEmptyMessageStore() $this->assertSame($assistantContent, $result->getContent()); $this->assertCount(2, $this->store->load()); } + + public function testItSupportsStreaming() + { + $userMessage = Message::ofUser('What is your favourite song?'); + $generator = (function () { + yield 'Bitter Sweet'; + yield ' '; + yield 'Symfony'; + })(); + + $streamResult = new StreamResult($generator); + + $this->agent->expects($this->once()) + ->method('call') + ->willReturn($streamResult); + + $result = $this->chat->submit($userMessage); + $this->assertInstanceOf(AccumulatingStreamResult::class, $result); + + $chunks = iterator_to_array($result->getContent()); + $this->assertSame(['Bitter Sweet', ' ', 'Symfony'], $chunks); + + $storedMessages = $this->store->load(); + $this->assertCount(2, $storedMessages); + + $messages = $storedMessages->getMessages(); + $lastMessage = end($messages); + $this->assertInstanceOf(AssistantMessage::class, $lastMessage); + $this->assertSame('Bitter Sweet Symfony', $lastMessage->getContent()); + } + + public function testStreamingPreservesMetadata() + { + $userMessage = Message::ofUser('Hello'); + $generator = (function () { + yield 'Test'; + })(); + + $streamResult = new StreamResult($generator); + $streamResult->getMetadata()->add('key1', 'value1'); + $streamResult->getMetadata()->add('key2', 'value2'); + + $this->agent->expects($this->once()) + ->method('call') + ->willReturn($streamResult); + + $result = $this->chat->submit($userMessage); + + iterator_to_array($result->getContent()); + + $storedMessages = $this->store->load(); + $lastMessage = $storedMessages->getMessages()[1]; + $this->assertTrue($lastMessage->getMetadata()->has('key1')); + $this->assertTrue($lastMessage->getMetadata()->has('key2')); + $this->assertSame('value1', $lastMessage->getMetadata()->get('key1')); + $this->assertSame('value2', $lastMessage->getMetadata()->get('key2')); + } + + public function testStreamingWithToolCalls() + { + $userMessage = Message::ofUser('Hello'); + $toolCall = new ToolCall('call_123', 'test_tool', ['param' => 'value']); + $toolCallResult = new ToolCallResult($toolCall); + + $generator = (function () use ($toolCallResult) { + yield 'Some text'; + yield $toolCallResult; + })(); + + $streamResult = new StreamResult($generator); + + $this->agent->expects($this->once()) + ->method('call') + ->willReturn($streamResult); + + $result = $this->chat->submit($userMessage); + + iterator_to_array($result->getContent()); + + $storedMessages = $this->store->load(); + $lastMessage = $storedMessages->getMessages()[1]; + $this->assertInstanceOf(AssistantMessage::class, $lastMessage); + $this->assertSame('Some text', $lastMessage->getContent()); + $this->assertTrue($lastMessage->hasToolCalls()); + } + + public function testStreamingCallbackFiresEvenIfIterationStopsEarly() + { + $userMessage = Message::ofUser('Hello'); + $generator = (function () { + yield 'Chunk1'; + yield 'Chunk2'; + yield 'Chunk3'; + })(); + + $streamResult = new StreamResult($generator); + + $this->agent->expects($this->once()) + ->method('call') + ->willReturn($streamResult); + + $result = $this->chat->submit($userMessage); + + $content = $result->getContent(); + $content->current(); + $content->next(); + + while ($content->valid()) { + $content->next(); + } + + $storedMessages = $this->store->load(); + $this->assertCount(2, $storedMessages); + } }