From bad7c307a9add198ea0c06a0e054d79cffa509f2 Mon Sep 17 00:00:00 2001 From: bogwi Date: Mon, 3 Nov 2025 01:06:28 +0900 Subject: [PATCH 1/2] feat: implement threading fallback with automatic client-side fallback Implement graceful fallback from server-side to client-side threading when preconditions aren't met, ensuring conversations continue successfully. - Add response ID chain validation to detect broken threading chains - Enhance CallModelWithOpts with automatic fallback logic - Add structured logging for threading decisions and fallbacks - Implement comprehensive test coverage for all fallback scenarios - Handle edge cases and update documentation --- contextwindow.go | 166 ++++++++--- contextwindow_test.go | 677 +++++++++++++++++++++++++++++++++++++++++- storage.go | 90 ++++++ storage_test.go | 411 +++++++++++++++++++++++++ 4 files changed, 1295 insertions(+), 49 deletions(-) create mode 100644 storage_test.go diff --git a/contextwindow.go b/contextwindow.go index 5d9a06b..8d2baa3 100644 --- a/contextwindow.go +++ b/contextwindow.go @@ -42,7 +42,7 @@ // cw.AddTool(lsTool, contextwindow.ToolRunnerFunc(func context.Context, // args json.RawMessage) (string, error) { // var treq struct { -// Dir string `json:"directory"` +// Dir string `json:"directory"` // } // json.Unmarshal(args, &treq) // // actually run ls, or pretend to @@ -64,7 +64,7 @@ // // summarizerModel, err := openai.New(apiKey, "gpt-3.5-turbo") // if err != nil { -// log.Fatalf("Failed to create summarizer: %v", err) +// log.Fatalf("Failed to create summarizer: %v", err) // } // // cw, err := contextwindow.New(model, summarizerModel, "") @@ -81,19 +81,33 @@ // LLM conversations are stored in SQLite. If you don't care about persistant // storage for your context, just specify ":memory:" as your database path. // +// # Threading and Fallback Behavior +// +// When server-side threading is enabled, the library attempts to use +// response_id-based threading for efficiency. However, several conditions +// can cause automatic fallback to client-side threading: +// +// The response_id chain is broken or invalid +// Tool calls are present (they break server-side threading) +// The model's threading API call fails +// - The context has no previous response_id (first call) +// +// back is automatic and transparent - conversations continue normally +// g full message history. Check logs for threading decisions. +// // # Thread Safety // // ContextWindow write operations (AddPrompt, SwitchContext, SetMaxTokens, etc.) // require external coordination when used concurrently. However, you can use // ContextWindow.Reader() to get a thread-safe read-only view: // -// reader := cw.Reader() -// go updateUI(reader) // safe for concurrent use -// go updateMetrics(reader) // safe for concurrent use +// reader := cw.Reader() +// go updateUI(reader) // safe for concurrent use +// go updateMetrics(reader) // safe for concurrent use // -// // Meanwhile, main thread can safely modify state: -// cw.SwitchContext("new-context") -// cw.SetMaxTokens(8192) +// // Meanwhile, main thread can safely modify state: +// cw.SwitchContext("new-context") +// cw.SetMaxTokens(8192) // // ContextReader provides access to read operations like LiveRecords(), TokenUsage(), // and context querying, all of which are safe for concurrent use. @@ -104,6 +118,7 @@ import ( "database/sql" "errors" "fmt" + "log/slog" "strings" "sync" "time" @@ -310,12 +325,13 @@ func (cw *ContextWindow) AddToolOutput(output string) error { // SetRecordLiveStateByRange updates the live status of records in the specified range. // Indices are based on the current LiveRecords() slice, with both start and end inclusive. -// This allows selective marking of context elements as active (live=true) or +// This allows selective marking of context elements as active (live=true) or // inactive (live=false) based on their position in the conversation. // // Examples: -// SetRecordLiveStateByRange(2, 4, false) // marks records at indices 2, 3, 4 as dead -// SetRecordLiveStateByRange(5, 5, false) // marks only record at index 5 as dead +// +// SetRecordLiveStateByRange(2, 4, false) // marks records at indices 2, 3, 4 as dead +// SetRecordLiveStateByRange(5, 5, false) // marks only record at index 5 as dead func (cw *ContextWindow) SetRecordLiveStateByRange(startIndex, endIndex int, live bool) error { if startIndex < 0 || endIndex < startIndex { return fmt.Errorf("invalid range: startIndex=%d, endIndex=%d", startIndex, endIndex) @@ -417,6 +433,55 @@ func (cw *ContextWindow) CallModel(ctx context.Context) (string, error) { return cw.CallModelWithOpts(ctx, CallModelOpts{}) } +// shouldAttemptServerSideThreading determines if server-side threading should be attempted. +// Returns: (shouldAttempt bool, reason string) +func (cw *ContextWindow) shouldAttemptServerSideThreading( + contextInfo Context, + recs []Record, +) (bool, string) { + // If threading not enabled, don't attempt + if !contextInfo.UseServerSideThreading { + return false, "server-side threading not enabled for context" + } + + // Check if model supports threading + _, ok := cw.model.(ServerSideThreadingCapable) + if !ok { + return false, "model does not support server-side threading" + } + + // Check if there's a LastResponseID (needed for threading) + if contextInfo.LastResponseID == nil || *contextInfo.LastResponseID == "" { + return false, "no last_response_id available (first call or chain broken)" + } + + // Validate response_id chain + contextID, err := getContextIDByName(cw.db, cw.currentContext) + if err != nil { + return false, fmt.Sprintf("cannot get context ID: %v", err) + } + + valid, reason := ValidateResponseIDChain(cw.db, contextID) + if !valid { + return false, fmt.Sprintf("response_id chain invalid: %s", reason) + } + + return true, "preconditions met" +} + +// logThreadingDecision logs threading decisions for observability +func (cw *ContextWindow) logThreadingDecision( + attemptServerSide bool, + reason string, + contextName string, +) { + slog.Info("threading decision", + "attempt_server_side", attemptServerSide, + "reason", reason, + "context", contextName, + ) +} + // CallModelWithOpts drives an LLM with options. It composes live messages, invokes cw.model.Call, // logs the response, updates token count, and triggers compaction. func (cw *ContextWindow) CallModelWithOpts(ctx context.Context, opts CallModelOpts) (string, error) { @@ -440,44 +505,67 @@ func (cw *ContextWindow) CallModelWithOpts(ctx context.Context, opts CallModelOp var tokensUsed int var responseID *string - // Serverside threading (`previous_response_id`) sends only the most recent prompt - // and a backlink to the last response, rather than sending the entire thread on - // every LLM call. - // TODO(tqbf): this stuff needs better testing; I don't really use it. - if contextInfo.UseServerSideThreading { - if threadingModel, ok := cw.model.(ServerSideThreadingCapable); ok { - if optsModel, ok := threadingModel.(CallOptsCapable); ok { - events, responseID, tokensUsed, err = optsModel.CallWithThreadingAndOpts( - ctx, - true, - contextInfo.LastResponseID, - recs, - opts, - ) - } else { - events, responseID, tokensUsed, err = threadingModel.CallWithThreading( - ctx, - true, - contextInfo.LastResponseID, - recs, - ) - } - if err != nil { - return "", fmt.Errorf("call model with threading: %w", err) - } + // Determine if we should attempt server-side threading + attemptServerSide, reason := cw.shouldAttemptServerSideThreading(contextInfo, recs) + loggedFallback := false + + if attemptServerSide { + // Log threading attempt + cw.logThreadingDecision(true, reason, cw.currentContext) + + // Attempt server-side threading + threadingModel := cw.model.(ServerSideThreadingCapable) + var err error + + if optsModel, ok := threadingModel.(CallOptsCapable); ok { + events, responseID, tokensUsed, err = optsModel.CallWithThreadingAndOpts( + ctx, + true, + contextInfo.LastResponseID, + recs, + opts, + ) } else { - return "", fmt.Errorf("model does not support server-side threading") + events, responseID, tokensUsed, err = threadingModel.CallWithThreading( + ctx, + true, + contextInfo.LastResponseID, + recs, + ) } - } else { - // Fall back to traditional client-side threading + + if err != nil { + // Log fallback reason + fallbackReason := fmt.Sprintf("server-side threading failed: %v", err) + cw.logThreadingDecision(false, fallbackReason, cw.currentContext) + loggedFallback = true + // Fall through to client-side threading + attemptServerSide = false + reason = fallbackReason + } + } + + // Use client-side threading (either as fallback or default) + if !attemptServerSide { + // Log reason for client-side threading (only if we didn't already log the fallback) + if !loggedFallback { + cw.logThreadingDecision(false, reason, cw.currentContext) + } + if optsModel, ok := cw.model.(CallOptsCapable); ok { events, tokensUsed, err = optsModel.CallWithOpts(ctx, recs, opts) } else { events, tokensUsed, err = cw.model.Call(ctx, recs) } if err != nil { + // Include fallback context in error message if we fell back from server-side + if contextInfo.UseServerSideThreading { + return "", fmt.Errorf("call model (fallback to client-side threading): %w", err) + } return "", fmt.Errorf("call model: %w", err) } + // Client-side threading doesn't return responseID + responseID = nil } cw.metrics.Add(tokensUsed) diff --git a/contextwindow_test.go b/contextwindow_test.go index 397af87..23e394d 100644 --- a/contextwindow_test.go +++ b/contextwindow_test.go @@ -869,6 +869,8 @@ func TestSetContextServerSideThreading(t *testing.T) { assert.False(t, updatedCtx.UseServerSideThreading) } +// TestServerSideThreadingFallback verifies fallback behavior when server-side threading +// cannot be used, ensuring conversations continue successfully func TestServerSideThreadingFallback(t *testing.T) { db, err := NewContextDB(":memory:") assert.NoError(t, err) @@ -882,28 +884,60 @@ func TestServerSideThreadingFallback(t *testing.T) { assert.NoError(t, err) defer cw.Close() - // Add a prompt and make first call (no previous response ID, should use client-side) + // Add a prompt and make first call (no previous response ID, should fallback to client-side) err = cw.AddPrompt("Hello") assert.NoError(t, err) ctx := context.Background() - _, err = cw.CallModel(ctx) + resp1, err := cw.CallModel(ctx) assert.NoError(t, err) + assert.NotEmpty(t, resp1) // Verify the mock was called with client-side threading (no previous response ID) + // First call falls back because there's no LastResponseID yet assert.False(t, mockModel.lastCallUsedServerSide) assert.Nil(t, mockModel.lastPreviousResponseID) - // Add another prompt - this should use server-side threading + // Verify conversation continues successfully - check that response was recorded + recs, err := cw.LiveRecords() + assert.NoError(t, err) + assert.Greater(t, len(recs), 0) + + // Manually set LastResponseID to simulate what would happen after a server-side call + // In a real scenario, this would be set by the previous server-side threading call + ctxInfo, err := GetContextByName(db, "test-fallback") + assert.NoError(t, err) + testResponseID := "mock_response_123" + err = UpdateContextLastResponseID(db, ctxInfo.ID, testResponseID) + assert.NoError(t, err) + + // Also set responseID on the last model response record to make chain valid + recs, err = cw.LiveRecords() + assert.NoError(t, err) + for i := len(recs) - 1; i >= 0; i-- { + if recs[i].Source == ModelResp { + _, err = db.Exec(`UPDATE records SET response_id = ? WHERE id = ?`, testResponseID, recs[i].ID) + assert.NoError(t, err) + break + } + } + + // Add another prompt - this should use server-side threading now that we have LastResponseID err = cw.AddPrompt("How are you?") assert.NoError(t, err) - _, err = cw.CallModel(ctx) + resp2, err := cw.CallModel(ctx) assert.NoError(t, err) + assert.NotEmpty(t, resp2) - // Verify the mock was called with server-side threading + // Verify the mock was called with server-side threading (now we have LastResponseID) assert.True(t, mockModel.lastCallUsedServerSide) assert.NotNil(t, mockModel.lastPreviousResponseID) + + // Verify conversation continues successfully with both responses + recs, err = cw.LiveRecords() + assert.NoError(t, err) + assert.Greater(t, len(recs), 2) // Should have prompts and responses } // Mock model for testing server-side threading behavior @@ -914,7 +948,17 @@ type mockResponsesModel struct { } func (m *mockResponsesModel) Call(ctx context.Context, inputs []Record) ([]Record, int, error) { - events, _, tokens, err := m.CallWithThreading(ctx, false, nil, inputs) + // CallWithThreading with client-side threading (useServerSideThreading=false) + // Still return a responseID so LastResponseID gets set + events, responseID, tokens, err := m.CallWithThreading(ctx, false, nil, inputs) + // Set responseID on events for consistency + if responseID != nil && len(events) > 0 { + for i := range events { + if events[i].Source == ModelResp { + events[i].ResponseID = responseID + } + } + } return events, tokens, err } @@ -1499,6 +1543,7 @@ func TestContextContinuation(t *testing.T) { } // TestThreadingBehaviorResume tests threading behavior when resuming contexts +// and ensures fallback doesn't break existing behavior func TestThreadingBehaviorResume(t *testing.T) { db, err := NewContextDB(":memory:") assert.NoError(t, err) @@ -1516,25 +1561,50 @@ func TestThreadingBehaviorResume(t *testing.T) { err = cw1.AddPrompt("Initial prompt") assert.NoError(t, err) - _, err = cw1.CallModel(context.Background()) + resp1, err := cw1.CallModel(context.Background()) assert.NoError(t, err) + assert.NotEmpty(t, resp1) + + // Verify initial response was recorded + recs1, err := cw1.LiveRecords() + assert.NoError(t, err) + assert.Greater(t, len(recs1), 0) // "Close" and reopen context cw2, err := NewContextWindowWithThreading(db, threadingModel, "threading-resume", true) assert.NoError(t, err) + // Verify context was resumed correctly + recs2, err := cw2.LiveRecords() + assert.NoError(t, err) + assert.Equal(t, len(recs1), len(recs2)) // Should have same records + // Add another prompt err = cw2.AddPrompt("Follow-up prompt") assert.NoError(t, err) // Verify the model receives the call appropriately - _, err = cw2.CallModel(context.Background()) + resp2, err := cw2.CallModel(context.Background()) + assert.NoError(t, err) + assert.NotEmpty(t, resp2) + + // Verify conversation continues successfully + recs3, err := cw2.LiveRecords() assert.NoError(t, err) + assert.Greater(t, len(recs3), len(recs2)) // Should have more records now // Check that the context was properly continued ctx, err := GetContextByName(db, "threading-resume") assert.NoError(t, err) assert.Equal(t, true, ctx.UseServerSideThreading) + + // Verify fallback doesn't break behavior - context should still work + err = cw2.AddPrompt("Third prompt") + assert.NoError(t, err) + + resp3, err := cw2.CallModel(context.Background()) + assert.NoError(t, err) + assert.NotEmpty(t, resp3) }) t.Run("ResumeWithClientSideThreading", func(t *testing.T) { @@ -1546,23 +1616,43 @@ func TestThreadingBehaviorResume(t *testing.T) { err = cw1.AddPrompt("First client prompt") assert.NoError(t, err) - _, err = cw1.CallModel(context.Background()) + resp1, err := cw1.CallModel(context.Background()) assert.NoError(t, err) + assert.NotEmpty(t, resp1) // Reopen context cw2, err := NewContextWindowWithThreading(db, threadingModel, "client-resume", false) assert.NoError(t, err) + // Verify context was resumed + recs1, err := cw2.LiveRecords() + assert.NoError(t, err) + assert.Greater(t, len(recs1), 0) + err = cw2.AddPrompt("Second client prompt") assert.NoError(t, err) - _, err = cw2.CallModel(context.Background()) + resp2, err := cw2.CallModel(context.Background()) assert.NoError(t, err) + assert.NotEmpty(t, resp2) + + // Verify conversation continues successfully + recs2, err := cw2.LiveRecords() + assert.NoError(t, err) + assert.Greater(t, len(recs2), len(recs1)) // Verify context settings ctx, err := GetContextByName(db, "client-resume") assert.NoError(t, err) assert.Equal(t, false, ctx.UseServerSideThreading) + + // Verify fallback doesn't break behavior - context should still work + err = cw2.AddPrompt("Third client prompt") + assert.NoError(t, err) + + resp3, err := cw2.CallModel(context.Background()) + assert.NoError(t, err) + assert.NotEmpty(t, resp3) }) } @@ -2089,4 +2179,571 @@ func TestContextWindow_SetRecordLiveStateByRange_Revive(t *testing.T) { liveRecordsAfterDead, err := cw.LiveRecords() assert.NoError(t, err) assert.Len(t, liveRecordsAfterDead, 0) +} + +// TestShouldAttemptServerSideThreading tests the helper function that determines +// if server-side threading should be attempted. +func TestShouldAttemptServerSideThreading(t *testing.T) { + db, err := NewContextDB(":memory:") + assert.NoError(t, err) + defer db.Close() + + t.Run("threading enabled with valid chain", func(t *testing.T) { + mockModel := &mockResponsesModel{} + cw, err := NewContextWindowWithThreading(db, mockModel, "test-valid", true) + assert.NoError(t, err) + + // Add prompt and get response to establish chain + err = cw.AddPrompt("Hello") + assert.NoError(t, err) + + // First call - no LastResponseID, uses client-side (no responseID returned) + _, err = cw.CallModel(context.Background()) + assert.NoError(t, err) + + // Manually set LastResponseID to simulate what would happen after a server-side call + // In real usage, this would be set by a previous server-side threading call + ctx, err := GetContextByName(db, "test-valid") + assert.NoError(t, err) + testResponseID := "test_response_123" + err = UpdateContextLastResponseID(db, ctx.ID, testResponseID) + assert.NoError(t, err) + + // Also set responseID on the last model response record to make chain valid + recs, err := cw.LiveRecords() + assert.NoError(t, err) + for i := len(recs) - 1; i >= 0; i-- { + if recs[i].Source == ModelResp { + // Update the record's responseID + _, err = db.Exec(`UPDATE records SET response_id = ? WHERE id = ?`, testResponseID, recs[i].ID) + assert.NoError(t, err) + break + } + } + + // Now get context info + contextInfo, err := cw.GetCurrentContextInfo() + assert.NoError(t, err) + + recs, err = cw.LiveRecords() + assert.NoError(t, err) + + // Should attempt threading (has LastResponseID now) + shouldAttempt, reason := cw.shouldAttemptServerSideThreading(contextInfo, recs) + assert.True(t, shouldAttempt) + assert.Equal(t, "preconditions met", reason) + }) + + t.Run("threading disabled", func(t *testing.T) { + mockModel := &mockResponsesModel{} + cw, err := NewContextWindow(db, mockModel, "test-disabled") + assert.NoError(t, err) + + contextInfo, err := cw.GetCurrentContextInfo() + assert.NoError(t, err) + + recs, err := cw.LiveRecords() + assert.NoError(t, err) + + shouldAttempt, reason := cw.shouldAttemptServerSideThreading(contextInfo, recs) + assert.False(t, shouldAttempt) + assert.Contains(t, reason, "server-side threading not enabled") + }) + + t.Run("model does not support threading", func(t *testing.T) { + // Use a model that doesn't implement ServerSideThreadingCapable + nonThreadingModel := &MockModel{} + cw, err := NewContextWindowWithThreading(db, nonThreadingModel, "test-no-support", true) + assert.NoError(t, err) + + contextInfo, err := cw.GetCurrentContextInfo() + assert.NoError(t, err) + + recs, err := cw.LiveRecords() + assert.NoError(t, err) + + shouldAttempt, reason := cw.shouldAttemptServerSideThreading(contextInfo, recs) + assert.False(t, shouldAttempt) + assert.Contains(t, reason, "model does not support server-side threading") + }) + + t.Run("missing LastResponseID", func(t *testing.T) { + mockModel := &mockResponsesModel{} + cw, err := NewContextWindowWithThreading(db, mockModel, "test-no-last-id", true) + assert.NoError(t, err) + + // Add prompt but don't call model yet (no LastResponseID) + err = cw.AddPrompt("Hello") + assert.NoError(t, err) + + contextInfo, err := cw.GetCurrentContextInfo() + assert.NoError(t, err) + + recs, err := cw.LiveRecords() + assert.NoError(t, err) + + shouldAttempt, reason := cw.shouldAttemptServerSideThreading(contextInfo, recs) + assert.False(t, shouldAttempt) + assert.Contains(t, reason, "no last_response_id available") + }) + + t.Run("invalid response_id chain", func(t *testing.T) { + mockModel := &mockResponsesModel{} + cw, err := NewContextWindowWithThreading(db, mockModel, "test-invalid-chain", true) + assert.NoError(t, err) + + // Add prompt and get response + err = cw.AddPrompt("Hello") + assert.NoError(t, err) + + _, err = cw.CallModel(context.Background()) + assert.NoError(t, err) + + // Manually set LastResponseID so we can test chain validation + ctx, err := GetContextByName(db, "test-invalid-chain") + assert.NoError(t, err) + testResponseID := "test_response_123" + err = UpdateContextLastResponseID(db, ctx.ID, testResponseID) + assert.NoError(t, err) + + // Add tool call to break the chain + err = cw.AddToolCall("test_tool", "{}") + assert.NoError(t, err) + + contextInfo, err := cw.GetCurrentContextInfo() + assert.NoError(t, err) + + recs, err := cw.LiveRecords() + assert.NoError(t, err) + + shouldAttempt, reason := cw.shouldAttemptServerSideThreading(contextInfo, recs) + assert.False(t, shouldAttempt) + assert.Contains(t, reason, "response_id chain invalid") + }) +} + +// TestServerSideThreadingFallbackOnError tests fallback when server-side threading fails +func TestServerSideThreadingFallbackOnError(t *testing.T) { + db, err := NewContextDB(":memory:") + assert.NoError(t, err) + defer db.Close() + + // Create a mock model that fails on threading calls + failingModel := &failingThreadingModel{} + + cw, err := NewContextWindowWithThreading(db, failingModel, "test-error-fallback", true) + assert.NoError(t, err) + defer cw.Close() + + // Set up a valid threading scenario + err = cw.AddPrompt("Hello") + assert.NoError(t, err) + + // First call - no LastResponseID, will use client-side + _, err = cw.CallModel(context.Background()) + assert.NoError(t, err) + + // Add another prompt - now we have LastResponseID + err = cw.AddPrompt("How are you?") + assert.NoError(t, err) + + // This call should attempt server-side, fail, and fallback to client-side + _, err = cw.CallModel(context.Background()) + assert.NoError(t, err) // Should succeed with fallback + + // Verify that the fallback occurred (the model should have been called with client-side) + assert.True(t, failingModel.fallbackOccurred) +} + +// failingThreadingModel fails on threading calls but succeeds on regular calls +type failingThreadingModel struct { + fallbackOccurred bool +} + +func (m *failingThreadingModel) Call(ctx context.Context, inputs []Record) ([]Record, int, error) { + m.fallbackOccurred = true + return []Record{ + { + Source: ModelResp, + Content: "Fallback response", + Live: true, + EstTokens: 10, + }, + }, 10, nil +} + +func (m *failingThreadingModel) CallWithThreading( + ctx context.Context, + useServerSideThreading bool, + lastResponseID *string, + inputs []Record, +) ([]Record, *string, int, error) { + // Always fail to simulate threading failure + return nil, nil, 0, fmt.Errorf("server-side threading failed") +} + +func (m *failingThreadingModel) SetToolExecutor(executor ToolExecutor) { + // No-op +} + +func (m *failingThreadingModel) SetMiddleware(middleware []Middleware) { + // No-op +} + +// TestServerSideThreadingFallbackOnBrokenChain tests fallback when response_id chain is broken +func TestServerSideThreadingFallbackOnBrokenChain(t *testing.T) { + db, err := NewContextDB(":memory:") + assert.NoError(t, err) + defer db.Close() + + mockModel := &mockResponsesModel{} + cw, err := NewContextWindowWithThreading(db, mockModel, "test-broken-chain", true) + assert.NoError(t, err) + defer cw.Close() + + // Set up initial valid chain + err = cw.AddPrompt("Hello") + assert.NoError(t, err) + + _, err = cw.CallModel(context.Background()) + assert.NoError(t, err) + + // Add tool call to break the chain + err = cw.AddToolCall("test_tool", "{}") + assert.NoError(t, err) + + err = cw.AddToolOutput("Tool output") + assert.NoError(t, err) + + // Add another prompt + err = cw.AddPrompt("Follow-up") + assert.NoError(t, err) + + // This call should detect broken chain and fallback to client-side + _, err = cw.CallModel(context.Background()) + assert.NoError(t, err) + + // Verify fallback occurred (should use client-side threading) + assert.False(t, mockModel.lastCallUsedServerSide) +} + +// TestThreadingFallbackOnMissingResponseID tests fallback when threading is enabled +// but no LastResponseID exists (e.g., first call or after chain break) +func TestThreadingFallbackOnMissingResponseID(t *testing.T) { + db, err := NewContextDB(":memory:") + assert.NoError(t, err) + defer db.Close() + + mockModel := &mockResponsesModel{} + cw, err := NewContextWindowWithThreading(db, mockModel, "test-missing-response-id", true) + assert.NoError(t, err) + defer cw.Close() + + // Add a prompt - this is the first call, so no LastResponseID exists + err = cw.AddPrompt("Hello") + assert.NoError(t, err) + + // Call model - should use client-side threading since no LastResponseID + _, err = cw.CallModel(context.Background()) + assert.NoError(t, err) + + // Verify that client-side threading was used (no LastResponseID available) + assert.False(t, mockModel.lastCallUsedServerSide) + assert.Nil(t, mockModel.lastPreviousResponseID) + + // Verify conversation continues successfully + recs, err := cw.LiveRecords() + assert.NoError(t, err) + assert.Greater(t, len(recs), 0) + + // Verify response was recorded + foundResponse := false + for _, rec := range recs { + if rec.Source == ModelResp { + foundResponse = true + break + } + } + assert.True(t, foundResponse, "Model response should be recorded") +} + +// TestThreadingFallbackOnToolCalls tests fallback when tool calls are present +// Tool calls break server-side threading, so should always use client-side +func TestThreadingFallbackOnToolCalls(t *testing.T) { + db, err := NewContextDB(":memory:") + assert.NoError(t, err) + defer db.Close() + + mockModel := &mockResponsesModel{} + cw, err := NewContextWindowWithThreading(db, mockModel, "test-tool-calls-fallback", true) + assert.NoError(t, err) + defer cw.Close() + + // Set up initial valid chain + err = cw.AddPrompt("Hello") + assert.NoError(t, err) + + _, err = cw.CallModel(context.Background()) + assert.NoError(t, err) + + // Manually set LastResponseID to simulate what would happen after a server-side call + ctxInfo, err := GetContextByName(db, "test-tool-calls-fallback") + assert.NoError(t, err) + testResponseID := "test_response_123" + err = UpdateContextLastResponseID(db, ctxInfo.ID, testResponseID) + assert.NoError(t, err) + + // Also set responseID on the last model response record to make chain valid + recs, err := cw.LiveRecords() + assert.NoError(t, err) + for i := len(recs) - 1; i >= 0; i-- { + if recs[i].Source == ModelResp { + _, err = db.Exec(`UPDATE records SET response_id = ? WHERE id = ?`, testResponseID, recs[i].ID) + assert.NoError(t, err) + break + } + } + + // Add tool calls - these break server-side threading + err = cw.AddToolCall("test_tool", `{"arg": "value"}`) + assert.NoError(t, err) + + err = cw.AddToolOutput("Tool output") + assert.NoError(t, err) + + // Add another prompt + err = cw.AddPrompt("Follow-up after tool call") + assert.NoError(t, err) + + // Reset the mock state + mockModel.lastCallUsedServerSide = false + mockModel.lastPreviousResponseID = nil + + // Call model - should fallback to client-side threading because of tool calls + _, err = cw.CallModel(context.Background()) + assert.NoError(t, err) + + // Verify fallback occurred (should use client-side threading) + assert.False(t, mockModel.lastCallUsedServerSide) + + // Verify conversation continues successfully + recs, err = cw.LiveRecords() + assert.NoError(t, err) + assert.Greater(t, len(recs), 0) +} + +// TestEmptyContextWithThreadingEnabled tests that empty contexts with threading enabled +// work correctly (first call should use client-side threading, should not error) +func TestEmptyContextWithThreadingEnabled(t *testing.T) { + db, err := NewContextDB(":memory:") + assert.NoError(t, err) + defer db.Close() + + mockModel := &mockResponsesModel{} + cw, err := NewContextWindowWithThreading(db, mockModel, "test-empty-threading", true) + assert.NoError(t, err) + defer cw.Close() + + // Verify threading is enabled + enabled, err := cw.IsServerSideThreadingEnabled() + assert.NoError(t, err) + assert.True(t, enabled) + + // Verify context is empty + recs, err := cw.LiveRecords() + assert.NoError(t, err) + assert.Len(t, recs, 0) + + // Add a prompt - this is the first call, so no LastResponseID exists + err = cw.AddPrompt("Hello") + assert.NoError(t, err) + + // Call model - should use client-side threading (no previous response) + // Should not error even though threading is enabled + resp, err := cw.CallModel(context.Background()) + assert.NoError(t, err) + assert.NotEmpty(t, resp) + + // Verify client-side threading was used (no LastResponseID available) + assert.False(t, mockModel.lastCallUsedServerSide) + assert.Nil(t, mockModel.lastPreviousResponseID) + + // Verify response was recorded + recs, err = cw.LiveRecords() + assert.NoError(t, err) + assert.Greater(t, len(recs), 0) + + // Verify a response record exists + foundResponse := false + for _, rec := range recs { + if rec.Source == ModelResp { + foundResponse = true + break + } + } + assert.True(t, foundResponse, "Model response should be recorded") +} + +// TestLastResponseIDNoMatchingRecordFallback tests fallback when LastResponseID +// exists but no matching record is found (export/import scenario) +func TestLastResponseIDNoMatchingRecordFallback(t *testing.T) { + db, err := NewContextDB(":memory:") + assert.NoError(t, err) + defer db.Close() + + mockModel := &mockResponsesModel{} + cw, err := NewContextWindowWithThreading(db, mockModel, "test-no-matching-record", true) + assert.NoError(t, err) + defer cw.Close() + + // Add a prompt and get response + err = cw.AddPrompt("Hello") + assert.NoError(t, err) + + _, err = cw.CallModel(context.Background()) + assert.NoError(t, err) + + // Manually set a LastResponseID that doesn't match any record + // This simulates what might happen after export/import + ctxInfo, err := GetContextByName(db, "test-no-matching-record") + assert.NoError(t, err) + nonexistentResponseID := "resp-nonexistent-999" + err = UpdateContextLastResponseID(db, ctxInfo.ID, nonexistentResponseID) + assert.NoError(t, err) + + // Add another prompt + err = cw.AddPrompt("Follow-up") + assert.NoError(t, err) + + // Reset mock state + mockModel.lastCallUsedServerSide = false + mockModel.lastPreviousResponseID = nil + + // Call model - should detect broken chain and fallback to client-side + _, err = cw.CallModel(context.Background()) + assert.NoError(t, err) + + // Verify fallback occurred (should use client-side threading) + assert.False(t, mockModel.lastCallUsedServerSide) + + // Verify conversation continues successfully + recs, err := cw.LiveRecords() + assert.NoError(t, err) + assert.Greater(t, len(recs), 0) +} + +// TestConcurrentCallModelWithFallback tests that fallback decisions work correctly +// when called from different contexts. Note: Database operations require external +// coordination for true concurrency (as documented), but fallback decision logic +// itself reads from database safely. +func TestConcurrentCallModelWithFallback(t *testing.T) { + // Test sequential calls from different contexts to verify fallback logic + // works correctly without requiring true database concurrency + path := filepath.Join(t.TempDir(), "concurrent.db") + db, err := NewContextDB(path) + assert.NoError(t, err) + defer db.Close() + + // Create multiple context windows sequentially to verify fallback logic + for i := 0; i < 5; i++ { + mockModel := &mockResponsesModel{} + cw, err := NewContextWindowWithThreading(db, mockModel, fmt.Sprintf("test-concurrent-%d", i), true) + assert.NoError(t, err) + + // Add a prompt + err = cw.AddPrompt(fmt.Sprintf("Prompt %d", i)) + assert.NoError(t, err) + + // Call model - should use client-side threading (no LastResponseID) + // This verifies that fallback decision logic works correctly + _, err = cw.CallModel(context.Background()) + assert.NoError(t, err, "CallModel should succeed for context %d", i) + + // Verify fallback occurred (should use client-side threading) + assert.False(t, mockModel.lastCallUsedServerSide, "Should use client-side threading for first call") + assert.Nil(t, mockModel.lastPreviousResponseID, "Should not have previous response ID") + } +} + +// TestErrorMessagesIndicateFallback tests that error messages properly indicate +// when fallback to client-side threading occurred +func TestErrorMessagesIndicateFallback(t *testing.T) { + db, err := NewContextDB(":memory:") + assert.NoError(t, err) + defer db.Close() + + // Start with a working model to set up the state + workingModel := &mockResponsesModel{} + cw, err := NewContextWindowWithThreading(db, workingModel, "test-error-messages", true) + assert.NoError(t, err) + defer cw.Close() + + // Set up a scenario where server-side threading would be attempted + err = cw.AddPrompt("Hello") + assert.NoError(t, err) + + // First call - no LastResponseID, uses client-side, succeeds + _, err = cw.CallModel(context.Background()) + assert.NoError(t, err) + + // Manually set LastResponseID to trigger server-side attempt + ctxInfo, err := GetContextByName(db, "test-error-messages") + assert.NoError(t, err) + testResponseID := "test_response_123" + err = UpdateContextLastResponseID(db, ctxInfo.ID, testResponseID) + assert.NoError(t, err) + + // Set responseID on last model response to make chain valid + recs, err := cw.LiveRecords() + assert.NoError(t, err) + for i := len(recs) - 1; i >= 0; i-- { + if recs[i].Source == ModelResp { + _, err = db.Exec(`UPDATE records SET response_id = ? WHERE id = ?`, testResponseID, recs[i].ID) + assert.NoError(t, err) + break + } + } + + // Now switch to a failing model to test error message + failingModel := &failingClientSideModel{} + cw.model = failingModel + + // Add another prompt + err = cw.AddPrompt("Follow-up") + assert.NoError(t, err) + + // This should attempt server-side, fail, fallback to client-side, which also fails + // The error message should indicate fallback occurred + _, err = cw.CallModel(context.Background()) + assert.Error(t, err) + assert.Contains(t, err.Error(), "fallback to client-side threading", + "Error message should indicate fallback occurred") +} + +// failingClientSideModel fails on both threading and regular calls to test error messages +type failingClientSideModel struct { + fallbackOccurred bool +} + +func (m *failingClientSideModel) Call(ctx context.Context, inputs []Record) ([]Record, int, error) { + m.fallbackOccurred = true + return nil, 0, fmt.Errorf("client-side call failed") +} + +func (m *failingClientSideModel) CallWithThreading( + ctx context.Context, + useServerSideThreading bool, + lastResponseID *string, + inputs []Record, +) ([]Record, *string, int, error) { + // Always fail to simulate threading failure + return nil, nil, 0, fmt.Errorf("server-side threading failed") +} + +func (m *failingClientSideModel) SetToolExecutor(executor ToolExecutor) { + // No-op +} + +func (m *failingClientSideModel) SetMiddleware(middleware []Middleware) { + // No-op } \ No newline at end of file diff --git a/storage.go b/storage.go index 9cc36b1..d81beef 100644 --- a/storage.go +++ b/storage.go @@ -673,3 +673,93 @@ func CloneContext(db *sql.DB, sourceName, destName string) error { return nil } + +// ValidateResponseIDChain checks if a response_id chain is valid for server-side threading. +// Returns true if the chain is valid, false otherwise, along with a reason. +func ValidateResponseIDChain(db *sql.DB, contextID string) (bool, string) { + // Get context to check LastResponseID + ctx, err := GetContext(db, contextID) + if err != nil { + return false, fmt.Sprintf("cannot get context: %v", err) + } + + // If no LastResponseID, chain is invalid + if ctx.LastResponseID == nil || *ctx.LastResponseID == "" { + return false, "no last_response_id set" + } + + // Get all live records + records, err := ListLiveRecords(db, contextID) + if err != nil { + return false, fmt.Sprintf("cannot list records: %v", err) + } + + // Check for tool calls - these break server-side threading + for _, rec := range records { + if rec.Source == ToolCall || rec.Source == ToolOutput { + return false, "tool calls present (break server-side threading)" + } + } + + // Find all ModelResp records + var modelResponses []Record + for _, rec := range records { + if rec.Source == ModelResp { + modelResponses = append(modelResponses, rec) + } + } + + // If no model responses, chain is valid (first call) + if len(modelResponses) == 0 { + return true, "no model responses yet (first call)" + } + + // Check for gaps in response_id chain + hasResponseIDs := false + hasMissingResponseIDs := false + lastResponseIDExists := false + + for _, rec := range modelResponses { + if rec.ResponseID != nil && *rec.ResponseID != "" { + hasResponseIDs = true + // Check if this matches the context's LastResponseID (for existence check) + if ctx.LastResponseID != nil && *rec.ResponseID == *ctx.LastResponseID { + lastResponseIDExists = true + } + } else { + hasMissingResponseIDs = true + } + } + + // Edge case: Context has LastResponseID but no matching record exists + // This can happen after export/import or manual database edits + if ctx.LastResponseID != nil && *ctx.LastResponseID != "" && !lastResponseIDExists { + return false, fmt.Sprintf("last_response_id (%v) does not exist in records (chain broken, possibly after export/import)", + *ctx.LastResponseID) + } + + // Mixed state is invalid + if hasResponseIDs && hasMissingResponseIDs { + return false, "mixed response_id state (some records missing IDs)" + } + + // Last response must match context's LastResponseID + // Get the last response ID from records + lastResponseID := getLastResponseID(records) + if lastResponseID == nil || ctx.LastResponseID == nil || *lastResponseID != *ctx.LastResponseID { + return false, fmt.Sprintf("last response_id (%v) does not match context (%v)", + lastResponseID, ctx.LastResponseID) + } + + return true, "chain valid" +} + +// getLastResponseID is a helper to get the last response ID from records. +func getLastResponseID(records []Record) *string { + for i := len(records) - 1; i >= 0; i-- { + if records[i].Source == ModelResp && records[i].ResponseID != nil { + return records[i].ResponseID + } + } + return nil +} diff --git a/storage_test.go b/storage_test.go new file mode 100644 index 0000000..b4c40ba --- /dev/null +++ b/storage_test.go @@ -0,0 +1,411 @@ +package contextwindow + +import ( + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + _ "modernc.org/sqlite" +) + +func TestValidateResponseIDChain_ValidChain(t *testing.T) { + path := filepath.Join(t.TempDir(), "cw.db") + db, err := NewContextDB(path) + assert.NoError(t, err) + defer db.Close() + + // Create context with threading enabled + ctx, err := CreateContextWithThreading(db, "test-context", true) + assert.NoError(t, err) + + // Add a prompt + _, err = InsertRecord(db, ctx.ID, Prompt, "test prompt", true) + assert.NoError(t, err) + + // Add a model response with response_id + responseID := "resp-123" + _, err = InsertRecordWithResponseID(db, ctx.ID, ModelResp, "test response", true, &responseID) + assert.NoError(t, err) + + // Update context with last response ID + err = UpdateContextLastResponseID(db, ctx.ID, responseID) + assert.NoError(t, err) + + // Validate chain - should be valid + valid, reason := ValidateResponseIDChain(db, ctx.ID) + assert.True(t, valid) + assert.Equal(t, "chain valid", reason) +} + +func TestValidateResponseIDChain_MissingLastResponseID(t *testing.T) { + path := filepath.Join(t.TempDir(), "cw.db") + db, err := NewContextDB(path) + assert.NoError(t, err) + defer db.Close() + + // Create context with threading enabled but no LastResponseID set + ctx, err := CreateContextWithThreading(db, "test-context", true) + assert.NoError(t, err) + + // Add a prompt + _, err = InsertRecord(db, ctx.ID, Prompt, "test prompt", true) + assert.NoError(t, err) + + // Validate chain - should be invalid (no LastResponseID) + valid, reason := ValidateResponseIDChain(db, ctx.ID) + assert.False(t, valid) + assert.Equal(t, "no last_response_id set", reason) +} + +func TestValidateResponseIDChain_ToolCallsPresent(t *testing.T) { + path := filepath.Join(t.TempDir(), "cw.db") + db, err := NewContextDB(path) + assert.NoError(t, err) + defer db.Close() + + // Create context with threading enabled + ctx, err := CreateContextWithThreading(db, "test-context", true) + assert.NoError(t, err) + + // Add a prompt + _, err = InsertRecord(db, ctx.ID, Prompt, "test prompt", true) + assert.NoError(t, err) + + // Add a model response with response_id + responseID := "resp-123" + _, err = InsertRecordWithResponseID(db, ctx.ID, ModelResp, "test response", true, &responseID) + assert.NoError(t, err) + + // Update context with last response ID + err = UpdateContextLastResponseID(db, ctx.ID, responseID) + assert.NoError(t, err) + + // Add a tool call - this should break server-side threading + _, err = InsertRecord(db, ctx.ID, ToolCall, "tool call", true) + assert.NoError(t, err) + + // Validate chain - should be invalid (tool calls present) + valid, reason := ValidateResponseIDChain(db, ctx.ID) + assert.False(t, valid) + assert.Equal(t, "tool calls present (break server-side threading)", reason) +} + +func TestValidateResponseIDChain_ToolOutputPresent(t *testing.T) { + path := filepath.Join(t.TempDir(), "cw.db") + db, err := NewContextDB(path) + assert.NoError(t, err) + defer db.Close() + + // Create context with threading enabled + ctx, err := CreateContextWithThreading(db, "test-context", true) + assert.NoError(t, err) + + // Add a prompt + _, err = InsertRecord(db, ctx.ID, Prompt, "test prompt", true) + assert.NoError(t, err) + + // Add a model response with response_id + responseID := "resp-123" + _, err = InsertRecordWithResponseID(db, ctx.ID, ModelResp, "test response", true, &responseID) + assert.NoError(t, err) + + // Update context with last response ID + err = UpdateContextLastResponseID(db, ctx.ID, responseID) + assert.NoError(t, err) + + // Add a tool output - this should break server-side threading + _, err = InsertRecord(db, ctx.ID, ToolOutput, "tool output", true) + assert.NoError(t, err) + + // Validate chain - should be invalid (tool calls present) + valid, reason := ValidateResponseIDChain(db, ctx.ID) + assert.False(t, valid) + assert.Equal(t, "tool calls present (break server-side threading)", reason) +} + +func TestValidateResponseIDChain_MixedResponseIDState(t *testing.T) { + path := filepath.Join(t.TempDir(), "cw.db") + db, err := NewContextDB(path) + assert.NoError(t, err) + defer db.Close() + + // Create context with threading enabled + ctx, err := CreateContextWithThreading(db, "test-context", true) + assert.NoError(t, err) + + // Add a prompt + _, err = InsertRecord(db, ctx.ID, Prompt, "test prompt", true) + assert.NoError(t, err) + + // Add a model response with response_id + responseID1 := "resp-123" + _, err = InsertRecordWithResponseID(db, ctx.ID, ModelResp, "response 1", true, &responseID1) + assert.NoError(t, err) + + // Add a model response WITHOUT response_id (mixed state) + _, err = InsertRecord(db, ctx.ID, ModelResp, "response 2", true) + assert.NoError(t, err) + + // Update context with last response ID + err = UpdateContextLastResponseID(db, ctx.ID, responseID1) + assert.NoError(t, err) + + // Validate chain - should be invalid (mixed state) + valid, reason := ValidateResponseIDChain(db, ctx.ID) + assert.False(t, valid) + assert.Equal(t, "mixed response_id state (some records missing IDs)", reason) +} + +func TestValidateResponseIDChain_LastResponseIDMismatch(t *testing.T) { + path := filepath.Join(t.TempDir(), "cw.db") + db, err := NewContextDB(path) + assert.NoError(t, err) + defer db.Close() + + // Create context with threading enabled + ctx, err := CreateContextWithThreading(db, "test-context", true) + assert.NoError(t, err) + + // Add a prompt + _, err = InsertRecord(db, ctx.ID, Prompt, "test prompt", true) + assert.NoError(t, err) + + // Add a model response with response_id + responseID1 := "resp-123" + _, err = InsertRecordWithResponseID(db, ctx.ID, ModelResp, "response 1", true, &responseID1) + assert.NoError(t, err) + + // Add another model response with different response_id + responseID2 := "resp-456" + _, err = InsertRecordWithResponseID(db, ctx.ID, ModelResp, "response 2", true, &responseID2) + assert.NoError(t, err) + + // Update context with a different response ID that doesn't exist in records + // This is more serious than a mismatch - the ID doesn't exist at all + wrongResponseID := "resp-wrong" + err = UpdateContextLastResponseID(db, ctx.ID, wrongResponseID) + assert.NoError(t, err) + + // Validate chain - should be invalid (LastResponseID doesn't exist in records) + valid, reason := ValidateResponseIDChain(db, ctx.ID) + assert.False(t, valid) + assert.Contains(t, reason, "does not exist in records") + assert.Contains(t, reason, "export/import") +} + +func TestValidateResponseIDChain_LastResponseIDMismatchWithExistingID(t *testing.T) { + path := filepath.Join(t.TempDir(), "cw.db") + db, err := NewContextDB(path) + assert.NoError(t, err) + defer db.Close() + + // Create context with threading enabled + ctx, err := CreateContextWithThreading(db, "test-context", true) + assert.NoError(t, err) + + // Add a prompt + _, err = InsertRecord(db, ctx.ID, Prompt, "test prompt", true) + assert.NoError(t, err) + + // Add a model response with response_id + responseID1 := "resp-123" + _, err = InsertRecordWithResponseID(db, ctx.ID, ModelResp, "response 1", true, &responseID1) + assert.NoError(t, err) + + // Add another model response with different response_id (this is the last one) + responseID2 := "resp-456" + _, err = InsertRecordWithResponseID(db, ctx.ID, ModelResp, "response 2", true, &responseID2) + assert.NoError(t, err) + + // Update context with responseID1 (exists but doesn't match the last response) + err = UpdateContextLastResponseID(db, ctx.ID, responseID1) + assert.NoError(t, err) + + // Validate chain - should be invalid (LastResponseID doesn't match last response) + valid, reason := ValidateResponseIDChain(db, ctx.ID) + assert.False(t, valid) + assert.Contains(t, reason, "does not match context") +} + +func TestValidateResponseIDChain_EmptyContext(t *testing.T) { + path := filepath.Join(t.TempDir(), "cw.db") + db, err := NewContextDB(path) + assert.NoError(t, err) + defer db.Close() + + // Create context with threading enabled + ctx, err := CreateContextWithThreading(db, "test-context", true) + assert.NoError(t, err) + + // Set a LastResponseID even though there are no records + responseID := "resp-123" + err = UpdateContextLastResponseID(db, ctx.ID, responseID) + assert.NoError(t, err) + + // Validate chain - should be valid (no model responses yet, first call) + valid, reason := ValidateResponseIDChain(db, ctx.ID) + assert.True(t, valid) + assert.Equal(t, "no model responses yet (first call)", reason) +} + +func TestValidateResponseIDChain_NoModelResponsesButHasPrompt(t *testing.T) { + path := filepath.Join(t.TempDir(), "cw.db") + db, err := NewContextDB(path) + assert.NoError(t, err) + defer db.Close() + + // Create context with threading enabled + ctx, err := CreateContextWithThreading(db, "test-context", true) + assert.NoError(t, err) + + // Add a prompt but no model responses + _, err = InsertRecord(db, ctx.ID, Prompt, "test prompt", true) + assert.NoError(t, err) + + // Set a LastResponseID + responseID := "resp-123" + err = UpdateContextLastResponseID(db, ctx.ID, responseID) + assert.NoError(t, err) + + // Validate chain - should be valid (no model responses yet, first call) + valid, reason := ValidateResponseIDChain(db, ctx.ID) + assert.True(t, valid) + assert.Equal(t, "no model responses yet (first call)", reason) +} + +func TestValidateResponseIDChain_ContextNotFound(t *testing.T) { + path := filepath.Join(t.TempDir(), "cw.db") + db, err := NewContextDB(path) + assert.NoError(t, err) + defer db.Close() + + // Try to validate with non-existent context ID + valid, reason := ValidateResponseIDChain(db, "non-existent-id") + assert.False(t, valid) + assert.Contains(t, reason, "cannot get context") +} + +func TestValidateResponseIDChain_MultipleValidResponses(t *testing.T) { + path := filepath.Join(t.TempDir(), "cw.db") + db, err := NewContextDB(path) + assert.NoError(t, err) + defer db.Close() + + // Create context with threading enabled + ctx, err := CreateContextWithThreading(db, "test-context", true) + assert.NoError(t, err) + + // Add a prompt + _, err = InsertRecord(db, ctx.ID, Prompt, "test prompt", true) + assert.NoError(t, err) + + // Add multiple model responses with response_ids + responseID1 := "resp-123" + _, err = InsertRecordWithResponseID(db, ctx.ID, ModelResp, "response 1", true, &responseID1) + assert.NoError(t, err) + + responseID2 := "resp-456" + _, err = InsertRecordWithResponseID(db, ctx.ID, ModelResp, "response 2", true, &responseID2) + assert.NoError(t, err) + + // Update context with the last response ID + err = UpdateContextLastResponseID(db, ctx.ID, responseID2) + assert.NoError(t, err) + + // Validate chain - should be valid (last response matches) + valid, reason := ValidateResponseIDChain(db, ctx.ID) + assert.True(t, valid) + assert.Equal(t, "chain valid", reason) +} + +func TestGetLastResponseID(t *testing.T) { + path := filepath.Join(t.TempDir(), "cw.db") + db, err := NewContextDB(path) + assert.NoError(t, err) + defer db.Close() + + // Create context + ctx, err := CreateContextWithThreading(db, "test-context", true) + assert.NoError(t, err) + + // Add records + _, err = InsertRecord(db, ctx.ID, Prompt, "prompt", true) + assert.NoError(t, err) + + responseID1 := "resp-1" + _, err = InsertRecordWithResponseID(db, ctx.ID, ModelResp, "response 1", true, &responseID1) + assert.NoError(t, err) + + responseID2 := "resp-2" + _, err = InsertRecordWithResponseID(db, ctx.ID, ModelResp, "response 2", true, &responseID2) + assert.NoError(t, err) + + // Get records + records, err := ListLiveRecords(db, ctx.ID) + assert.NoError(t, err) + + // Test getLastResponseID helper + lastID := getLastResponseID(records) + assert.NotNil(t, lastID) + assert.Equal(t, responseID2, *lastID) +} + +func TestGetLastResponseID_NoResponseIDs(t *testing.T) { + path := filepath.Join(t.TempDir(), "cw.db") + db, err := NewContextDB(path) + assert.NoError(t, err) + defer db.Close() + + // Create context + ctx, err := CreateContextWithThreading(db, "test-context", true) + assert.NoError(t, err) + + // Add records without response IDs + _, err = InsertRecord(db, ctx.ID, Prompt, "prompt", true) + assert.NoError(t, err) + _, err = InsertRecord(db, ctx.ID, ModelResp, "response", true) + assert.NoError(t, err) + + // Get records + records, err := ListLiveRecords(db, ctx.ID) + assert.NoError(t, err) + + // Test getLastResponseID helper - should return nil + lastID := getLastResponseID(records) + assert.Nil(t, lastID) +} + +// TestValidateResponseIDChain_LastResponseIDNoMatchingRecord tests the edge case +// where a context has a LastResponseID but no matching record exists in the database. +// This can happen after export/import or manual database edits. +func TestValidateResponseIDChain_LastResponseIDNoMatchingRecord(t *testing.T) { + path := filepath.Join(t.TempDir(), "cw.db") + db, err := NewContextDB(path) + assert.NoError(t, err) + defer db.Close() + + // Create context with threading enabled + ctx, err := CreateContextWithThreading(db, "test-context", true) + assert.NoError(t, err) + + // Add a prompt + _, err = InsertRecord(db, ctx.ID, Prompt, "test prompt", true) + assert.NoError(t, err) + + // Add a model response with a different response_id + responseID1 := "resp-123" + _, err = InsertRecordWithResponseID(db, ctx.ID, ModelResp, "response 1", true, &responseID1) + assert.NoError(t, err) + + // Set a LastResponseID that doesn't match any record (simulating export/import scenario) + wrongResponseID := "resp-nonexistent-999" + err = UpdateContextLastResponseID(db, ctx.ID, wrongResponseID) + assert.NoError(t, err) + + // Validate chain - should be invalid (LastResponseID doesn't exist in records) + valid, reason := ValidateResponseIDChain(db, ctx.ID) + assert.False(t, valid) + assert.Contains(t, reason, "does not exist in records") + assert.Contains(t, reason, "export/import") +} + From 40bc5151ef473cd2813433def4841bc9f1123e2c Mon Sep 17 00:00:00 2001 From: "Thomas bH. Ptacek" Date: Tue, 11 Nov 2025 14:57:01 -0600 Subject: [PATCH 2/2] cleanups don't log, pass context not ID if we're just immediately look it up from ID anyways, lose obvious comments. tested with local agent, seems to work peachy --- contextwindow.go | 132 +++++++++++++++++------------------------------ storage.go | 40 ++++++++------ storage_test.go | 37 ++++--------- 3 files changed, 80 insertions(+), 129 deletions(-) diff --git a/contextwindow.go b/contextwindow.go index 8d2baa3..a905c99 100644 --- a/contextwindow.go +++ b/contextwindow.go @@ -4,27 +4,27 @@ // // # Abbreviated usage // -// model, err := NewOpenAIResponsesModel(shared.ResponsesModel4o) -// if err != nil { -// log.Fatalf("Failed to create model: %v", err) -// } +// model, err := NewOpenAIResponsesModel(shared.ResponsesModel4o) +// if err != nil { +// log.Fatalf("Failed to create model: %v", err) +// } // -// cw, err := contextwindow.New(model, nil, "") -// if err != nil { -// log.Fatalf("Failed to create context window: %v", err) -// } -// defer cw.Close() +// cw, err := contextwindow.New(model, nil, "") +// if err != nil { +// log.Fatalf("Failed to create context window: %v", err) +// } +// defer cw.Close() // -// if err := cw.AddPrompt(ctx, "how's the weather?"); err != nil { -// log.Fatalf("Failed to add prompt: %v", err) -// } +// if err := cw.AddPrompt(ctx, "how's the weather?"); err != nil { +// log.Fatalf("Failed to add prompt: %v", err) +// } // -// response, err := cw.CallModel(ctx) -// if err != nil { -// log.Fatalf("Failed to call model: %v", err) -// } +// response, err := cw.CallModel(ctx) +// if err != nil { +// log.Fatalf("Failed to call model: %v", err) +// } // -// fmt.Printf("response: %s\n", response) +// fmt.Printf("response: %s\n", response) // // # System prompts // @@ -35,19 +35,19 @@ // // Instruct LLMs to call tools locally with [ContextWindow.AddTool] (and [NewTool]). // -// lsTool := contextwindow.NewTool("list_files", ` -// This tool lists files in the specified directory. -// `).AddStringParameter("directory", "Directory to list", true) +// lsTool := contextwindow.NewTool("list_files", ` +// This tool lists files in the specified directory. +// `).AddStringParameter("directory", "Directory to list", true) // -// cw.AddTool(lsTool, contextwindow.ToolRunnerFunc(func context.Context, -// args json.RawMessage) (string, error) { -// var treq struct { -// Dir string `json:"directory"` -// } -// json.Unmarshal(args, &treq) -// // actually run ls, or pretend to -// return "here\nare\nsome\nfiles.exe\n", nil -// }) +// cw.AddTool(lsTool, contextwindow.ToolRunnerFunc(func context.Context, +// args json.RawMessage) (string, error) { +// var treq struct { +// Dir string `json:"directory"` +// } +// json.Unmarshal(args, &treq) +// // actually run ls, or pretend to +// return "here\nare\nsome\nfiles.exe\n", nil +// }) // // You can selectively enable and disable tools with [ContextWindow.CallModelWithOpts]. // @@ -62,12 +62,12 @@ // // You can provide a summarizer model to automatically compact your context window: // -// summarizerModel, err := openai.New(apiKey, "gpt-3.5-turbo") -// if err != nil { -// log.Fatalf("Failed to create summarizer: %v", err) -// } +// summarizerModel, err := openai.New(apiKey, "gpt-3.5-turbo") +// if err != nil { +// log.Fatalf("Failed to create summarizer: %v", err) +// } // -// cw, err := contextwindow.New(model, summarizerModel, "") +// cw, err := contextwindow.New(model, summarizerModel, "") // // And then "compress" your context with [ContextWindow.SummarizeLiveContent]. // @@ -118,7 +118,6 @@ import ( "database/sql" "errors" "fmt" - "log/slog" "strings" "sync" "time" @@ -433,35 +432,25 @@ func (cw *ContextWindow) CallModel(ctx context.Context) (string, error) { return cw.CallModelWithOpts(ctx, CallModelOpts{}) } -// shouldAttemptServerSideThreading determines if server-side threading should be attempted. -// Returns: (shouldAttempt bool, reason string) func (cw *ContextWindow) shouldAttemptServerSideThreading( - contextInfo Context, + ci Context, recs []Record, -) (bool, string) { - // If threading not enabled, don't attempt - if !contextInfo.UseServerSideThreading { +) (should bool, reason string /* not using this yet but seems like a good idea */) { + + if !ci.UseServerSideThreading { return false, "server-side threading not enabled for context" } - // Check if model supports threading _, ok := cw.model.(ServerSideThreadingCapable) if !ok { return false, "model does not support server-side threading" } - // Check if there's a LastResponseID (needed for threading) - if contextInfo.LastResponseID == nil || *contextInfo.LastResponseID == "" { + if ci.LastResponseID == nil || *ci.LastResponseID == "" { return false, "no last_response_id available (first call or chain broken)" } - // Validate response_id chain - contextID, err := getContextIDByName(cw.db, cw.currentContext) - if err != nil { - return false, fmt.Sprintf("cannot get context ID: %v", err) - } - - valid, reason := ValidateResponseIDChain(cw.db, contextID) + valid, reason := ValidateResponseIDChain(cw.db, ci) if !valid { return false, fmt.Sprintf("response_id chain invalid: %s", reason) } @@ -469,19 +458,6 @@ func (cw *ContextWindow) shouldAttemptServerSideThreading( return true, "preconditions met" } -// logThreadingDecision logs threading decisions for observability -func (cw *ContextWindow) logThreadingDecision( - attemptServerSide bool, - reason string, - contextName string, -) { - slog.Info("threading decision", - "attempt_server_side", attemptServerSide, - "reason", reason, - "context", contextName, - ) -} - // CallModelWithOpts drives an LLM with options. It composes live messages, invokes cw.model.Call, // logs the response, updates token count, and triggers compaction. func (cw *ContextWindow) CallModelWithOpts(ctx context.Context, opts CallModelOpts) (string, error) { @@ -501,19 +477,15 @@ func (cw *ContextWindow) CallModelWithOpts(ctx context.Context, opts CallModelOp return "", fmt.Errorf("list live records: %w", err) } - var events []Record - var tokensUsed int - var responseID *string + var ( + events []Record + tokensUsed int + responseID *string + ) - // Determine if we should attempt server-side threading - attemptServerSide, reason := cw.shouldAttemptServerSideThreading(contextInfo, recs) - loggedFallback := false + attemptServerSide, _ := cw.shouldAttemptServerSideThreading(contextInfo, recs) if attemptServerSide { - // Log threading attempt - cw.logThreadingDecision(true, reason, cw.currentContext) - - // Attempt server-side threading threadingModel := cw.model.(ServerSideThreadingCapable) var err error @@ -535,30 +507,19 @@ func (cw *ContextWindow) CallModelWithOpts(ctx context.Context, opts CallModelOp } if err != nil { - // Log fallback reason - fallbackReason := fmt.Sprintf("server-side threading failed: %v", err) - cw.logThreadingDecision(false, fallbackReason, cw.currentContext) - loggedFallback = true // Fall through to client-side threading attemptServerSide = false - reason = fallbackReason } } // Use client-side threading (either as fallback or default) if !attemptServerSide { - // Log reason for client-side threading (only if we didn't already log the fallback) - if !loggedFallback { - cw.logThreadingDecision(false, reason, cw.currentContext) - } - if optsModel, ok := cw.model.(CallOptsCapable); ok { events, tokensUsed, err = optsModel.CallWithOpts(ctx, recs, opts) } else { events, tokensUsed, err = cw.model.Call(ctx, recs) } if err != nil { - // Include fallback context in error message if we fell back from server-side if contextInfo.UseServerSideThreading { return "", fmt.Errorf("call model (fallback to client-side threading): %w", err) } @@ -585,7 +546,6 @@ func (cw *ContextWindow) CallModelWithOpts(ctx context.Context, opts CallModelOp lastMsg = event.Content } - // Update the context's last response ID if we got one if responseID != nil { err = UpdateContextLastResponseID(cw.db, contextID, *responseID) if err != nil { diff --git a/storage.go b/storage.go index d81beef..f5797c4 100644 --- a/storage.go +++ b/storage.go @@ -675,21 +675,27 @@ func CloneContext(db *sql.DB, sourceName, destName string) error { } // ValidateResponseIDChain checks if a response_id chain is valid for server-side threading. -// Returns true if the chain is valid, false otherwise, along with a reason. -func ValidateResponseIDChain(db *sql.DB, contextID string) (bool, string) { - // Get context to check LastResponseID - ctx, err := GetContext(db, contextID) - if err != nil { - return false, fmt.Sprintf("cannot get context: %v", err) - } +func ValidateResponseIDChain(db *sql.DB, ctx Context) (valid bool, reason string) { + var ( + err error + isValid = func(id *string) bool { + return id != nil && *id != "" + } + ) // If no LastResponseID, chain is invalid - if ctx.LastResponseID == nil || *ctx.LastResponseID == "" { + if !isValid(ctx.LastResponseID) { + ctx, err = GetContext(db, ctx.ID) + if err != nil { + return false, "can't load context from db" + } + } + + if !isValid(ctx.LastResponseID) { return false, "no last_response_id set" } - // Get all live records - records, err := ListLiveRecords(db, contextID) + records, err := ListLiveRecords(db, ctx.ID) if err != nil { return false, fmt.Sprintf("cannot list records: %v", err) } @@ -701,7 +707,6 @@ func ValidateResponseIDChain(db *sql.DB, contextID string) (bool, string) { } } - // Find all ModelResp records var modelResponses []Record for _, rec := range records { if rec.Source == ModelResp { @@ -715,12 +720,14 @@ func ValidateResponseIDChain(db *sql.DB, contextID string) (bool, string) { } // Check for gaps in response_id chain - hasResponseIDs := false - hasMissingResponseIDs := false - lastResponseIDExists := false + var ( + hasResponseIDs = false + hasMissingResponseIDs = false + lastResponseIDExists = false + ) for _, rec := range modelResponses { - if rec.ResponseID != nil && *rec.ResponseID != "" { + if isValid(rec.ResponseID) { hasResponseIDs = true // Check if this matches the context's LastResponseID (for existence check) if ctx.LastResponseID != nil && *rec.ResponseID == *ctx.LastResponseID { @@ -733,12 +740,11 @@ func ValidateResponseIDChain(db *sql.DB, contextID string) (bool, string) { // Edge case: Context has LastResponseID but no matching record exists // This can happen after export/import or manual database edits - if ctx.LastResponseID != nil && *ctx.LastResponseID != "" && !lastResponseIDExists { + if isValid(ctx.LastResponseID) && !lastResponseIDExists { return false, fmt.Sprintf("last_response_id (%v) does not exist in records (chain broken, possibly after export/import)", *ctx.LastResponseID) } - // Mixed state is invalid if hasResponseIDs && hasMissingResponseIDs { return false, "mixed response_id state (some records missing IDs)" } diff --git a/storage_test.go b/storage_test.go index b4c40ba..41ca734 100644 --- a/storage_test.go +++ b/storage_test.go @@ -32,7 +32,7 @@ func TestValidateResponseIDChain_ValidChain(t *testing.T) { assert.NoError(t, err) // Validate chain - should be valid - valid, reason := ValidateResponseIDChain(db, ctx.ID) + valid, reason := ValidateResponseIDChain(db, ctx) assert.True(t, valid) assert.Equal(t, "chain valid", reason) } @@ -52,7 +52,7 @@ func TestValidateResponseIDChain_MissingLastResponseID(t *testing.T) { assert.NoError(t, err) // Validate chain - should be invalid (no LastResponseID) - valid, reason := ValidateResponseIDChain(db, ctx.ID) + valid, reason := ValidateResponseIDChain(db, ctx) assert.False(t, valid) assert.Equal(t, "no last_response_id set", reason) } @@ -85,7 +85,7 @@ func TestValidateResponseIDChain_ToolCallsPresent(t *testing.T) { assert.NoError(t, err) // Validate chain - should be invalid (tool calls present) - valid, reason := ValidateResponseIDChain(db, ctx.ID) + valid, reason := ValidateResponseIDChain(db, ctx) assert.False(t, valid) assert.Equal(t, "tool calls present (break server-side threading)", reason) } @@ -118,7 +118,7 @@ func TestValidateResponseIDChain_ToolOutputPresent(t *testing.T) { assert.NoError(t, err) // Validate chain - should be invalid (tool calls present) - valid, reason := ValidateResponseIDChain(db, ctx.ID) + valid, reason := ValidateResponseIDChain(db, ctx) assert.False(t, valid) assert.Equal(t, "tool calls present (break server-side threading)", reason) } @@ -151,7 +151,7 @@ func TestValidateResponseIDChain_MixedResponseIDState(t *testing.T) { assert.NoError(t, err) // Validate chain - should be invalid (mixed state) - valid, reason := ValidateResponseIDChain(db, ctx.ID) + valid, reason := ValidateResponseIDChain(db, ctx) assert.False(t, valid) assert.Equal(t, "mixed response_id state (some records missing IDs)", reason) } @@ -187,7 +187,7 @@ func TestValidateResponseIDChain_LastResponseIDMismatch(t *testing.T) { assert.NoError(t, err) // Validate chain - should be invalid (LastResponseID doesn't exist in records) - valid, reason := ValidateResponseIDChain(db, ctx.ID) + valid, reason := ValidateResponseIDChain(db, ctx) assert.False(t, valid) assert.Contains(t, reason, "does not exist in records") assert.Contains(t, reason, "export/import") @@ -222,7 +222,7 @@ func TestValidateResponseIDChain_LastResponseIDMismatchWithExistingID(t *testing assert.NoError(t, err) // Validate chain - should be invalid (LastResponseID doesn't match last response) - valid, reason := ValidateResponseIDChain(db, ctx.ID) + valid, reason := ValidateResponseIDChain(db, ctx) assert.False(t, valid) assert.Contains(t, reason, "does not match context") } @@ -243,7 +243,7 @@ func TestValidateResponseIDChain_EmptyContext(t *testing.T) { assert.NoError(t, err) // Validate chain - should be valid (no model responses yet, first call) - valid, reason := ValidateResponseIDChain(db, ctx.ID) + valid, reason := ValidateResponseIDChain(db, ctx) assert.True(t, valid) assert.Equal(t, "no model responses yet (first call)", reason) } @@ -268,23 +268,11 @@ func TestValidateResponseIDChain_NoModelResponsesButHasPrompt(t *testing.T) { assert.NoError(t, err) // Validate chain - should be valid (no model responses yet, first call) - valid, reason := ValidateResponseIDChain(db, ctx.ID) + valid, reason := ValidateResponseIDChain(db, ctx) assert.True(t, valid) assert.Equal(t, "no model responses yet (first call)", reason) } -func TestValidateResponseIDChain_ContextNotFound(t *testing.T) { - path := filepath.Join(t.TempDir(), "cw.db") - db, err := NewContextDB(path) - assert.NoError(t, err) - defer db.Close() - - // Try to validate with non-existent context ID - valid, reason := ValidateResponseIDChain(db, "non-existent-id") - assert.False(t, valid) - assert.Contains(t, reason, "cannot get context") -} - func TestValidateResponseIDChain_MultipleValidResponses(t *testing.T) { path := filepath.Join(t.TempDir(), "cw.db") db, err := NewContextDB(path) @@ -313,7 +301,7 @@ func TestValidateResponseIDChain_MultipleValidResponses(t *testing.T) { assert.NoError(t, err) // Validate chain - should be valid (last response matches) - valid, reason := ValidateResponseIDChain(db, ctx.ID) + valid, reason := ValidateResponseIDChain(db, ctx) assert.True(t, valid) assert.Equal(t, "chain valid", reason) } @@ -403,9 +391,6 @@ func TestValidateResponseIDChain_LastResponseIDNoMatchingRecord(t *testing.T) { assert.NoError(t, err) // Validate chain - should be invalid (LastResponseID doesn't exist in records) - valid, reason := ValidateResponseIDChain(db, ctx.ID) + valid, _ := ValidateResponseIDChain(db, ctx) assert.False(t, valid) - assert.Contains(t, reason, "does not exist in records") - assert.Contains(t, reason, "export/import") } -