Skip to content

Commit

Permalink
Forward LLM completion in a sequential model, concurrently with filte…
Browse files Browse the repository at this point in the history
…ring. (#62111)

* Tests broken

* Tests run OK

* Use guardrails v2 completions handler behind feature flag

* Add some comments

* BAZEL fix

* Update internal/guardrails/attribution_filter2.go

* gofmt
  • Loading branch information
cbart committed Apr 26, 2024
1 parent bece762 commit d10d4f0
Show file tree
Hide file tree
Showing 5 changed files with 375 additions and 320 deletions.
11 changes: 10 additions & 1 deletion internal/completions/httpapi/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,16 @@ func newStreamingResponseHandler(logger log.Logger, db database.DB, feature type
f := guardrails.NoopCompletionsFilter(eventSink)
if cf := conf.GetConfigFeatures(conf.SiteConfig()); cf != nil && cf.Attribution &&
featureflag.FromContext(ctx).GetBoolOr("autocomplete-attribution", true) {
ff, err := guardrails.NewCompletionsFilter(guardrails.CompletionsFilterConfig{
factory := guardrails.NewCompletionsFilter
// TODO(#61828) - Validate & cleanup:
// 1. If experiments are successful on S2 and we do not see any panics,
// please switch the feature flag default value to true.
// 2. Afterwards cleanup the implementation and only use V2 completion filter.
// Remove v1 implementation completely.
if featureflag.FromContext(ctx).GetBoolOr("autocomplete-attribution-v2", false) {
factory = guardrails.NewCompletionsFilter2
}
ff, err := factory(guardrails.CompletionsFilterConfig{
Sink: eventSink,
Test: test,
AttributionError: attributionErrorLog,
Expand Down
2 changes: 1 addition & 1 deletion internal/guardrails/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ go_library(
name = "guardrails",
srcs = [
"attribution_filter.go",
"attribution_filter2.go",
"attribution_threshold.go",
],
importpath = "github.com/sourcegraph/sourcegraph/internal/guardrails",
Expand All @@ -22,7 +23,6 @@ go_test(
deps = [
":guardrails",
"//internal/completions/types",
"//lib/errors",
"@com_github_stretchr_testify//require",
],
)
1 change: 1 addition & 0 deletions internal/guardrails/attribution_filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ func NewCompletionsFilter(config CompletionsFilterConfig) (CompletionsFilter, er
func (a *completionsFilter) Send(ctx context.Context, e types.CompletionResponse) error {
if err := ctx.Err(); err != nil {
a.blockSending()
return err
}
if a.attributionResultPermissive() {
return a.send(e)
Expand Down
114 changes: 114 additions & 0 deletions internal/guardrails/attribution_filter2.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package guardrails

import (
"context"
"sync"
"sync/atomic"

"github.com/sourcegraph/sourcegraph/internal/completions/types"
"github.com/sourcegraph/sourcegraph/lib/errors"
)

// NewCompletionsFilter2 returns a fully initialized streaming filter.
// This filter should be used for only a single code completions streaming
// since it keeps internal state. `Send` and `WaitDone` methods
// are expected to be called from single goroutine,
// and `WaitDone` is only invoked after all `Send` calls have finished.
func NewCompletionsFilter2(config CompletionsFilterConfig) (CompletionsFilter, error) {
if config.Sink == nil || config.Test == nil || config.AttributionError == nil {
return nil, errors.New("Attribution filtering misconfigured.")
}
return &attributionRunFilter{
config: config,
closeOnSearchFinished: make(chan struct{}),
}, nil
}

// attributionRunFilter implementation of CompletionsFilter that runs attribution search for snippets
// aboce certain threshold defined by `SnippetLowerBound`.
// It's inspired by an idea of [communicating sequential processes](https://en.wikipedia.org/wiki/Communicating_sequential_processes):
// - Attribution search is going to run async of passing down the completion from the LLM.
// - But all the work of actually forwarding LLM completions happens only on the _caller_
// thread, that is the one that controls the filter by calling `Send` and `WaitDone`.
// Hopefully this will ensure proper desctruction of response sending and no races
// with attribution search.
// - The synchronization with attribution search happens on 3 elements:
// 1. `sync.Once` is used to ensure attribution search is only fired once.
// This simplifies logic of starting search once snippet passed a given threshold.
// 2. `atomic.Bool` is set to true once we confirm no attribution was found.
// This makes it real easy for `Send` to make a decision on the spot whether or not
// to forward a completion back to the client.
// 3. `chan struct{}` is closed on attribution search finishing.
// This is a robust way for `WaitDone` to wait on either context cancellation (timeout)
// or attribution search finishing (via select). Channel closing happens
// in the attribution search routine, which is fired via sync.Once, so no chance
// of multiple goroutines closing the channel.
type attributionRunFilter struct {
config CompletionsFilterConfig
// Just to make sure we have run attribution once.
attributionSearch sync.Once
// Attribution search result, true = successful, false is any of {not run, pending, failed, error}.
attributionSucceeded atomic.Bool
// Channel that the attribution routine closes when attribution is finished.
closeOnSearchFinished chan struct{}
// Last seen completion that was not sent. Unsynchronized - only referred
// to from Send and WaitDone, which are executed in the same request handler routine.
last *types.CompletionResponse
}

// Send forwards the completion to the client if it can given its current idea
// about attribution status. Otherwise it memoizes completion in `attributionRunFilter#last`.
func (d *attributionRunFilter) Send(ctx context.Context, r types.CompletionResponse) error {
if err := ctx.Err(); err != nil {
return err
}
if d.shortEnough(r) {
d.last = nil // last seen completion was sent
return d.config.Sink(r)
}
d.attributionSearch.Do(func() { go d.runAttribution(ctx, r) })
if d.attributionSucceeded.Load() {
d.last = nil // last seen completion was sent
return d.config.Sink(r)
}
d.last = &r // last seen completion not sent
return nil
}

// WaitDone awaits either attribution search finishing or timeout.
// The caller calls WaitDone only after all calls to send, so LLM is done streaming.
// This is why in case of attribution search finishing it's enough for us
// to send the last memoized completion here.
func (d *attributionRunFilter) WaitDone(ctx context.Context) error {
select {
case <-ctx.Done():
// Request cancelled, return.
return ctx.Err()
case <-d.closeOnSearchFinished:
// When search finishes successfully and last seen completion was not sent, send it now, and finish.
if d.attributionSucceeded.Load() && d.last != nil {
return d.config.Sink(*d.last)
}
return nil
}
}

func (d *attributionRunFilter) shortEnough(r types.CompletionResponse) bool {
return !NewThreshold().ShouldSearch(r.Completion)
}

// runAttribution is a blocking function that defines the goroutine
// that runs attribution search and then synchronizes with main thread
// by setting `atomic.Bool` for flagging and closing `chan struct{}`
// to notify `WaitDone` if needed.
func (d *attributionRunFilter) runAttribution(ctx context.Context, r types.CompletionResponse) {
defer close(d.closeOnSearchFinished)
canUse, err := d.config.Test(ctx, r.Completion)
if err != nil {
d.config.AttributionError(err)
return
}
if canUse {
d.attributionSucceeded.Store(true)
}
}
Loading

0 comments on commit d10d4f0

Please sign in to comment.