From 07b400793559eaf388d3dc7d5af3ee92f16d5469 Mon Sep 17 00:00:00 2001 From: 2xd7 Date: Thu, 19 Feb 2026 10:53:37 +0400 Subject: [PATCH 01/19] feat: add context cancellation to Stream.Next and handle empty chunks as normal --- chat/chat.go | 2 +- chat/types.go | 8 ++++---- connect/openai/stream.go | 24 ++++++++++++++++++------ 3 files changed, 23 insertions(+), 11 deletions(-) diff --git a/chat/chat.go b/chat/chat.go index a09327f..fec4822 100644 --- a/chat/chat.go +++ b/chat/chat.go @@ -20,7 +20,7 @@ func (c *Chat) Complete(ctx context.Context, client Client) <-chan StreamEvent { } defer stream.Close() - for stream.Next() { + for stream.Next(ctx) { event := stream.Current() select { diff --git a/chat/types.go b/chat/types.go index 032b1eb..6b17e0c 100644 --- a/chat/types.go +++ b/chat/types.go @@ -51,10 +51,10 @@ func (m *Messages) Snapshot() []StreamEvent { } type Stream[T any] interface { - Next() bool // advance; returns false on EOF or error - Current() T // the current element; valid only if Last Next() returned true - Err() error // non-nil if the stream ended because of an error - Close() error // release resources, ensure Next() returns false + Next(ctx context.Context) bool // advance; returns false on EOF or error + Current() T // the current element; valid only if Last Next() returned true + Err() error // non-nil if the stream ended because of an error + Close() error // release resources, ensure Next() returns false } type Verdict struct { diff --git a/connect/openai/stream.go b/connect/openai/stream.go index 69200a1..ad875fb 100644 --- a/connect/openai/stream.go +++ b/connect/openai/stream.go @@ -1,7 +1,7 @@ package openai_connect import ( - "errors" + "context" "github.com/openai/openai-go/v3" "github.com/x2d7/interlude/chat" @@ -27,22 +27,34 @@ type OpenAIStream struct { SSEStream sseStreamer } -func (s *OpenAIStream) Next() bool { +func (s *OpenAIStream) Next(ctx context.Context) bool { if s.err != nil { return false } + // сheck context cancellation before trying to get next chunk + select { + case <-ctx.Done(): + s.err = ctx.Err() + return false + default: + } + // creating a queue if it's empty if len(s.queue) == 0 { if proceed := s.SSEStream.Next(); proceed { // parsing events queue, err := s.handleRawChunk(s.SSEStream.Current()) - // fall if error appears (or empty event list) if err != nil { s.err = err return false } + // skip empty chunks and try next one + if len(queue) == 0 { + return s.Next(ctx) + } + // updating queue to new parsed events s.queue = queue } else { @@ -76,9 +88,6 @@ func (s *OpenAIStream) handleRawChunk(chunk openai.ChatCompletionChunk) ([]chat. if err != nil { return nil, err } - if len(events) == 0 { - return nil, errors.New("empty events") - } return events, nil } @@ -97,6 +106,9 @@ func (s *OpenAIStream) _handleRawChunk(chunk openai.ChatCompletionChunk) ([]chat refusal := delta.Refusal tools := delta.ToolCalls + // TODO: Использование FinishReason + _ = choice.FinishReason + if content != "" { result = append(result, chat.NewEventNewToken(content)) } From 4e2345dc349d36299a47396b93a565103abf9d7b Mon Sep 17 00:00:00 2001 From: 2xd7 Date: Thu, 19 Feb 2026 10:55:07 +0400 Subject: [PATCH 02/19] refactor(stream): add context parameter to Next() method for tests Updates Next() to accept context.Context for improved streaming control. - Update test calls with context.Background() - Remove deprecated empty event error handling tests --- chat/chat_test.go | 2 +- connect/openai/stream_test.go | 58 ++++++++--------------------------- 2 files changed, 13 insertions(+), 47 deletions(-) diff --git a/chat/chat_test.go b/chat/chat_test.go index d1eb418..46abff5 100644 --- a/chat/chat_test.go +++ b/chat/chat_test.go @@ -27,7 +27,7 @@ func NewMockStream(events []StreamEvent, err error) *MockStream { } } -func (s *MockStream) Next() bool { +func (s *MockStream) Next(ctx context.Context) bool { s.mu.Lock() defer s.mu.Unlock() if s.index >= len(s.events)-1 { diff --git a/connect/openai/stream_test.go b/connect/openai/stream_test.go index 92410ff..00d06ac 100644 --- a/connect/openai/stream_test.go +++ b/connect/openai/stream_test.go @@ -1,6 +1,7 @@ package openai_connect import ( + "context" "errors" "testing" @@ -211,24 +212,6 @@ func TestHandleRawChunk_EmptyDelta_ReturnsEmptyList(t *testing.T) { // ==================== handleRawChunk (decorator) Tests ==================== -func TestHandleRawChunkDecorator_EmptyEvents_ReturnsError(t *testing.T) { - s := newStream(newMockSSEStream(nil, nil)) - // Empty delta → _handleRawChunk returns empty slice → decorator returns error - chunk := makeChunk("", "", nil) - - events, err := s.handleRawChunk(chunk) - - if err == nil { - t.Fatal("Expected error for empty events, got nil") - } - if err.Error() != "empty events" { - t.Errorf("Expected 'empty events' error, got '%s'", err.Error()) - } - if events != nil { - t.Errorf("Expected nil events on error, got %v", events) - } -} - func TestHandleRawChunkDecorator_NonEmptyEvents_PassesThrough(t *testing.T) { s := newStream(newMockSSEStream(nil, nil)) chunk := makeChunk("token", "", nil) @@ -250,7 +233,7 @@ func TestNext_SingleChunk_ReturnsTrueAndParsesEvent(t *testing.T) { mock := newMockSSEStream([]openai.ChatCompletionChunk{chunk}, nil) s := newStream(mock) - if !s.Next() { + if !s.Next(context.Background()) { t.Fatal("Expected Next() = true for first chunk") } event := s.Current() @@ -269,7 +252,7 @@ func TestNext_MultipleEventsInOneChunk_DrainsQueue(t *testing.T) { s := newStream(mock) // First Next() — fetches chunk, puts 2 events in queue, returns first - if !s.Next() { + if !s.Next(context.Background()) { t.Fatal("Expected Next() = true (1st)") } first := s.Current() @@ -278,7 +261,7 @@ func TestNext_MultipleEventsInOneChunk_DrainsQueue(t *testing.T) { } // Second Next() — takes from queue without fetching new chunk - if !s.Next() { + if !s.Next(context.Background()) { t.Fatal("Expected Next() = true (2nd, from queue)") } second := s.Current() @@ -287,7 +270,7 @@ func TestNext_MultipleEventsInOneChunk_DrainsQueue(t *testing.T) { } // Third Next() — queue empty, SSE exhausted - if s.Next() { + if s.Next(context.Background()) { t.Fatal("Expected Next() = false after all events consumed") } } @@ -296,7 +279,7 @@ func TestNext_SSEExhausted_ReturnsFalse(t *testing.T) { mock := newMockSSEStream([]openai.ChatCompletionChunk{}, nil) s := newStream(mock) - if s.Next() { + if s.Next(context.Background()) { t.Fatal("Expected Next() = false for empty SSE stream") } } @@ -306,7 +289,7 @@ func TestNext_SSEExhausted_SetsNilErr(t *testing.T) { mock := newMockSSEStream([]openai.ChatCompletionChunk{}, nil) s := newStream(mock) - s.Next() + s.Next(context.Background()) if s.Err() != nil { t.Errorf("Expected nil error after clean SSE end, got %v", s.Err()) @@ -318,7 +301,7 @@ func TestNext_SSEError_StopsAndSetsErr(t *testing.T) { mock := newMockSSEStream([]openai.ChatCompletionChunk{}, apiErr) s := newStream(mock) - if s.Next() { + if s.Next(context.Background()) { t.Fatal("Expected Next() = false when SSE has error") } if !errors.Is(s.Err(), apiErr) { @@ -326,23 +309,6 @@ func TestNext_SSEError_StopsAndSetsErr(t *testing.T) { } } -func TestNext_EmptyChunk_StopsWithEmptyEventsError(t *testing.T) { - // Chunk with empty delta → handleRawChunk returns "empty events" error - chunk := makeChunk("", "", nil) - mock := newMockSSEStream([]openai.ChatCompletionChunk{chunk}, nil) - s := newStream(mock) - - if s.Next() { - t.Fatal("Expected Next() = false when chunk produces empty events") - } - if s.Err() == nil { - t.Fatal("Expected error to be set when chunk is empty") - } - if s.Err().Error() != "empty events" { - t.Errorf("Expected 'empty events' error, got '%s'", s.Err().Error()) - } -} - func TestNext_ErrAlreadySet_ReturnsFalseImmediately(t *testing.T) { chunk := makeChunk("Hello", "", nil) mock := newMockSSEStream([]openai.ChatCompletionChunk{chunk}, nil) @@ -351,7 +317,7 @@ func TestNext_ErrAlreadySet_ReturnsFalseImmediately(t *testing.T) { // Manually set error before calling Next() s.err = errors.New("pre-existing error") - if s.Next() { + if s.Next(context.Background()) { t.Fatal("Expected Next() = false when err is already set") } // SSE should not have been called (mock index remains -1) @@ -370,7 +336,7 @@ func TestNext_MultipleChunks_AllEventsReceived(t *testing.T) { s := newStream(mock) var received []chat.StreamEvent - for s.Next() { + for s.Next(context.Background()) { received = append(received, s.Current()) } @@ -406,7 +372,7 @@ func TestCurrent_AfterNext_ReturnsEvent(t *testing.T) { mock := newMockSSEStream([]openai.ChatCompletionChunk{chunk}, nil) s := newStream(mock) - s.Next() + s.Next(context.Background()) if s.Current() == nil { t.Error("Expected non-nil Current() after Next()") } @@ -428,7 +394,7 @@ func TestErr_AfterSSEError(t *testing.T) { mock := newMockSSEStream(nil, expected) s := newStream(mock) - s.Next() // triggers SSE read → SSE.Next() returns false → s.err = SSE.Err() + s.Next(context.Background()) // triggers SSE read → SSE.Next() returns false → s.err = SSE.Err() if !errors.Is(s.Err(), expected) { t.Errorf("Expected stream error, got %v", s.Err()) From dc5420dc239964ee5ec5ea91b828391edac8cfc7 Mon Sep 17 00:00:00 2001 From: 2xd7 Date: Thu, 19 Feb 2026 11:03:48 +0400 Subject: [PATCH 03/19] test: add context cancellation test for OpenAIStream.Next() --- connect/openai/stream_test.go | 63 +++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/connect/openai/stream_test.go b/connect/openai/stream_test.go index 00d06ac..9915e9e 100644 --- a/connect/openai/stream_test.go +++ b/connect/openai/stream_test.go @@ -415,3 +415,66 @@ func TestClose_DelegatesToSSEStream(t *testing.T) { t.Error("Expected SSEStream.Close() to be called") } } + +// ==================== OpenAIStream.Next() Context Cancellation Tests ==================== + +// infiniteMockSSEStream simulates an infinite stream that never ends +type infiniteMockSSEStream struct { + currentChunk openai.ChatCompletionChunk +} + +func newInfiniteMockSSEStream() *infiniteMockSSEStream { + return &infiniteMockSSEStream{ + currentChunk: makeChunk("token", "", nil), + } +} + +func (m *infiniteMockSSEStream) Next() bool { + return true +} + +func (m *infiniteMockSSEStream) Current() openai.ChatCompletionChunk { + return m.currentChunk +} + +func (m *infiniteMockSSEStream) Err() error { + return nil +} + +func (m *infiniteMockSSEStream) Close() error { + return nil +} + +func TestNext_ContextCancellation_StopsStream(t *testing.T) { + mock := newInfiniteMockSSEStream() + s := newStream(mock) + + // Create a cancellable context + ctx, cancel := context.WithCancel(context.Background()) + + // Get first token - should succeed + if !s.Next(ctx) { + t.Fatal("Expected Next() = true for first chunk") + } + + // Get second token - should succeed + if !s.Next(ctx) { + t.Fatal("Expected Next() = true for second chunk") + } + + // Cancel the context + cancel() + + // Next() should now return false because context is cancelled + if s.Next(ctx) { + t.Fatal("Expected Next() = false after context cancellation") + } + + // Verify that the error is set to context cancellation error + if s.Err() == nil { + t.Fatal("Expected error to be set after context cancellation") + } + if !errors.Is(s.Err(), context.Canceled) { + t.Errorf("Expected context.Canceled error, got %v", s.Err()) + } +} From ec6444c808ba723a19b1c95f1a8e9926bbec28f4 Mon Sep 17 00:00:00 2001 From: 2xd7 Date: Thu, 19 Feb 2026 18:06:01 +0400 Subject: [PATCH 04/19] test: add Execute() calls to verify JSON unmarshalling in tools tests --- chat/tools/tools_test.go | 317 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 313 insertions(+), 4 deletions(-) diff --git a/chat/tools/tools_test.go b/chat/tools/tools_test.go index ce4cb06..a5129a9 100644 --- a/chat/tools/tools_test.go +++ b/chat/tools/tools_test.go @@ -53,6 +53,9 @@ type RecursiveStruct struct { func TestNewTool_WithStruct(t *testing.T) { tool, err := NewTool("test_struct", "test description", func(s SimpleStruct) (string, error) { + if s.Name != "John" || s.Age != 30 { + return "", errors.New("invalid input") + } return "ok", nil }) @@ -72,6 +75,23 @@ func TestNewTool_WithStruct(t *testing.T) { if tool.schema == nil { t.Error("NewTool() schema is nil") } + + tools := NewTools() + err = tools.Add(tool) + if err != nil { + t.Errorf("Add() error = %v", err) + return + } + + result, ok := tools.Execute("test_struct", `{"name": "John", "age": 30}`) + if !ok { + t.Errorf("Execute() error = %v", result) + return + } + + if result != "ok" { + t.Errorf("Execute() result = %v, want %v", result, "ok") + } } func TestNewTool_WithPrimitive(t *testing.T) { @@ -102,15 +122,31 @@ func TestNewTool_WithPrimitive(t *testing.T) { if !strings.Contains(string(schemaJSON), "input") { t.Errorf("NewTool() schema should contain 'input' field, got: %s", schemaJSON) } + + tools := NewTools() + err = tools.Add(tool) + if err != nil { + t.Errorf("Add() error = %v", err) + return + } + + // Try to execute with valid input + result, ok := tools.Execute("test_primitive", `{"input": "hello"}`) + if !ok { + t.Errorf("Execute() error = %v", result) + return + } } func TestNewTool_NestedStruct(t *testing.T) { tool, err := NewTool("test_nested", "nested struct test", func(n NestedStruct) (string, error) { + if n.User.Name != "John" || n.User.Age != 30 || n.Active != true { + return "", errors.New("invalid input") + } return "ok", nil }) if err != nil { - t.Errorf("NewTool() error = %v", err) return } @@ -128,10 +164,28 @@ func TestNewTool_NestedStruct(t *testing.T) { if !strings.Contains(string(schemaJSON), "user") || !strings.Contains(string(schemaJSON), "active") { t.Errorf("NewTool() schema should contain 'user' and 'active' fields, got: %s", schemaJSON) } + + tools := NewTools() + err = tools.Add(tool) + if err != nil { + t.Errorf("Add() error = %v", err) + return + } + + // Try to execute with valid input + result, ok := tools.Execute("test_nested", `{"user": {"name": "John", "age": 30}, "active": true}`) + t.Log("result: ", result) + if !ok { + t.Errorf("Execute() error = %v", result) + return + } } func TestNewTool_EmbeddedStruct(t *testing.T) { tool, err := NewTool("test_embedded", "embedded struct test", func(e EmbeddedStruct) (string, error) { + if e.Name != "John" || e.Age != 30 { + return "", errors.New("invalid input") + } return "ok", nil }) @@ -154,6 +208,21 @@ func TestNewTool_EmbeddedStruct(t *testing.T) { if !strings.Contains(string(schemaJSON), "name") || !strings.Contains(string(schemaJSON), "age") { t.Errorf("NewTool() schema should contain embedded fields 'name' and 'age', got: %s", schemaJSON) } + + tools := NewTools() + err = tools.Add(tool) + if err != nil { + t.Errorf("Add() error = %v", err) + return + } + + // Try to execute with valid input + result, ok := tools.Execute("test_embedded", `{"name": "John", "age": 30}`) + t.Log("result: ", result) + if !ok { + t.Errorf("Execute() error = %v", result) + return + } } func TestNewTool_PrimitiveWithMapPrimitive(t *testing.T) { @@ -162,6 +231,11 @@ func TestNewTool_PrimitiveWithMapPrimitive(t *testing.T) { } tool, err := NewTool("test_map_primitive", "map primitive test", func(m MapPrimitive) (string, error) { + scores := m.Scores + if scores["a"] != 1 || scores["b"] != 2 { + t.Errorf("invalid scores: %v", m) + return "", errors.New("invalid scores") + } return "ok", nil }) @@ -183,10 +257,29 @@ func TestNewTool_PrimitiveWithMapPrimitive(t *testing.T) { if !strings.Contains(string(schemaJSON), "scores") { t.Errorf("NewTool() schema should contain 'scores' field, got: %s", schemaJSON) } + + tools := NewTools() + err = tools.Add(tool) + if err != nil { + t.Errorf("Add() error = %v", err) + return + } + + // Try to execute with valid input + result, ok := tools.Execute("test_map_primitive", `{"input": {"a": 1, "b": 2}}`) + t.Log("result: ", result) + if !ok { + t.Errorf("Execute() error = %v", result) + return + } } func TestNewTool_PrimitiveWithMapStruct(t *testing.T) { tool, err := NewTool("test_map_struct", "map struct test", func(m MapStruct) (string, error) { + users := m.Users + if users["alice"].Age != 25 || users["bob"].Age != 30 { + return "", errors.New("invalid users") + } return "ok", nil }) @@ -208,10 +301,29 @@ func TestNewTool_PrimitiveWithMapStruct(t *testing.T) { if !strings.Contains(string(schemaJSON), "users") { t.Errorf("NewTool() schema should contain 'users' field, got: %s", schemaJSON) } + + tools := NewTools() + err = tools.Add(tool) + if err != nil { + t.Errorf("Add() error = %v", err) + return + } + + // Try to execute with valid input + result, ok := tools.Execute("test_map_struct", `{"users": {"alice": {"age": 25}, "bob": {"age": 30}}}`) + t.Log("result: ", result) + if !ok { + t.Errorf("Execute() error = %v", result) + return + } } func TestNewTool_PrimitiveWithSlicePrimitive(t *testing.T) { tool, err := NewTool("test_slice_primitive", "slice primitive test", func(s SlicePrimitive) (string, error) { + ids := s.IDs + if len(ids) != 3 || ids[0] != 1 || ids[1] != 2 || ids[2] != 3 { + return "", errors.New("invalid ids") + } return "ok", nil }) @@ -233,10 +345,29 @@ func TestNewTool_PrimitiveWithSlicePrimitive(t *testing.T) { if !strings.Contains(string(schemaJSON), "ids") { t.Errorf("NewTool() schema should contain 'ids' field, got: %s", schemaJSON) } + + tools := NewTools() + err = tools.Add(tool) + if err != nil { + t.Errorf("Add() error = %v", err) + return + } + + // Try to execute with valid input + result, ok := tools.Execute("test_slice_primitive", `{"ids": [1, 2, 3]}`) + t.Log("result: ", result) + if !ok { + t.Errorf("Execute() error = %v", result) + return + } } func TestNewTool_PrimitiveWithSliceStruct(t *testing.T) { tool, err := NewTool("test_slice_struct", "slice struct test", func(s SliceStruct) (string, error) { + items := s.Items + if len(items) != 2 || items[0].Name != "Alice" || items[0].Age != 25 || items[1].Name != "Bob" || items[1].Age != 30 { + return "", errors.New("invalid items") + } return "ok", nil }) @@ -258,6 +389,20 @@ func TestNewTool_PrimitiveWithSliceStruct(t *testing.T) { if !strings.Contains(string(schemaJSON), "items") { t.Errorf("NewTool() schema should contain 'items' field, got: %s", schemaJSON) } + + tools := NewTools() + err = tools.Add(tool) + if err != nil { + t.Errorf("Add() error = %v", err) + return + } + + result, ok := tools.Execute("test_slice_struct", `{"items": [{"name": "Alice", "age": 25}, {"name": "Bob", "age": 30}]}`) + t.Log("result: ", result) + if !ok { + t.Errorf("Execute() error = %v", result) + return + } } func TestNewTool_PrimitiveWithMap(t *testing.T) { @@ -267,6 +412,10 @@ func TestNewTool_PrimitiveWithMap(t *testing.T) { } tool, err := NewTool("test_map", "map test", func(m MapInput) (string, error) { + val, ok := m.Data["key"] + if !ok || val != "value" { + return "", errors.New("invalid data") + } return "ok", nil }) @@ -278,10 +427,33 @@ func TestNewTool_PrimitiveWithMap(t *testing.T) { if tool.schema == nil { t.Error("NewTool() schema is nil") } + + tools := NewTools() + err = tools.Add(tool) + if err != nil { + t.Errorf("Add() error = %v", err) + return + } + + result, ok := tools.Execute("test_map", `{"data": {"key": "value"}}`) + if !ok { + t.Errorf("Execute() error = %v", result) + return + } + + if result != "ok" { + t.Errorf("Execute() result = %v, want %v", result, "ok") + } } func TestNewTool_RecursiveStruct(t *testing.T) { tool, err := NewTool("test_recursive", "recursive struct test", func(r RecursiveStruct) (string, error) { + if r.Value != 42 { + return "", errors.New("invalid value") + } + if r.Child == nil || r.Child.Value != 100 { + return "", errors.New("invalid child") + } return "ok", nil }) @@ -303,6 +475,23 @@ func TestNewTool_RecursiveStruct(t *testing.T) { if !strings.Contains(string(schemaJSON), "value") || !strings.Contains(string(schemaJSON), "child") { t.Errorf("NewTool() schema should contain 'value' and 'child' fields, got: %s", schemaJSON) } + + tools := NewTools() + err = tools.Add(tool) + if err != nil { + t.Errorf("Add() error = %v", err) + return + } + + result, ok := tools.Execute("test_recursive", `{"value": 42, "child": {"value": 100}}`) + if !ok { + t.Errorf("Execute() error = %v", result) + return + } + + if result != "ok" { + t.Errorf("Execute() result = %v, want %v", result, "ok") + } } // ============================================================================ @@ -563,6 +752,9 @@ func TestNewTool_PointerType(t *testing.T) { } tool, err := NewTool("pointer_test", "test pointer type", func(s PointerStruct) (string, error) { + if s.Name == nil || *s.Name != "John" { + return "", errors.New("invalid name") + } return "ok", nil }) @@ -574,6 +766,23 @@ func TestNewTool_PointerType(t *testing.T) { if tool.schema == nil { t.Error("NewTool() schema is nil for pointer type") } + + tools := NewTools() + err = tools.Add(tool) + if err != nil { + t.Errorf("Add() error = %v", err) + return + } + + result, ok := tools.Execute("pointer_test", `{"name": "John"}`) + if !ok { + t.Errorf("Execute() error = %v", result) + return + } + + if result != "ok" { + t.Errorf("Execute() result = %v, want %v", result, "ok") + } } func TestNewTool_InterfaceType(t *testing.T) { @@ -582,6 +791,9 @@ func TestNewTool_InterfaceType(t *testing.T) { } tool, err := NewTool("interface_test", "test interface type", func(s InterfaceStruct) (string, error) { + if s.Data == nil { + return "", errors.New("data is nil") + } return "ok", nil }) @@ -593,6 +805,23 @@ func TestNewTool_InterfaceType(t *testing.T) { if tool.schema == nil { t.Error("NewTool() schema is nil for interface type") } + + tools := NewTools() + err = tools.Add(tool) + if err != nil { + t.Errorf("Add() error = %v", err) + return + } + + result, ok := tools.Execute("interface_test", `{"data": "test"}`) + if !ok { + t.Errorf("Execute() error = %v", result) + return + } + + if result != "ok" { + t.Errorf("Execute() result = %v, want %v", result, "ok") + } } func TestNewTool_AnonymousStruct(t *testing.T) { @@ -600,6 +829,9 @@ func TestNewTool_AnonymousStruct(t *testing.T) { tool, err := NewTool("anon_struct", "test anonymous struct", func(s struct { Name string `json:"name"` }) (string, error) { + if s.Name != "John" { + return "", errors.New("invalid name") + } return "ok", nil }) @@ -611,11 +843,31 @@ func TestNewTool_AnonymousStruct(t *testing.T) { if tool.schema == nil { t.Error("NewTool() schema is nil for anonymous struct") } + + tools := NewTools() + err = tools.Add(tool) + if err != nil { + t.Errorf("Add() error = %v", err) + return + } + + result, ok := tools.Execute("anon_struct", `{"name": "John"}`) + if !ok { + t.Errorf("Execute() error = %v", result) + return + } + + if result != "ok" { + t.Errorf("Execute() result = %v, want %v", result, "ok") + } } func TestNewTool_WithIntPrimitive(t *testing.T) { tool, err := NewTool("int_primitive", "test int primitive", func(n int) (string, error) { - return "received", nil + if n != 42 { + return "", errors.New("invalid value") + } + return "ok", nil }) if err != nil { @@ -632,11 +884,31 @@ func TestNewTool_WithIntPrimitive(t *testing.T) { if !strings.Contains(string(schemaJSON), "input") { t.Errorf("NewTool() schema should contain 'input' field, got: %s", schemaJSON) } + + tools := NewTools() + err = tools.Add(tool) + if err != nil { + t.Errorf("Add() error = %v", err) + return + } + + result, ok := tools.Execute("int_primitive", `{"input": 42}`) + if !ok { + t.Errorf("Execute() error = %v", result) + return + } + + if result != "ok" { + t.Errorf("Execute() result = %v, want %v", result, "ok") + } } func TestNewTool_WithBoolPrimitive(t *testing.T) { tool, err := NewTool("bool_primitive", "test bool primitive", func(b bool) (string, error) { - return "received", nil + if b != true { + return "", errors.New("invalid value") + } + return "ok", nil }) if err != nil { @@ -647,6 +919,23 @@ func TestNewTool_WithBoolPrimitive(t *testing.T) { if tool.schema == nil { t.Error("NewTool() schema is nil for bool primitive") } + + tools := NewTools() + err = tools.Add(tool) + if err != nil { + t.Errorf("Add() error = %v", err) + return + } + + result, ok := tools.Execute("bool_primitive", `{"input": true}`) + if !ok { + t.Errorf("Execute() error = %v", result) + return + } + + if result != "ok" { + t.Errorf("Execute() result = %v, want %v", result, "ok") + } } // ============================================================================ @@ -729,7 +1018,10 @@ func containsHelper(s, substr string) bool { func TestNewTool_WithFloatPrimitive(t *testing.T) { tool, err := NewTool("float_primitive", "test float64 primitive", func(f float64) (string, error) { - return "received", nil + if f != 3.14 { + return "", errors.New("invalid value") + } + return "ok", nil }) if err != nil { @@ -740,4 +1032,21 @@ func TestNewTool_WithFloatPrimitive(t *testing.T) { if tool.schema == nil { t.Error("NewTool() schema is nil for float64 primitive") } + + tools := NewTools() + err = tools.Add(tool) + if err != nil { + t.Errorf("Add() error = %v", err) + return + } + + result, ok := tools.Execute("float_primitive", `{"input": 3.14}`) + if !ok { + t.Errorf("Execute() error = %v", result) + return + } + + if result != "ok" { + t.Errorf("Execute() result = %v, want %v", result, "ok") + } } From 9119c32f700ca7939a1c7cb862689bb0e2c118ff Mon Sep 17 00:00:00 2001 From: 2xd7 Date: Thu, 19 Feb 2026 18:25:30 +0400 Subject: [PATCH 05/19] fix: handle primitive types in tool input unmarshaling via reflection --- chat/tools/tools.go | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/chat/tools/tools.go b/chat/tools/tools.go index 62fdb92..a2fba1c 100644 --- a/chat/tools/tools.go +++ b/chat/tools/tools.go @@ -3,19 +3,25 @@ package tools import ( "encoding/json" "fmt" + "reflect" ) func NewTool[T any](name, description string, f func(T) (string, error)) (tool, error) { inputType := ensureInputStructType[T]() - wrapper := func(input string) (string, error) { - raw := []byte(input) + var extract func(reflect.Value) T + if inputType == reflect.TypeFor[T]() { + extract = func(v reflect.Value) T { return v.Interface().(T) } + } else { + extract = func(v reflect.Value) T { return v.Field(0).Interface().(T) } + } - var parsed T - if err := json.Unmarshal(raw, &parsed); err != nil { - return "", fmt.Errorf("unmarshal into %T: %w", parsed, err) + wrapper := func(input string) (string, error) { + ptr := reflect.New(inputType) + if err := json.Unmarshal([]byte(input), ptr.Interface()); err != nil { + return "", fmt.Errorf("unmarshal into %v: %w", inputType, err) } - return f(parsed) + return f(extract(ptr.Elem())) } t := tool{ From 36ecaa4f4988d2642cf69ec727cdea2c50e5e198 Mon Sep 17 00:00:00 2001 From: 2xd7 Date: Thu, 19 Feb 2026 18:37:59 +0400 Subject: [PATCH 06/19] fix(tests): use correct 'scores' field in TestNewTool_PrimitiveWithMapPrimitive execution part --- chat/tools/tools_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chat/tools/tools_test.go b/chat/tools/tools_test.go index a5129a9..0cdea5b 100644 --- a/chat/tools/tools_test.go +++ b/chat/tools/tools_test.go @@ -266,7 +266,7 @@ func TestNewTool_PrimitiveWithMapPrimitive(t *testing.T) { } // Try to execute with valid input - result, ok := tools.Execute("test_map_primitive", `{"input": {"a": 1, "b": 2}}`) + result, ok := tools.Execute("test_map_primitive", `{"scores": {"a": 1, "b": 2}}`) t.Log("result: ", result) if !ok { t.Errorf("Execute() error = %v", result) From 320c74636ee6f5ebbc6f982a5646393f5d801c4f Mon Sep 17 00:00:00 2001 From: 2xd7 Date: Thu, 19 Feb 2026 20:56:20 +0400 Subject: [PATCH 07/19] test: add context cancellation tests for streaming and partial tool call assembling tests --- chat/chat_test.go | 419 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 419 insertions(+) diff --git a/chat/chat_test.go b/chat/chat_test.go index 46abff5..1ce5051 100644 --- a/chat/chat_test.go +++ b/chat/chat_test.go @@ -3,8 +3,11 @@ package chat import ( "context" "errors" + "fmt" + "strings" "sync" "testing" + "time" "github.com/x2d7/interlude/chat/tools" ) @@ -863,3 +866,419 @@ func TestHelpers_SendSystemStream(t *testing.T) { cancel() } + +// ==================== Context Cancellation Tests ==================== + +// TestComplete_ContextCancelledDuringStream tests context cancellation while streaming tokens +func TestComplete_ContextCancelledDuringStream(t *testing.T) { + chat := &Chat{ + Messages: NewMessages(), + Tools: &tools.Tools{}, + } + + // Slow stream that will be cancelled + events := []StreamEvent{ + NewEventNewToken("Hello"), + NewEventNewToken(" world"), + // No completion - stream will be stuck waiting + } + + mockClient := NewMockClient() + mockClient.SetStreamingEvents(events) + + ctx, cancel := context.WithCancel(context.Background()) + result := chat.Complete(ctx, mockClient) + + // Receive first event + event1 := <-result + if event1.GetType() != eventNewToken { + t.Errorf("Expected first event to be token, got %v", event1.GetType()) + } + + // Cancel context + cancel() + + // Give some time for the goroutine to process cancellation + time.Sleep(50 * time.Millisecond) + + // Drain remaining events - after cancellation there should be no more + // because the channel is closed after all events are processed + for event := range result { + t.Logf("Got event after cancellation: %v", event.GetType()) + } +} + +// TestSession_ContextCancelledDuringTokenCollection tests context cancellation during token collection +func TestSession_ContextCancelledDuringTokenCollection(t *testing.T) { + chat := &Chat{ + Messages: NewMessages(), + Tools: &tools.Tools{}, + } + + mockClient := NewMockClient() + // Stream that won't send completion - simulates stuck stream + mockClient.SetStreamingEvents([]StreamEvent{ + NewEventNewToken("Hello"), + // Missing EventCompletionEnded - stream hangs + }) + + ctx, cancel := context.WithCancel(context.Background()) + result := chat.Session(ctx, mockClient) + + // Receive first token event + var receivedToken bool + for event := range result { + if event.GetType() == eventNewToken { + receivedToken = true + break + } + } + + if !receivedToken { + t.Fatal("Expected to receive at least one token") + } + + // Cancel context - should stop the session + cancel() + + // Key thing is session stopped - we verified we received at least one token above +} + +// TestSession_ContextCancelledWhileWaitingForApproval tests context cancellation while waiting for tool approval +func TestSession_ContextCancelledWhileWaitingForApproval(t *testing.T) { + chatTools := tools.NewTools() + chat := &Chat{ + Messages: NewMessages(), + Tools: &chatTools, + } + + tool, err := tools.NewTool[map[string]string]("test-tool", "Test tool", + func(input map[string]string) (string, error) { + return "result", nil + }) + if err != nil { + t.Fatalf("Failed to create tool: %v", err) + } + chat.Tools.Add(tool) + + // Round 1: tool call that requires approval + mockClient := NewMultiRoundMockClient([][]StreamEvent{ + {NewEventNewToolCall("call-1", "test-tool", `{"key": "value"}`), NewEventCompletionEnded()}, + // Round 2 never comes - context cancelled + }) + + ctx, cancel := context.WithCancel(context.Background()) + result := chat.Session(ctx, mockClient) + + // Receive tool call event + var toolCallEvent EventNewToolCall + for event := range result { + if tc, ok := event.(EventNewToolCall); ok { + toolCallEvent = tc + break + } + } + + if toolCallEvent.CallID != "call-1" { + t.Fatalf("Expected to receive tool call event") + } + + // Cancel context BEFORE resolving - simulates user cancelling while waiting for approval + cancel() + + // Tool should NOT have been executed (context was cancelled) + messages := chat.Messages.Snapshot() + toolExecuted := false + for _, msg := range messages { + if msg.GetType() == eventNewToolMessage { + toolExecuted = true + break + } + } + + if toolExecuted { + t.Error("Tool should not have been executed when context was cancelled") + } +} + +// TestSession_ContextCancelledBetweenRounds tests context cancellation between completion rounds +func TestSession_ContextCancelledBetweenRounds(t *testing.T) { + chatTools := tools.NewTools() + chat := &Chat{ + Messages: NewMessages(), + Tools: &chatTools, + } + + tool, err := tools.NewTool[map[string]string]("test-tool", "Test tool", + func(input map[string]string) (string, error) { + return "result", nil + }) + if err != nil { + t.Fatalf("Failed to create tool: %v", err) + } + chat.Tools.Add(tool) + + // First round: tool call + // Second round: would be waiting for new completion + mockClient := NewMultiRoundMockClient([][]StreamEvent{ + {NewEventNewToolCall("call-1", "test-tool", `{"key": "value"}`), NewEventCompletionEnded()}, + }) + + ctx, cancel := context.WithCancel(context.Background()) + result := chat.Session(ctx, mockClient) + + // Handle tool call - resolve it + go func() { + for event := range result { + if tc, ok := event.(EventNewToolCall); ok { + tc.Resolve(true) + } + } + }() + + // Let first round complete + time.Sleep(100 * time.Millisecond) + + // Cancel context between rounds + cancel() + + // Verify only first tool was executed, not second round + messages := chat.Messages.Snapshot() + toolMessageCount := 0 + for _, msg := range messages { + if msg.GetType() == eventNewToolMessage { + toolMessageCount++ + } + } + + if toolMessageCount > 1 { + t.Errorf("Expected at most 1 tool message, got %d", toolMessageCount) + } +} + +// ==================== Tool Call Assembly Tests ==================== + +// TestSession_ToolCallAssembly_Basic tests basic tool call assembly from multiple chunks +func TestSession_ToolCallAssembly_Basic(t *testing.T) { + chatTools := tools.NewTools() + chat := &Chat{ + Messages: NewMessages(), + Tools: &chatTools, + } + + tool, err := tools.NewTool[map[string]string]("weather", "Get weather", + func(input map[string]string) (string, error) { + return "sunny", nil + }) + if err != nil { + t.Fatalf("Failed to create tool: %v", err) + } + chat.Tools.Add(tool) + + // Simulate streaming tool call in chunks: + // 1. First chunk with CallID (start of tool call) + // 2. Subsequent chunks without CallID (continuation) + mockClient := NewMultiRoundMockClient([][]StreamEvent{ + { + NewEventNewToolCall("call-1", "weather", `{"city": "`), + NewEventNewToolCall("", "weather", "Moscow"), + NewEventNewToolCall("", "weather", `"}`), + NewEventCompletionEnded(), + }, + {NewEventCompletionEnded()}, + }) + + ctx := context.Background() + result := chat.Session(ctx, mockClient) + + // Resolve tool call + go func() { + for event := range result { + if tc, ok := event.(EventNewToolCall); ok { + tc.Resolve(true) + } + } + }() + + // Wait for completion + time.Sleep(200 * time.Millisecond) + + // Check that tool call was assembled correctly in history + messages := chat.Messages.Snapshot() + var toolCall EventNewToolCall + for _, msg := range messages { + if tc, ok := msg.(EventNewToolCall); ok { + toolCall = tc + break + } + } + + if toolCall.CallID != "call-1" { + t.Errorf("Expected CallID 'call-1', got '%s'", toolCall.CallID) + } + + // The content should contain assembled JSON + if !strings.Contains(toolCall.Content, "Moscow") { + t.Errorf("Expected assembled content to contain 'Moscow', got '%s'", toolCall.Content) + } +} + +// TestSession_ToolCallAssembly_MultipleToolsWithAssembly tests multiple tool calls where some are assembled +func TestSession_ToolCallAssembly_MultipleToolsWithAssembly(t *testing.T) { + chatTools := tools.NewTools() + chat := &Chat{ + Messages: NewMessages(), + Tools: &chatTools, + } + + tool1, err := tools.NewTool[map[string]string]("tool1", "Tool 1", + func(input map[string]string) (string, error) { + return "result1", nil + }) + if err != nil { + t.Fatalf("Failed to create tool: %v", err) + } + tool2, err := tools.NewTool[map[string]string]("tool2", "Tool 2", + func(input map[string]string) (string, error) { + return "result2", nil + }) + if err != nil { + t.Fatalf("Failed to create tool: %v", err) + } + chat.Tools.Add(tool1) + chat.Tools.Add(tool2) + + // 3 tool calls: + // 1) assembled from chunks (call-1) + // 2) assembled from chunks (call-2) + // 3) complete (call-3) + mockClient := NewMultiRoundMockClient([][]StreamEvent{ + { + NewEventNewToolCall("call-1", "tool1", `{"param": `), + NewEventNewToolCall("", "tool1", "123"), + NewEventNewToolCall("", "tool1", "}"), + NewEventNewToolCall("call-2", "tool2", `{"val": `), + NewEventNewToolCall("", "tool2", "456"), + NewEventNewToolCall("", "tool2", "}"), + NewEventNewToolCall("call-3", "tool2", `{"value": 789}`), + NewEventCompletionEnded(), + }, + {NewEventCompletionEnded()}, + }) + + ctx := context.Background() + result := chat.Session(ctx, mockClient) + + // Resolve tool calls - use channel to avoid race + resolvedCh := make(chan int, 1) + go func() { + resolved := 0 + for event := range result { + if tc, ok := event.(EventNewToolCall); ok { + tc.Resolve(true) + resolved++ + } + } + resolvedCh <- resolved + }() + + // Wait for completion + resolved := <-resolvedCh + close(resolvedCh) + + if resolved != 3 { + t.Errorf("Expected 3 tool calls resolved, got %d", resolved) + } + + // Verify all tool calls were assembled correctly in history + messages := chat.Messages.Snapshot() + toolCalls := make([]EventNewToolCall, 0) + for _, msg := range messages { + if tc, ok := msg.(EventNewToolCall); ok { + toolCalls = append(toolCalls, tc) + } + } + + if len(toolCalls) != 3 { + t.Fatalf("Expected 3 tool calls, got %d", len(toolCalls)) + } + + // First tool (assembled from chunks) + if !strings.Contains(toolCalls[0].Content, "123") { + t.Errorf("First tool should contain assembled '123', got '%s'", toolCalls[0].Content) + } + // Second tool (assembled from chunks) + if !strings.Contains(toolCalls[1].Content, "456") { + t.Errorf("Second tool should contain assembled '456', got '%s'", toolCalls[1].Content) + } + // Third tool (complete) + if !strings.Contains(toolCalls[2].Content, "789") { + t.Errorf("Third tool should contain '789', got '%s'", toolCalls[2].Content) + } +} + +// TestSession_ToolCallAssembly_LargeContent tests assembly of larger content +func TestSession_ToolCallAssembly_LargeContent(t *testing.T) { + chatTools := tools.NewTools() + chat := &Chat{ + Messages: NewMessages(), + Tools: &chatTools, + } + + tool, err := tools.NewTool[map[string]string]("search", "Search", + func(input map[string]string) (string, error) { + return "results", nil + }) + if err != nil { + t.Fatalf("Failed to create tool: %v", err) + } + chat.Tools.Add(tool) + + // Build a large JSON by streaming many small chunks + // Using array syntax: ["chunk0", "chunk1", ...] + events := []StreamEvent{ + NewEventNewToolCall("call-1", "search", `["chunk0"`), + } + // Add more chunks without CallID + for i := 1; i < 20; i++ { + chunk := fmt.Sprintf(",\"chunk%d\"", i) + events = append(events, NewEventNewToolCall("", "search", chunk)) + } + events = append(events, NewEventNewToolCall("", "search", `]`)) + events = append(events, NewEventCompletionEnded()) + + mockClient := NewMultiRoundMockClient([][]StreamEvent{ + events, + {NewEventCompletionEnded()}, + }) + + ctx := context.Background() + result := chat.Session(ctx, mockClient) + + // Resolve tool call + go func() { + for event := range result { + if tc, ok := event.(EventNewToolCall); ok { + tc.Resolve(true) + } + } + }() + + // Wait for completion + time.Sleep(300 * time.Millisecond) + + // Verify tool call was assembled in history + messages := chat.Messages.Snapshot() + var toolCall EventNewToolCall + for _, msg := range messages { + if tc, ok := msg.(EventNewToolCall); ok { + toolCall = tc + break + } + } + + // Should contain assembled content + if !strings.Contains(toolCall.Content, "chunk0") || !strings.Contains(toolCall.Content, "chunk19") { + t.Errorf("Expected assembled content with chunks, got '%s'", toolCall.Content) + } +} From 92331bfc0213aae35d7e8055bf5508213575b1a3 Mon Sep 17 00:00:00 2001 From: 2xd7 Date: Thu, 19 Feb 2026 20:58:53 +0400 Subject: [PATCH 08/19] feat: assemble partial tool calls --- chat/chat.go | 49 +++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 43 insertions(+), 6 deletions(-) diff --git a/chat/chat.go b/chat/chat.go index fec4822..b2dfadd 100644 --- a/chat/chat.go +++ b/chat/chat.go @@ -46,6 +46,23 @@ func (c *Chat) Session(ctx context.Context, client Client) <-chan StreamEvent { result := make(chan StreamEvent, 16) events := c.Complete(ctx, client) + // delivers a StreamEvent to the result channel + // skips nil events + send := func(event StreamEvent) bool { + if event == nil { + if ctx.Err() != nil { + return false + } + return true + } + select { + case result <- event: + return true + case <-ctx.Done(): + return false + } + } + // event handling go func() { defer close(result) @@ -72,6 +89,13 @@ func (c *Chat) Session(ctx context.Context, client Client) <-chan StreamEvent { callAmount := len(toolCalls) + // send every call from the queue + for _, call := range toolCalls { + if !send(call) { + return + } + } + // ending current completion result <- NewEventCompletionEnded() @@ -110,14 +134,24 @@ func (c *Chat) Session(ctx context.Context, client Client) <-chan StreamEvent { // in case of ev changes inside "collecting events" block var modifiedEvent StreamEvent + // in case if we need to skip event + var skipEvent bool + // collecting events switch event := ev.(type) { case EventNewToken: stringBuilder.WriteString(event.Content) case EventNewToolCall: - approval.Attach(&event) - modifiedEvent = event - toolCalls = append(toolCalls, event) + // prevent adding tool call immediately — we need to wait until end of completion + skipEvent = true + // if callid is present — it's the start of a new tool call + if event.CallID != "" { + approval.Attach(&event) + toolCalls = append(toolCalls, event) + } else { + // add token to the last tool call + toolCalls[len(toolCalls)-1].Content += event.Content + } case EventNewRefusal: c.AppendEvent(event) } @@ -127,10 +161,13 @@ func (c *Chat) Session(ctx context.Context, client Client) <-chan StreamEvent { ev = modifiedEvent } + // skipping event + if skipEvent { + continue + } + // sending events to the channel - select { - case result <- ev: - case <-ctx.Done(): + if !send(ev) { return } } From f7fd1af80e58d91038c5e308156dd5afb5dd115f Mon Sep 17 00:00:00 2001 From: 2xd7 Date: Thu, 19 Feb 2026 21:03:59 +0400 Subject: [PATCH 09/19] refactor(chat): remove event modification logic in Session handler --- chat/chat.go | 8 -------- 1 file changed, 8 deletions(-) diff --git a/chat/chat.go b/chat/chat.go index b2dfadd..c0597fa 100644 --- a/chat/chat.go +++ b/chat/chat.go @@ -131,9 +131,6 @@ func (c *Chat) Session(ctx context.Context, client Client) <-chan StreamEvent { continue } - // in case of ev changes inside "collecting events" block - var modifiedEvent StreamEvent - // in case if we need to skip event var skipEvent bool @@ -156,11 +153,6 @@ func (c *Chat) Session(ctx context.Context, client Client) <-chan StreamEvent { c.AppendEvent(event) } - // modifying event - if modifiedEvent != nil { - ev = modifiedEvent - } - // skipping event if skipEvent { continue From db8b13a45934c9a1da8633e0274a98eb51dfc8a0 Mon Sep 17 00:00:00 2001 From: 2xd7 Date: Thu, 19 Feb 2026 23:12:11 +0400 Subject: [PATCH 10/19] test: remove EventCompletionEnded from mock streams since Session handles it --- chat/chat_test.go | 60 ++++++++++++++++++++++++----------------------- 1 file changed, 31 insertions(+), 29 deletions(-) diff --git a/chat/chat_test.go b/chat/chat_test.go index 1ce5051..da0f0f7 100644 --- a/chat/chat_test.go +++ b/chat/chat_test.go @@ -126,7 +126,8 @@ func (c *MultiRoundMockClient) NewStreaming(ctx context.Context) Stream[StreamEv c.mu.Lock() defer c.mu.Unlock() if c.roundIndex >= len(c.Rounds) { - return NewMockStream([]StreamEvent{NewEventCompletionEnded()}, nil) + // Return empty stream - Session will handle completion signal itself + return NewMockStream([]StreamEvent{}, nil) } events := c.Rounds[c.roundIndex] c.roundIndex++ @@ -395,12 +396,11 @@ func TestSession_CollectsTokens(t *testing.T) { } mockClient := NewMockClient() - // Need to send EventCompletionEnded to trigger collection of tokens // Note: tokens are accumulated into a single message + // EventCompletionEnded is generated by Session, not the mock mockClient.SetStreamingEvents([]StreamEvent{ NewEventNewToken("Hello"), NewEventNewToken(" world"), - NewEventCompletionEnded(), }) ctx := context.Background() @@ -441,11 +441,11 @@ func TestSession_CollectsToolCalls(t *testing.T) { Tools: &tools.Tools{}, } - // Round 1: one tool call + completion signal - // Round 2: just completion (session exits cleanly, no extra tool call) + // Round 1: one tool call (Session generates CompletionEnded internally) + // Round 2: empty (session exits cleanly, no extra tool call) mockClient := NewMultiRoundMockClient([][]StreamEvent{ - {NewEventNewToolCall("call-1", "tool1", `{}`), NewEventCompletionEnded()}, - {NewEventCompletionEnded()}, + {NewEventNewToolCall("call-1", "tool1", `{}`)}, + {}, }) ctx := context.Background() @@ -483,9 +483,9 @@ func TestSession_EmitsCompletionEnded(t *testing.T) { } mockClient := NewMockClient() + // Note: EventCompletionEnded is generated by Session, not the mock mockClient.SetStreamingEvents([]StreamEvent{ NewEventNewToken("test"), - NewEventCompletionEnded(), }) ctx, cancel := context.WithCancel(context.Background()) @@ -514,9 +514,9 @@ func TestSession_Refusal(t *testing.T) { } mockClient := NewMockClient() + // Note: EventCompletionEnded is generated by Session, not the mock mockClient.SetStreamingEvents([]StreamEvent{ NewEventNewRefusal("I cannot help with that"), - NewEventCompletionEnded(), }) ctx, cancel := context.WithCancel(context.Background()) @@ -564,10 +564,10 @@ func TestSession_ToolAccepted(t *testing.T) { } chat.Tools.Add(tool) - // Round 1: tool call, Round 2: empty completion to finish + // Round 1: tool call, Round 2: empty to finish (Session generates CompletionEnded) mockClient := NewMultiRoundMockClient([][]StreamEvent{ - {NewEventNewToolCall("call-1", "test-tool", `{"key": "value"}`), NewEventCompletionEnded()}, - {NewEventCompletionEnded()}, + {NewEventNewToolCall("call-1", "test-tool", `{"key": "value"}`)}, + {}, }) ctx := context.Background() @@ -620,10 +620,10 @@ func TestSession_ToolRejected(t *testing.T) { } chat.Tools.Add(tool) - // Round 1: tool call, Round 2: empty completion to finish + // Round 1: tool call, Round 2: empty to finish (Session generates CompletionEnded) mockClient := NewMultiRoundMockClient([][]StreamEvent{ - {NewEventNewToolCall("call-1", "test-tool", `{"key": "value"}`), NewEventCompletionEnded()}, - {NewEventCompletionEnded()}, + {NewEventNewToolCall("call-1", "test-tool", `{"key": "value"}`)}, + {}, }) ctx := context.Background() @@ -668,10 +668,10 @@ func TestSession_NonExistentTool(t *testing.T) { Tools: &chatTools, } - // Round 1: tool call, Round 2: empty completion to finish + // Round 1: tool call, Round 2: empty to finish (Session generates CompletionEnded) mockClient := NewMultiRoundMockClient([][]StreamEvent{ - {NewEventNewToolCall("call-1", "non-existent", `{}`), NewEventCompletionEnded()}, - {NewEventCompletionEnded()}, + {NewEventNewToolCall("call-1", "non-existent", `{}`)}, + {}, }) ctx := context.Background() @@ -961,9 +961,9 @@ func TestSession_ContextCancelledWhileWaitingForApproval(t *testing.T) { } chat.Tools.Add(tool) - // Round 1: tool call that requires approval + // Round 1: tool call that requires approval (Session generates CompletionEnded) mockClient := NewMultiRoundMockClient([][]StreamEvent{ - {NewEventNewToolCall("call-1", "test-tool", `{"key": "value"}`), NewEventCompletionEnded()}, + {NewEventNewToolCall("call-1", "test-tool", `{"key": "value"}`)}, // Round 2 never comes - context cancelled }) @@ -1018,10 +1018,9 @@ func TestSession_ContextCancelledBetweenRounds(t *testing.T) { } chat.Tools.Add(tool) - // First round: tool call - // Second round: would be waiting for new completion + // First round: tool call (Session generates CompletionEnded) mockClient := NewMultiRoundMockClient([][]StreamEvent{ - {NewEventNewToolCall("call-1", "test-tool", `{"key": "value"}`), NewEventCompletionEnded()}, + {NewEventNewToolCall("call-1", "test-tool", `{"key": "value"}`)}, }) ctx, cancel := context.WithCancel(context.Background()) @@ -1078,14 +1077,15 @@ func TestSession_ToolCallAssembly_Basic(t *testing.T) { // Simulate streaming tool call in chunks: // 1. First chunk with CallID (start of tool call) // 2. Subsequent chunks without CallID (continuation) + // Session generates CompletionEnded internally, so we don't add it to mocks mockClient := NewMultiRoundMockClient([][]StreamEvent{ { NewEventNewToolCall("call-1", "weather", `{"city": "`), NewEventNewToolCall("", "weather", "Moscow"), NewEventNewToolCall("", "weather", `"}`), - NewEventCompletionEnded(), }, - {NewEventCompletionEnded()}, + // Second round: empty (session exits cleanly) + {}, }) ctx := context.Background() @@ -1152,6 +1152,7 @@ func TestSession_ToolCallAssembly_MultipleToolsWithAssembly(t *testing.T) { // 1) assembled from chunks (call-1) // 2) assembled from chunks (call-2) // 3) complete (call-3) + // Session generates CompletionEnded internally, so we don't add it to mocks mockClient := NewMultiRoundMockClient([][]StreamEvent{ { NewEventNewToolCall("call-1", "tool1", `{"param": `), @@ -1161,9 +1162,9 @@ func TestSession_ToolCallAssembly_MultipleToolsWithAssembly(t *testing.T) { NewEventNewToolCall("", "tool2", "456"), NewEventNewToolCall("", "tool2", "}"), NewEventNewToolCall("call-3", "tool2", `{"value": 789}`), - NewEventCompletionEnded(), }, - {NewEventCompletionEnded()}, + // Second round: empty (session exits cleanly) + {}, }) ctx := context.Background() @@ -1236,6 +1237,7 @@ func TestSession_ToolCallAssembly_LargeContent(t *testing.T) { // Build a large JSON by streaming many small chunks // Using array syntax: ["chunk0", "chunk1", ...] + // Session generates CompletionEnded internally, so we don't add it to mocks events := []StreamEvent{ NewEventNewToolCall("call-1", "search", `["chunk0"`), } @@ -1245,11 +1247,11 @@ func TestSession_ToolCallAssembly_LargeContent(t *testing.T) { events = append(events, NewEventNewToolCall("", "search", chunk)) } events = append(events, NewEventNewToolCall("", "search", `]`)) - events = append(events, NewEventCompletionEnded()) mockClient := NewMultiRoundMockClient([][]StreamEvent{ events, - {NewEventCompletionEnded()}, + // Second round: empty (session exits cleanly) + {}, }) ctx := context.Background() From caf8a8cd152bcb9c144799fd1d4e822de10f6347 Mon Sep 17 00:00:00 2001 From: 2xd7 Date: Thu, 19 Feb 2026 23:19:53 +0400 Subject: [PATCH 11/19] test: add test for interleaved token and tool call event ordering --- chat/chat_test.go | 132 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 132 insertions(+) diff --git a/chat/chat_test.go b/chat/chat_test.go index da0f0f7..a104815 100644 --- a/chat/chat_test.go +++ b/chat/chat_test.go @@ -1284,3 +1284,135 @@ func TestSession_ToolCallAssembly_LargeContent(t *testing.T) { t.Errorf("Expected assembled content with chunks, got '%s'", toolCall.Content) } } + +// TestSession_MixedTokensAndToolCalls verifies the interleaved event ordering +// when text tokens and tool calls alternate in a single completion round. +// +// Stream sequence: token("A") → toolCall("call-1") → token("B") → toolCall("call-2") → token("C") +// +// Expected consumer order: A → call-1 → B → call-2 → C → CompletionEnded +// +// This confirms that tool calls are emitted at the point they occur +// (when the next non-toolcall event arrives), NOT batched at the end of the round. +func TestSession_MixedTokensAndToolCalls(t *testing.T) { + chatTools := tools.NewTools() + chat := &Chat{ + Messages: NewMessages(), + Tools: &chatTools, + } + + tool, err := tools.NewTool[map[string]string]("lookup", "Lookup tool", + func(input map[string]string) (string, error) { + return "ok", nil + }) + if err != nil { + t.Fatalf("Failed to create tool: %v", err) + } + chat.Tools.Add(tool) + + // No NewEventCompletionEnded() here — Session generates it from channel close. + // Round 2 is empty — causes Session to emit final CompletionEnded and exit. + mockClient := NewMultiRoundMockClient([][]StreamEvent{ + { + NewEventNewToken("A"), + NewEventNewToolCall("call-1", "lookup", `{"q":"1"}`), + NewEventNewToken("B"), + NewEventNewToolCall("call-2", "lookup", `{"q":"2"}`), + NewEventNewToken("C"), + }, + {}, // empty round — triggers final CompletionEnded + Session exit + }) + + ctx := context.Background() + result := chat.Session(ctx, mockClient) + + type record struct { + kind string // "token" | "tool" | "end" + content string // token text or tool CallID + } + + var order []record + + for event := range result { + switch e := event.(type) { + case EventNewToken: + order = append(order, record{"token", e.Content}) + case EventNewToolCall: + order = append(order, record{"tool", e.CallID}) + e.Resolve(true) // accept all tool calls + case EventCompletionEnded: + order = append(order, record{"end", ""}) + } + } + + // --- helpers --- + posOf := func(kind, content string) int { + for i, r := range order { + if r.kind == kind && r.content == content { + return i + } + } + return -1 + } + + posA := posOf("token", "A") + posCall1 := posOf("tool", "call-1") + posB := posOf("token", "B") + posCall2 := posOf("tool", "call-2") + posC := posOf("token", "C") + posEnd := posOf("end", "") + + t.Logf("Event order: %v", order) + + // All events must be present + for name, pos := range map[string]int{ + "token:A": posA, + "call-1": posCall1, + "token:B": posB, + "call-2": posCall2, + "token:C": posC, + "end": posEnd, + } { + if pos == -1 { + t.Errorf("Missing event: %s", name) + } + } + + if t.Failed() { + return + } + + // Verify interleaved ordering: + // A comes before call-1 (token before its following tool) + if posA >= posCall1 { + t.Errorf("Expected token:A (%d) before call-1 (%d)", posA, posCall1) + } + // call-1 is flushed BEFORE token:B (tool before next token) + if posCall1 >= posB { + t.Errorf("Expected call-1 (%d) before token:B (%d)", posCall1, posB) + } + // B comes before call-2 + if posB >= posCall2 { + t.Errorf("Expected token:B (%d) before call-2 (%d)", posB, posCall2) + } + // call-2 is flushed BEFORE token:C + if posCall2 >= posC { + t.Errorf("Expected call-2 (%d) before token:C (%d)", posCall2, posC) + } + // CompletionEnded is last + if posC >= posEnd { + t.Errorf("Expected token:C (%d) before CompletionEnded (%d)", posC, posEnd) + } + + // Verify tool calls also appear in history with correct data + messages := chat.Messages.Snapshot() + toolCallsInHistory := 0 + for _, msg := range messages { + if msg.GetType() == eventNewToolCall { + toolCallsInHistory++ + } + } + if toolCallsInHistory != 2 { + t.Errorf("Expected 2 tool calls in history, got %d", toolCallsInHistory) + } +} From 550d36ae2500ab9db67e0581933848fdef47c8e3 Mon Sep 17 00:00:00 2001 From: 2xd7 Date: Thu, 19 Feb 2026 23:21:59 +0400 Subject: [PATCH 12/19] fix: correct tool call event ordering in stream processing --- chat/chat.go | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/chat/chat.go b/chat/chat.go index c0597fa..a437425 100644 --- a/chat/chat.go +++ b/chat/chat.go @@ -70,6 +70,7 @@ func (c *Chat) Session(ctx context.Context, client Client) <-chan StreamEvent { // event collectors var stringBuilder strings.Builder toolCalls := make([]EventNewToolCall, 0) + var lastToolCall *EventNewToolCall approval := NewApproveWaiter() @@ -89,9 +90,9 @@ func (c *Chat) Session(ctx context.Context, client Client) <-chan StreamEvent { callAmount := len(toolCalls) - // send every call from the queue - for _, call := range toolCalls { - if !send(call) { + // send last tool call if it wasn't sent yet + if lastToolCall != nil { + if !send(*lastToolCall) { return } } @@ -121,6 +122,7 @@ func (c *Chat) Session(ctx context.Context, client Client) <-chan StreamEvent { // reset collectors stringBuilder.Reset() toolCalls = make([]EventNewToolCall, 0) + lastToolCall = nil // reset approval waiter approval = NewApproveWaiter() @@ -131,6 +133,16 @@ func (c *Chat) Session(ctx context.Context, client Client) <-chan StreamEvent { continue } + // flush last tool call if event type switched away from tool call stream + if lastToolCall != nil { + if _, isToolCall := ev.(EventNewToolCall); !isToolCall { + if !send(*lastToolCall) { + return + } + lastToolCall = nil + } + } + // in case if we need to skip event var skipEvent bool @@ -141,13 +153,19 @@ func (c *Chat) Session(ctx context.Context, client Client) <-chan StreamEvent { case EventNewToolCall: // prevent adding tool call immediately — we need to wait until end of completion skipEvent = true - // if callid is present — it's the start of a new tool call if event.CallID != "" { + // flush the previous tool call — it's now complete + if lastToolCall != nil { + if !send(*lastToolCall) { + return + } + } approval.Attach(&event) toolCalls = append(toolCalls, event) + lastToolCall = &toolCalls[len(toolCalls)-1] } else { // add token to the last tool call - toolCalls[len(toolCalls)-1].Content += event.Content + lastToolCall.Content += event.Content } case EventNewRefusal: c.AppendEvent(event) From ed56bad6c63c8d9853587767b587f50ca5e0cfc7 Mon Sep 17 00:00:00 2001 From: 2xd7 Date: Thu, 19 Feb 2026 23:23:02 +0400 Subject: [PATCH 13/19] fix: handle empty choices in OpenAI streaming response --- connect/openai/stream.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/connect/openai/stream.go b/connect/openai/stream.go index ad875fb..41dd02e 100644 --- a/connect/openai/stream.go +++ b/connect/openai/stream.go @@ -98,6 +98,9 @@ func (s *OpenAIStream) handleRawChunk(chunk openai.ChatCompletionChunk) ([]chat. // Should not return empty list. It would be considered as an error func (s *OpenAIStream) _handleRawChunk(chunk openai.ChatCompletionChunk) ([]chat.StreamEvent, error) { result := make([]chat.StreamEvent, 0) + if len(chunk.Choices) == 0 { + return result, nil + } choice := chunk.Choices[0] delta := choice.Delta From 5805eb0b9d33de868c2bdead58f158da9dfaa464 Mon Sep 17 00:00:00 2001 From: 2xd7 Date: Thu, 19 Feb 2026 23:45:58 +0400 Subject: [PATCH 14/19] refactor(chat): extract session state into dedicated struct --- chat/chat.go | 66 ++++++++++++++++++++++++++++------------------------ 1 file changed, 36 insertions(+), 30 deletions(-) diff --git a/chat/chat.go b/chat/chat.go index a437425..feff752 100644 --- a/chat/chat.go +++ b/chat/chat.go @@ -38,6 +38,20 @@ func (c *Chat) Complete(ctx context.Context, client Client) <-chan StreamEvent { return result } +type sessionState struct { + builder strings.Builder + toolCalls []EventNewToolCall + lastToolCall *EventNewToolCall + approval *ApproveWaiter +} + +func (s *sessionState) reset() { + s.builder.Reset() + s.toolCalls = s.toolCalls[:0] + s.lastToolCall = nil + s.approval = NewApproveWaiter() +} + func (c *Chat) Session(ctx context.Context, client Client) <-chan StreamEvent { // insert chat context into client input configuration client = client.SyncInput(c) @@ -67,12 +81,9 @@ func (c *Chat) Session(ctx context.Context, client Client) <-chan StreamEvent { go func() { defer close(result) - // event collectors - var stringBuilder strings.Builder - toolCalls := make([]EventNewToolCall, 0) - var lastToolCall *EventNewToolCall - - approval := NewApproveWaiter() + // session state + state := &sessionState{} + state.reset() for { select { @@ -81,18 +92,18 @@ func (c *Chat) Session(ctx context.Context, client Client) <-chan StreamEvent { case ev, ok := <-events: if !ok { // adding collected events to the chat (assistant's tokens and tool calls) - if stringBuilder.Len() != 0 { - c.AppendEvent(NewEventNewToken(stringBuilder.String())) + if state.builder.Len() != 0 { + c.AppendEvent(NewEventNewToken(state.builder.String())) } - for _, call := range toolCalls { + for _, call := range state.toolCalls { c.AppendEvent(call) } - callAmount := len(toolCalls) + callAmount := len(state.toolCalls) // send last tool call if it wasn't sent yet - if lastToolCall != nil { - if !send(*lastToolCall) { + if state.lastToolCall != nil { + if !send(*state.lastToolCall) { return } } @@ -105,7 +116,7 @@ func (c *Chat) Session(ctx context.Context, client Client) <-chan StreamEvent { } // initializing approval waiter - verdicts := approval.Wait(ctx, callAmount) + verdicts := state.approval.Wait(ctx, callAmount) // processing user verdicts for verdict := range verdicts { @@ -119,13 +130,8 @@ func (c *Chat) Session(ctx context.Context, client Client) <-chan StreamEvent { } } - // reset collectors - stringBuilder.Reset() - toolCalls = make([]EventNewToolCall, 0) - lastToolCall = nil - - // reset approval waiter - approval = NewApproveWaiter() + // reset state + state.reset() // resume text completion client = client.SyncInput(c) @@ -134,12 +140,12 @@ func (c *Chat) Session(ctx context.Context, client Client) <-chan StreamEvent { } // flush last tool call if event type switched away from tool call stream - if lastToolCall != nil { + if state.lastToolCall != nil { if _, isToolCall := ev.(EventNewToolCall); !isToolCall { - if !send(*lastToolCall) { + if !send(*state.lastToolCall) { return } - lastToolCall = nil + state.lastToolCall = nil } } @@ -149,23 +155,23 @@ func (c *Chat) Session(ctx context.Context, client Client) <-chan StreamEvent { // collecting events switch event := ev.(type) { case EventNewToken: - stringBuilder.WriteString(event.Content) + state.builder.WriteString(event.Content) case EventNewToolCall: // prevent adding tool call immediately — we need to wait until end of completion skipEvent = true if event.CallID != "" { // flush the previous tool call — it's now complete - if lastToolCall != nil { - if !send(*lastToolCall) { + if state.lastToolCall != nil { + if !send(*state.lastToolCall) { return } } - approval.Attach(&event) - toolCalls = append(toolCalls, event) - lastToolCall = &toolCalls[len(toolCalls)-1] + state.approval.Attach(&event) + state.toolCalls = append(state.toolCalls, event) + state.lastToolCall = &state.toolCalls[len(state.toolCalls)-1] } else { // add token to the last tool call - lastToolCall.Content += event.Content + state.lastToolCall.Content += event.Content } case EventNewRefusal: c.AppendEvent(event) From 4c88ac55d103cff82df2244757938ae3024ba983 Mon Sep 17 00:00:00 2001 From: 2xd7 Date: Thu, 19 Feb 2026 23:49:27 +0400 Subject: [PATCH 15/19] refactor(chat): extract lastToolCall flush logic into dedicated method --- chat/chat.go | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/chat/chat.go b/chat/chat.go index feff752..13a7bd6 100644 --- a/chat/chat.go +++ b/chat/chat.go @@ -52,6 +52,15 @@ func (s *sessionState) reset() { s.approval = NewApproveWaiter() } +func (s *sessionState) flushLastToolCall(send func(StreamEvent) bool) bool { + if s.lastToolCall == nil { + return true + } + ok := send(*s.lastToolCall) + s.lastToolCall = nil + return ok +} + func (c *Chat) Session(ctx context.Context, client Client) <-chan StreamEvent { // insert chat context into client input configuration client = client.SyncInput(c) @@ -102,10 +111,8 @@ func (c *Chat) Session(ctx context.Context, client Client) <-chan StreamEvent { callAmount := len(state.toolCalls) // send last tool call if it wasn't sent yet - if state.lastToolCall != nil { - if !send(*state.lastToolCall) { - return - } + if !state.flushLastToolCall(send) { + return } // ending current completion @@ -140,12 +147,9 @@ func (c *Chat) Session(ctx context.Context, client Client) <-chan StreamEvent { } // flush last tool call if event type switched away from tool call stream - if state.lastToolCall != nil { - if _, isToolCall := ev.(EventNewToolCall); !isToolCall { - if !send(*state.lastToolCall) { - return - } - state.lastToolCall = nil + if _, isToolCall := ev.(EventNewToolCall); !isToolCall { + if !state.flushLastToolCall(send) { + return } } @@ -161,10 +165,8 @@ func (c *Chat) Session(ctx context.Context, client Client) <-chan StreamEvent { skipEvent = true if event.CallID != "" { // flush the previous tool call — it's now complete - if state.lastToolCall != nil { - if !send(*state.lastToolCall) { - return - } + if !state.flushLastToolCall(send) { + return } state.approval.Attach(&event) state.toolCalls = append(state.toolCalls, event) From 7b026b940deee724a0f95032615e08f35f1bbed7 Mon Sep 17 00:00:00 2001 From: 2xd7 Date: Thu, 19 Feb 2026 23:51:34 +0400 Subject: [PATCH 16/19] refactor: send completion ended event via helper instead of direct channel send --- chat/chat.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/chat/chat.go b/chat/chat.go index 13a7bd6..3f4916b 100644 --- a/chat/chat.go +++ b/chat/chat.go @@ -116,7 +116,9 @@ func (c *Chat) Session(ctx context.Context, client Client) <-chan StreamEvent { } // ending current completion - result <- NewEventCompletionEnded() + if !send(NewEventCompletionEnded()) { + return + } if callAmount == 0 { return From aeb79054e76a43704c307f6c0e5f6591b9a41a09 Mon Sep 17 00:00:00 2001 From: 2xd7 Date: Fri, 20 Feb 2026 00:03:09 +0400 Subject: [PATCH 17/19] refactor(chat): move client and events into sessionState struct --- chat/chat.go | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/chat/chat.go b/chat/chat.go index 3f4916b..ee3bc02 100644 --- a/chat/chat.go +++ b/chat/chat.go @@ -39,6 +39,13 @@ func (c *Chat) Complete(ctx context.Context, client Client) <-chan StreamEvent { } type sessionState struct { + // session context + + client Client + events <-chan StreamEvent + + // session state variables + builder strings.Builder toolCalls []EventNewToolCall lastToolCall *EventNewToolCall @@ -67,7 +74,6 @@ func (c *Chat) Session(ctx context.Context, client Client) <-chan StreamEvent { // creating the channels result := make(chan StreamEvent, 16) - events := c.Complete(ctx, client) // delivers a StreamEvent to the result channel // skips nil events @@ -91,14 +97,17 @@ func (c *Chat) Session(ctx context.Context, client Client) <-chan StreamEvent { defer close(result) // session state - state := &sessionState{} + state := &sessionState{ + client: client, + events: c.Complete(ctx, client), + } state.reset() for { select { case <-ctx.Done(): return - case ev, ok := <-events: + case ev, ok := <-state.events: if !ok { // adding collected events to the chat (assistant's tokens and tool calls) if state.builder.Len() != 0 { @@ -143,8 +152,8 @@ func (c *Chat) Session(ctx context.Context, client Client) <-chan StreamEvent { state.reset() // resume text completion - client = client.SyncInput(c) - events = c.Complete(ctx, client) + state.client = state.client.SyncInput(c) + state.events = c.Complete(ctx, state.client) continue } From ab78269b6c181e91138b1e26f1dcae96ec87ef50 Mon Sep 17 00:00:00 2001 From: 2xd7 Date: Fri, 20 Feb 2026 00:13:17 +0400 Subject: [PATCH 18/19] refactor(chat): extract handleCompletionEnd method and move send to sessionState --- chat/chat.go | 116 ++++++++++++++++++++++++++++----------------------- 1 file changed, 63 insertions(+), 53 deletions(-) diff --git a/chat/chat.go b/chat/chat.go index ee3bc02..ba490f3 100644 --- a/chat/chat.go +++ b/chat/chat.go @@ -43,7 +43,8 @@ type sessionState struct { client Client events <-chan StreamEvent - + send func(StreamEvent) bool + // session state variables builder strings.Builder @@ -59,13 +60,13 @@ func (s *sessionState) reset() { s.approval = NewApproveWaiter() } -func (s *sessionState) flushLastToolCall(send func(StreamEvent) bool) bool { - if s.lastToolCall == nil { - return true - } - ok := send(*s.lastToolCall) - s.lastToolCall = nil - return ok +func (s *sessionState) flushLastToolCall() bool { + if s.lastToolCall == nil { + return true + } + ok := s.send(*s.lastToolCall) + s.lastToolCall = nil + return ok } func (c *Chat) Session(ctx context.Context, client Client) <-chan StreamEvent { @@ -100,6 +101,7 @@ func (c *Chat) Session(ctx context.Context, client Client) <-chan StreamEvent { state := &sessionState{ client: client, events: c.Complete(ctx, client), + send: send, } state.reset() @@ -109,57 +111,15 @@ func (c *Chat) Session(ctx context.Context, client Client) <-chan StreamEvent { return case ev, ok := <-state.events: if !ok { - // adding collected events to the chat (assistant's tokens and tool calls) - if state.builder.Len() != 0 { - c.AppendEvent(NewEventNewToken(state.builder.String())) - } - for _, call := range state.toolCalls { - c.AppendEvent(call) - } - - callAmount := len(state.toolCalls) - - // send last tool call if it wasn't sent yet - if !state.flushLastToolCall(send) { - return - } - - // ending current completion - if !send(NewEventCompletionEnded()) { + if !c.handleCompletionEnd(ctx, state) { return } - - if callAmount == 0 { - return - } - - // initializing approval waiter - verdicts := state.approval.Wait(ctx, callAmount) - - // processing user verdicts - for verdict := range verdicts { - call := verdict.call - - if verdict.Accepted { - callResult, success := c.Tools.Execute(call.Name, call.Content) - c.AppendEvent(NewEventNewToolMessage(call.CallID, callResult, success)) - } else { - c.AppendEvent(NewEventNewToolMessage(call.CallID, "User declined the tool call", false)) - } - } - - // reset state - state.reset() - - // resume text completion - state.client = state.client.SyncInput(c) - state.events = c.Complete(ctx, state.client) continue } // flush last tool call if event type switched away from tool call stream if _, isToolCall := ev.(EventNewToolCall); !isToolCall { - if !state.flushLastToolCall(send) { + if !state.flushLastToolCall() { return } } @@ -176,7 +136,7 @@ func (c *Chat) Session(ctx context.Context, client Client) <-chan StreamEvent { skipEvent = true if event.CallID != "" { // flush the previous tool call — it's now complete - if !state.flushLastToolCall(send) { + if !state.flushLastToolCall() { return } state.approval.Attach(&event) @@ -205,3 +165,53 @@ func (c *Chat) Session(ctx context.Context, client Client) <-chan StreamEvent { return result } + +func (c *Chat) handleCompletionEnd(ctx context.Context, state *sessionState) (proceed bool) { + // adding collected events to the chat (assistant's tokens and tool calls) + if state.builder.Len() != 0 { + c.AppendEvent(NewEventNewToken(state.builder.String())) + } + for _, call := range state.toolCalls { + c.AppendEvent(call) + } + + callAmount := len(state.toolCalls) + + // send last tool call if it wasn't sent yet + if !state.flushLastToolCall() { + return false + } + + // ending current completion + if !state.send(NewEventCompletionEnded()) { + return false + } + + if callAmount == 0 { + return false + } + + // initializing approval waiter + verdicts := state.approval.Wait(ctx, callAmount) + + // processing user verdicts + for verdict := range verdicts { + call := verdict.call + + if verdict.Accepted { + callResult, success := c.Tools.Execute(call.Name, call.Content) + c.AppendEvent(NewEventNewToolMessage(call.CallID, callResult, success)) + } else { + c.AppendEvent(NewEventNewToolMessage(call.CallID, "User declined the tool call", false)) + } + } + + // reset state + state.reset() + + // resume text completion + state.client = state.client.SyncInput(c) + state.events = c.Complete(ctx, state.client) + + return true +} From 0cabafb90d55f84914d26e9455307f95447d8426 Mon Sep 17 00:00:00 2001 From: 2xd7 Date: Fri, 20 Feb 2026 00:20:32 +0400 Subject: [PATCH 19/19] chore: add TODO comment for EventCompletionEnded struct --- chat/events.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chat/events.go b/chat/events.go index 533b936..03c65bc 100644 --- a/chat/events.go +++ b/chat/events.go @@ -31,7 +31,7 @@ type EventBase struct { Content string } -type EventCompletionEnded struct{} +type EventCompletionEnded struct{} // TODO: можно добавлять список вызовов инструментов и другую информацию о генерации func (e EventCompletionEnded) GetType() eventType { return eventCompletionEnded }