-
-
Notifications
You must be signed in to change notification settings - Fork 523
/
provider_meta.go
98 lines (84 loc) · 2.66 KB
/
provider_meta.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
package bedrockclient
import (
"context"
"encoding/json"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/tmc/langchaingo/llms"
)
// Ref: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html
// metaTextGenerationInput is the input to the model.
type metaTextGenerationInput struct {
// The prompt that you want to pass to the model. Required
Prompt string `json:"prompt"`
// Used to control the randomness of the generation. Optional, default = 0.5
Temperature float64 `json:"temperature,omitempty"`
// Used to lower value to ignore less probable options. Optional, default = 0.9
TopP float64 `json:"top_p,omitempty"`
// The maximum number of tokens to generate per result.
// The model truncates the response once the generated text exceeds max_gen_len.
// Optional, default = 512
MaxGenLen int `json:"max_gen_len,omitempty"`
}
// metaTextGenerationOutput is the output from the model.
type metaTextGenerationOutput struct {
// The generated text.
Generation string `json:"generation"`
// The number of tokens in the prompt.
PromptTokenCount int `json:"prompt_token_count"`
// The number of tokens in the generated text.
GenerationTokenCount int `json:"generation_token_count"`
// The reason why the response stopped generating text.
// One of: ["stop", "length"]
StopReason string `json:"stop_reason"`
}
// Finish reason for the completion of the generation.
const (
MetaCompletionReasonStop = "stop"
MetaCompletionReasonLength = "length"
)
func createMetaCompletion(ctx context.Context,
client *bedrockruntime.Client,
modelID string,
messages []Message,
options llms.CallOptions,
) (*llms.ContentResponse, error) {
txt := processInputMessagesGeneric(messages)
input := &metaTextGenerationInput{
Prompt: txt,
Temperature: options.Temperature,
TopP: options.TopP,
MaxGenLen: options.MaxTokens,
}
body, err := json.Marshal(input)
if err != nil {
return nil, err
}
modelInput := &bedrockruntime.InvokeModelInput{
ModelId: aws.String(modelID),
Accept: aws.String("*/*"),
ContentType: aws.String("application/json"),
Body: body,
}
resp, err := client.InvokeModel(ctx, modelInput)
if err != nil {
return nil, err
}
var output metaTextGenerationOutput
err = json.Unmarshal(resp.Body, &output)
if err != nil {
return nil, err
}
return &llms.ContentResponse{
Choices: []*llms.ContentChoice{
{
Content: output.Generation,
StopReason: output.StopReason,
GenerationInfo: map[string]interface{}{
"input_tokens": output.PromptTokenCount,
"output_tokens": output.GenerationTokenCount,
},
},
},
}, nil
}