Skip to content

Commit

Permalink
Backport: Client Compatible Bedrock ARN handling (#62720) (#62793)
Browse files Browse the repository at this point in the history
Client Compatible Bedrock ARN handling (#62720)

* Improve Bedrock ARN handling

* Fix up PR comments

(cherry picked from commit 6a7666c)
  • Loading branch information
RXminuS committed May 20, 2024
1 parent 660dee6 commit c06fd8d
Show file tree
Hide file tree
Showing 10 changed files with 290 additions and 21 deletions.
1 change: 1 addition & 0 deletions internal/completions/client/awsbedrock/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ go_library(
deps = [
"//internal/completions/tokenusage",
"//internal/completions/types",
"//internal/conf/conftypes",
"//internal/httpcli",
"//lib/errors",
"@com_github_aws_aws_sdk_go_v2//aws",
Expand Down
55 changes: 41 additions & 14 deletions internal/completions/client/awsbedrock/bedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (

"github.com/sourcegraph/sourcegraph/internal/completions/tokenusage"
"github.com/sourcegraph/sourcegraph/internal/completions/types"
"github.com/sourcegraph/sourcegraph/internal/conf/conftypes"
"github.com/sourcegraph/sourcegraph/internal/httpcli"
"github.com/sourcegraph/sourcegraph/lib/errors"
)
Expand Down Expand Up @@ -68,7 +69,8 @@ func (c *awsBedrockAnthropicCompletionStreamClient) Complete(
completion += content.Text
}

err = c.tokenManager.UpdateTokenCountsFromModelUsage(response.Usage.InputTokens, response.Usage.OutputTokens, "anthropic/"+requestParams.Model, string(feature), tokenusage.AwsBedrock)
parsedModelId := conftypes.NewBedrockModelRefFromModelID(requestParams.Model)
err = c.tokenManager.UpdateTokenCountsFromModelUsage(response.Usage.InputTokens, response.Usage.OutputTokens, "anthropic/"+parsedModelId.Model, string(feature), tokenusage.AwsBedrock)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -153,7 +155,8 @@ func (a *awsBedrockAnthropicCompletionStreamClient) Stream(
case "message_delta":
if event.Delta != nil {
stopReason = event.Delta.StopReason
err = a.tokenManager.UpdateTokenCountsFromModelUsage(inputPromptTokens, event.Usage.OutputTokens, "anthropic/"+requestParams.Model, string(feature), tokenusage.AwsBedrock)
parsedModelId := conftypes.NewBedrockModelRefFromModelID(requestParams.Model)
err = a.tokenManager.UpdateTokenCountsFromModelUsage(inputPromptTokens, event.Usage.OutputTokens, "anthropic/"+parsedModelId.Model, string(feature), tokenusage.AwsBedrock)
if err != nil {
logger.Warn("Failed to count tokens with the token manager %w ", log.Error(err))
}
Expand Down Expand Up @@ -232,19 +235,8 @@ func (c *awsBedrockAnthropicCompletionStreamClient) makeRequest(ctx context.Cont
if err != nil {
return nil, errors.Wrap(err, "marshalling request body")
}
apiURL, err := url.Parse(c.endpoint)
if err != nil || apiURL.Scheme == "" {
apiURL = &url.URL{
Scheme: "https",
Host: fmt.Sprintf("bedrock-runtime.%s.amazonaws.com", defaultConfig.Region),
}
}

if stream {
apiURL.Path = fmt.Sprintf("/model/%s/invoke-with-response-stream", requestParams.Model)
} else {
apiURL.Path = fmt.Sprintf("/model/%s/invoke", requestParams.Model)
}
apiURL := buildApiUrl(c.endpoint, requestParams.Model, stream, defaultConfig.Region)

req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL.String(), bytes.NewReader(reqBody))
if err != nil {
Expand Down Expand Up @@ -282,6 +274,41 @@ func (c *awsBedrockAnthropicCompletionStreamClient) makeRequest(ctx context.Cont
return resp, nil
}

// Builds a bedrock api URL from the configured endpoint url.
// If the endpoint isn't valid, falls back to the default endpoint for the specified fallbackRegion
func buildApiUrl(endpoint string, model string, stream bool, fallbackRegion string) *url.URL {
apiURL, err := url.Parse(endpoint)
if err != nil || apiURL.Scheme == "" {
apiURL = &url.URL{
Scheme: "https",
Host: fmt.Sprintf("bedrock-runtime.%s.amazonaws.com", fallbackRegion),
}
}

bedrockModelRef := conftypes.NewBedrockModelRefFromModelID(model)

if bedrockModelRef.ProvisionedCapacity != nil {
// We need to Query escape the provisioned capacity ARN, since otherwise
// the AWS API Gateway interprets the path as a path and doesn't route
// to the Bedrock service. This would results in abstract Coral errors
if stream {
apiURL.RawPath = fmt.Sprintf("/model/%s/invoke-with-response-stream", url.QueryEscape(*bedrockModelRef.ProvisionedCapacity))
apiURL.Path = fmt.Sprintf("/model/%s/invoke-with-response-stream", *bedrockModelRef.ProvisionedCapacity)
} else {
apiURL.RawPath = fmt.Sprintf("/model/%s/invoke", url.QueryEscape(*bedrockModelRef.ProvisionedCapacity))
apiURL.Path = fmt.Sprintf("/model/%s/invoke", *bedrockModelRef.ProvisionedCapacity)
}
} else {
if stream {
apiURL.Path = fmt.Sprintf("/model/%s/invoke-with-response-stream", bedrockModelRef.Model)
} else {
apiURL.Path = fmt.Sprintf("/model/%s/invoke", bedrockModelRef.Model)
}
}

return apiURL
}

func awsConfigOptsForKeyConfig(endpoint string, accessToken string) []func(*config.LoadOptions) error {
configOpts := []func(*config.LoadOptions) error{}
if endpoint != "" {
Expand Down
45 changes: 44 additions & 1 deletion internal/completions/client/awsbedrock/bedrock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,56 @@ package awsbedrock

import (
"context"
"fmt"
"testing"

"github.com/aws/aws-sdk-go-v2/config"
"github.com/stretchr/testify/require"
)

func TestAwsConfigOptsForKeyConfig(t *testing.T) {
func Test_BedrockProvisionedThroughputModel(t *testing.T) {
tests := []struct {
want string
endpoint string
model string
fallbackRegion string
stream bool
}{
{
want: "https://bedrock-runtime.us-west-2.amazonaws.com/model/amazon.titan-text-express-v1/invoke",
endpoint: "",
model: "amazon.titan-text-express-v1",
fallbackRegion: "us-west-2",
stream: false,
},
{
want: "https://bedrock-runtime.us-west-2.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0:200k/invoke",
endpoint: "",
model: "anthropic.claude-3-sonnet-20240229-v1:0:200k",
fallbackRegion: "us-west-2",
stream: false,
},
{
want: "https://vpce-12345678910.bedrock-runtime.us-west-2.vpce.amazonaws.com/model/arn%3Aaws%3Abedrock%3Aus-west-2%3A012345678901%3Aprovisioned-model%2Fabcdefghijkl/invoke-with-response-stream",
endpoint: "https://vpce-12345678910.bedrock-runtime.us-west-2.vpce.amazonaws.com",
model: "anthropic.claude-instant-v1/arn:aws:bedrock:us-west-2:012345678901:provisioned-model/abcdefghijkl",
fallbackRegion: "us-east-1",
stream: true,
},
}

for _, tt := range tests {
t.Run(fmt.Sprintf("%q", tt.want), func(t *testing.T) {
got := buildApiUrl(tt.endpoint, tt.model, tt.stream, tt.fallbackRegion)
if got.String() != tt.want {
t.Logf("got %q but wanted %q", got, tt.want)
t.Fail()
}
})
}
}

func Test_AwsConfigOptsForKeyConfig(t *testing.T) {

t.Run("With endpoint as URL", func(t *testing.T) {
endpoint := "https://example.com"
Expand Down
1 change: 1 addition & 0 deletions internal/completions/client/azureopenai/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ go_library(
go_test(
name = "azureopenai_test",
srcs = ["openai_test.go"],
data = glob(["testdata/**"]),
embed = [":azureopenai"],
deps = [
"//internal/completions/tokenusage",
Expand Down
35 changes: 29 additions & 6 deletions internal/conf/computed.go
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,11 @@ func GetCompletionsConfig(siteConfig schema.SiteConfiguration) (c *conftypes.Com
completionsConfig.ChatModel = completionsConfig.Model
}

// This records if the modelIDs have been canonicalized by the provider
// specific configuration. By default a ToLower will be applied the modelIDs
// if no other canonicalization has already been applied. In particular this
// is because BedrockModelRefs need different canonicalization
canonicalized := false
if completionsConfig.Provider == string(conftypes.CompletionsProviderNameSourcegraph) {
// If no endpoint is configured, use a default value.
if completionsConfig.Endpoint == "" {
Expand Down Expand Up @@ -850,12 +855,27 @@ func GetCompletionsConfig(siteConfig schema.SiteConfiguration) (c *conftypes.Com
if completionsConfig.CompletionModel == "" {
completionsConfig.CompletionModel = "anthropic.claude-instant-v1"
}

// We apply BedrockModelRef specific canonicalization
// Make sure models are always treated case-insensitive.
chatModelRef := conftypes.NewBedrockModelRefFromModelID(completionsConfig.ChatModel)
completionsConfig.ChatModel = chatModelRef.CanonicalizedModelID()

fastChatModelRef := conftypes.NewBedrockModelRefFromModelID(completionsConfig.FastChatModel)
completionsConfig.FastChatModel = fastChatModelRef.CanonicalizedModelID()

completionsModelRef := conftypes.NewBedrockModelRefFromModelID(completionsConfig.CompletionModel)
completionsConfig.CompletionModel = completionsModelRef.CanonicalizedModelID()
canonicalized = true
}

// Make sure models are always treated case-insensitive.
completionsConfig.ChatModel = strings.ToLower(completionsConfig.ChatModel)
completionsConfig.FastChatModel = strings.ToLower(completionsConfig.FastChatModel)
completionsConfig.CompletionModel = strings.ToLower(completionsConfig.CompletionModel)
// only apply canonicalization if not already applied. Not all model IDs can simply be lowercased
if !canonicalized {
// Make sure models are always treated case-insensitive.
completionsConfig.ChatModel = strings.ToLower(completionsConfig.ChatModel)
completionsConfig.FastChatModel = strings.ToLower(completionsConfig.FastChatModel)
completionsConfig.CompletionModel = strings.ToLower(completionsConfig.CompletionModel)
}

// If after trying to set default we still have not all models configured, completions are
// not available.
Expand Down Expand Up @@ -1185,8 +1205,9 @@ func defaultMaxPromptTokens(provider conftypes.CompletionsProviderName, model st
// this is a sane default for GPT in general.
return 7_000
case conftypes.CompletionsProviderNameAWSBedrock:
if strings.HasPrefix(model, "anthropic.") {
return anthropicDefaultMaxPromptTokens(strings.TrimPrefix(model, "anthropic."))
parsed := conftypes.NewBedrockModelRefFromModelID(model)
if strings.HasPrefix(parsed.Model, "anthropic.") {
return anthropicDefaultMaxPromptTokens(strings.TrimPrefix(parsed.Model, "anthropic."))
}
// Fallback for weird values.
return 9_000
Expand All @@ -1197,6 +1218,8 @@ func defaultMaxPromptTokens(provider conftypes.CompletionsProviderName, model st
}

func anthropicDefaultMaxPromptTokens(model string) int {
// TODO: this doesn't nearly cover all the ways that token size can be specified.
// See: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
if strings.HasSuffix(model, "-100k") {
return 100_000

Expand Down
24 changes: 24 additions & 0 deletions internal/conf/computed_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,30 @@ func TestGetCompletionsConfig(t *testing.T) {
Endpoint: "us-west-2",
},
},
{
name: "AWS Bedrock completions with Provisioned Throughput for some of the models",
siteConfig: schema.SiteConfiguration{
CodyEnabled: pointers.Ptr(true),
LicenseKey: licenseKey,
Completions: &schema.Completions{
Provider: "aws-bedrock",
Endpoint: "us-west-2",
ChatModel: "anthropic.claude-3-haiku-20240307-v1:0-100k/arn:aws:bedrock:us-west-2:012345678901:provisioned-model/abcdefghijkl",
FastChatModel: "anthropic.claude-v2",
},
},
wantConfig: &conftypes.CompletionsConfig{
ChatModel: "anthropic.claude-3-haiku-20240307-v1:0-100k/arn:aws:bedrock:us-west-2:012345678901:provisioned-model/abcdefghijkl",
ChatModelMaxTokens: 100_000,
FastChatModel: "anthropic.claude-v2",
FastChatModelMaxTokens: 12000,
CompletionModel: "anthropic.claude-instant-v1",
CompletionModelMaxTokens: 9000,
AccessToken: "",
Provider: "aws-bedrock",
Endpoint: "us-west-2",
},
},
{
name: "zero-config cody gateway completions without license key",
siteConfig: schema.SiteConfiguration{
Expand Down
48 changes: 48 additions & 0 deletions internal/conf/conftypes/conftypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package conftypes

import (
"reflect"
"strings"
"time"

"google.golang.org/protobuf/types/known/durationpb"
Expand Down Expand Up @@ -101,3 +102,50 @@ func (r *RawUnified) FromProto(in *proto.RawUnified) {
func (r RawUnified) Equal(other RawUnified) bool {
return r.Site == other.Site && reflect.DeepEqual(r.ServiceConnections, other.ServiceConnections)
}

// Bedrock Model IDs can be in one of two forms:
// - A static model ID, e.g. "anthropic.claude-v2".
// - A model ID and ARN for provisioned capacity, e.g.
// "anthropic.claude-v2/arn:aws:bedrock:us-west-2:012345678901:provisioned-model/xxxxxxxx"
//
// See the AWS docs for more information:
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
// https://docs.aws.amazon.com/bedrock/latest/APIReference/API_CreateProvisionedModelThroughput.html
type BedrockModelRef struct {
// Model is the underlying LLM model Bedrock is serving, e.g. "anthropic.claude-3-haiku-20240307-v1:0
Model string
// If the configuration is using provisioned capacity, this will
// contain the ARN of the model to use for making API calls.
// e.g. "anthropic.claude-v2/arn:aws:bedrock:us-west-2:012345678901:provisioned-model/xxxxxxxx"
ProvisionedCapacity *string
}

func NewBedrockModelRefFromModelID(modelID string) BedrockModelRef {
parts := strings.SplitN(modelID, "/", 2)

if parts == nil { // this shouldn't really happen
return BedrockModelRef{Model: modelID}
}

parsed := BedrockModelRef{
Model: parts[0],
}

if len(parts) == 2 {
parsed.ProvisionedCapacity = &parts[1]
}
return parsed
}

// Ensures that all case insensitive parts of the model ID are lowercased so
// that they can be compared.
func (bmr BedrockModelRef) CanonicalizedModelID() string {
// Bedrock models are case sensitive if they contain a ARN
// make sure to only lowercase the non ARN part
model := strings.ToLower(bmr.Model)

if bmr.ProvisionedCapacity != nil {
return strings.Join([]string{model, *bmr.ProvisionedCapacity}, "/")
}
return model
}
1 change: 1 addition & 0 deletions internal/conf/validation/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ go_test(
name = "validation_test",
srcs = [
"auth_test.go",
"cody_test.go",
"prometheus_test.go",
"txemail_test.go",
],
Expand Down
32 changes: 32 additions & 0 deletions internal/conf/validation/cody.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package validation

import (
"fmt"
"strings"
"time"

"github.com/sourcegraph/sourcegraph/internal/conf"
Expand All @@ -19,6 +20,8 @@ func init() {
conf.ContributeValidator(embeddingsConfigValidator)
}

const bedrockArnMessageTemplate = "completions.%s is invalid. Provisioned Capacity IDs must be formatted like \"model_id/provisioned_capacity_arn\".\nFor example \"anthropic.claude-instant-v1/%s\""

func completionsConfigValidator(q conftypes.SiteConfigQuerier) conf.Problems {
problems := []string{}
completionsConf := q.SiteConfig().Completions
Expand All @@ -30,6 +33,35 @@ func completionsConfigValidator(q conftypes.SiteConfigQuerier) conf.Problems {
problems = append(problems, "'completions.enabled' has been superceded by 'cody.enabled', please migrate to the new configuration.")
}

// Check for bedrock Provisioned Capacity ARNs which should instead be
// formatted like:
// "anthropic.claude-v2/arn:aws:bedrock:us-west-2:012345678901:provisioned-model/xxxxxxxx"
if completionsConf.Provider == string(conftypes.CompletionsProviderNameAWSBedrock) {
type modelID struct {
value string
field string
}
allModelIds := []modelID{
{value: completionsConf.ChatModel, field: "chatModel"},
{value: completionsConf.FastChatModel, field: "fastChatModel"},
{value: completionsConf.CompletionModel, field: "completionModel"},
}
var modelIdsToCheck []modelID
for _, modelId := range allModelIds {
if modelId.value != "" {
modelIdsToCheck = append(modelIdsToCheck, modelId)
}
}

for _, modelId := range modelIdsToCheck {
// When using provisioned capacity we expect an admin would just put the ARN
// here directly, but we need both the model AND the ARN. Hence the check.
if strings.HasPrefix(modelId.value, "arn:aws:") {
problems = append(problems, fmt.Sprintf(bedrockArnMessageTemplate, modelId.field, modelId.value))
}
}
}

if len(problems) > 0 {
return conf.NewSiteProblems(problems...)
}
Expand Down
Loading

0 comments on commit c06fd8d

Please sign in to comment.