-
-
Notifications
You must be signed in to change notification settings - Fork 537
/
provider_amazon.go
119 lines (103 loc) · 3.63 KB
/
provider_amazon.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
package bedrockclient
import (
"context"
"encoding/json"
"errors"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/tmc/langchaingo/llms"
)
// Ref: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-text.html
// amazonTextGenerationConfigInput is the input for the text generation configuration for Amazon Models.
type amazonTextGenerationConfigInput struct {
// The maximum number of tokens to generate per result. Optional, default = 512
MaxTokens int `json:"maxTokenCount,omitempty"`
// Use a lower value to ignore less probable options and decrease the diversity of responses. Optional, default = 1
TopP float64 `json:"topP,omitempty"`
// Use a lower value to decrease randomness in responses. Optional, default = 0.0
Temperature float64 `json:"temperature,omitempty"`
// Specify a character sequence to indicate where the model should stop.
// Currently only supports: ["|", "User:"]
StopSequences []string `json:"stopSequences,omitempty"`
}
// amazonTextGenerationInput is the input for the text generation for Amazon Models.
type amazonTextGenerationInput struct {
// The text which the model is requested to continue.
InputText string `json:"inputText"`
// The configuration for the text generation
TextGenerationConfig amazonTextGenerationConfigInput `json:"textGenerationConfig"`
}
// amazonTextGenerationOutput is the output for the text generation for Amazon Models.
type amazonTextGenerationOutput struct {
// The number of tokens in the prompt
InputTextTokenCount int `json:"inputTextTokenCount"`
// The results of the request
Results []struct {
// The number of tokens in the response
TokenCount int `json:"tokenCount"`
// The generated text
OutputText string `json:"outputText"`
// The reason for the completion of the generation
// One of: FINISH, LENGTH, CONTENT_FILTERED
CompletionReason string `json:"completionReason"`
} `json:"results"`
}
// Finish reason for the completion of the generation for Amazon Models.
const (
AmazonCompletionReasonFinish = "FINISH"
AmazonCompletionReasonMaxTokens = "LENGTH"
AmazonCompletionReasonContentFiltered = "CONTENT_FILTERED"
)
func createAmazonCompletion(ctx context.Context,
client *bedrockruntime.Client,
modelID string,
messages []Message,
options llms.CallOptions,
) (*llms.ContentResponse, error) {
txt := processInputMessagesGeneric(messages)
inputContent := amazonTextGenerationInput{
InputText: txt,
TextGenerationConfig: amazonTextGenerationConfigInput{
MaxTokens: options.MaxTokens,
TopP: options.TopP,
Temperature: options.Temperature,
StopSequences: options.StopWords,
},
}
body, err := json.Marshal(inputContent)
if err != nil {
return nil, err
}
modelInput := &bedrockruntime.InvokeModelInput{
ModelId: aws.String(modelID),
Accept: aws.String("*/*"),
ContentType: aws.String("application/json"),
Body: body,
}
resp, err := client.InvokeModel(ctx, modelInput)
if err != nil {
return nil, err
}
var output amazonTextGenerationOutput
err = json.Unmarshal(resp.Body, &output)
if err != nil {
return nil, err
}
if len(output.Results) == 0 {
return nil, errors.New("no results")
}
contentChoices := make([]*llms.ContentChoice, len(output.Results))
for i, result := range output.Results {
contentChoices[i] = &llms.ContentChoice{
Content: result.OutputText,
StopReason: result.CompletionReason,
GenerationInfo: map[string]any{
"input_tokens": output.InputTextTokenCount,
"output_tokens": result.TokenCount,
},
}
}
return &llms.ContentResponse{
Choices: contentChoices,
}, nil
}