Skip to content

Commit f75a77a

Browse files
committed
done
1 parent 65d8434 commit f75a77a

File tree

5 files changed

+56
-34
lines changed

5 files changed

+56
-34
lines changed

examples/agent/main.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ func main() {
1616
if err != nil {
1717
fmt.Println(err)
1818
}
19-
prompt := "青を英語で言うと、何か?"
19+
prompt := "日本の首相の名前は?"
2020
result, err := agentExecutor.Invoke(prompt)
2121
if err != nil {
2222
fmt.Println(err)
2323
}
24-
fmt.Printf("Response: %+v\n", result)
24+
fmt.Printf("%+v\n", result)
2525
}

pkg/agent/agent.go

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ func CreatePrompt(tools map[string]Tool) (*prompt.ChatPromptTemplate, error) {
5555
return prompt, nil
5656
}
5757

58-
func (a *Agent) Plan(intermediateSteps []string) (*NextAction, error) {
58+
func (a *Agent) Plan(intermediateSteps []string) (*NextAction, string, error) {
5959
var nextaction *NextAction
6060
agentScratchpad := strings.Join(intermediateSteps, "\n")
6161
m := map[string]string{
@@ -65,14 +65,18 @@ func (a *Agent) Plan(intermediateSteps []string) (*NextAction, error) {
6565

6666
agentDecision, err := a.LLMChain.Invoke(m)
6767
if err != nil {
68-
return nil, err
68+
return nil, "", err
69+
}
70+
decision := agentDecision.(string)
71+
72+
err = json.Unmarshal([]byte(decision), &nextaction)
73+
if strings.HasPrefix(decision, "FinalAnswer") {
74+
return nil, decision, nil
6975
}
70-
fmt.Printf("agentDecision: %v\n", agentDecision)
71-
err = json.Unmarshal([]byte(agentDecision.(string)), &nextaction)
7276
if err != nil {
73-
return nil, fmt.Errorf("got an error during Plan: %w", err)
77+
return nil, "", fmt.Errorf("got an error during Plan: %w", err)
7478
}
75-
return nextaction, err
79+
return nextaction, "", err
7680
}
7781

7882
type NextAction struct {

pkg/agent/agent_executor.go

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,48 +26,51 @@ func InitializeAgent(tools map[string]Tool, llm lib.Runnable) (*AgentExecutor, e
2626
}
2727

2828
func (a *AgentExecutor) Invoke(input any) (any, error) {
29-
switch input.(type) {
30-
case string:
31-
a.Agent.UserInput = input.(string)
32-
default:
33-
return nil, nil
34-
}
29+
a.Agent.UserInput = input.(string)
3530
output, err := a.call()
3631
if err != nil {
3732
return nil, err
3833
}
3934
return output, nil
4035
}
4136

42-
func (a *AgentExecutor) takeNextStep(intermediateSteps []string) (any, error) {
43-
var observation any
44-
output, err := a.Agent.Plan(intermediateSteps)
37+
func (a *AgentExecutor) takeNextStep(intermediateSteps []string) (string, string, error) {
38+
var observation string
39+
nextaction, finalanswer, err := a.Agent.Plan(intermediateSteps)
4540
if err != nil {
46-
return nil, err
41+
return "", "", err
4742
}
48-
if tool, ok := a.Tools[output.Action.Action_name]; ok {
49-
observation, err = tool.run(output.Action.Action_input)
43+
if len(finalanswer) > 1 {
44+
return "", finalanswer, nil
45+
}
46+
if tool, ok := a.Tools[nextaction.Action.Action_name]; ok {
47+
observation, err = tool.run(nextaction.Action.Action_input)
5048
if err != nil {
51-
return nil, fmt.Errorf("error during takeNextStep: %w", err)
49+
return "", "", fmt.Errorf("error during takeNextStep: %w", err)
5250
}
5351
}
52+
nextactionstring := fmt.Sprintf("Thought: %s\nAction_name: %s\nAction_input: %s", nextaction.Thought, nextaction.Action.Action_name, nextaction.Action.Action_input)
5453

55-
return observation, nil
54+
return observation, nextactionstring, nil
5655
}
5756

5857
func (a *AgentExecutor) call() (any, error) {
5958
iterations := 0
6059
var nextStepOutput any
6160
var intermediateSteps []string
6261
for a.MaxIterations > iterations {
63-
nextStepOutput, err := a.takeNextStep(intermediateSteps)
62+
nextStepOutput, pastaction, err := a.takeNextStep(intermediateSteps)
6463
if err != nil {
6564
return nil, err
6665
}
66+
if nextStepOutput == "" {
67+
return pastaction, nil
68+
}
6769
jsonData, err := json.Marshal(nextStepOutput)
6870
if err != nil {
6971
return nil, err
7072
}
73+
intermediateSteps = append(intermediateSteps, pastaction)
7174
intermediateSteps = append(intermediateSteps, string(jsonData))
7275
iterations += 1
7376
}

pkg/agent/prompt.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,6 @@ ALWAYS use the following format:
2626
Observation: the result of the action
2727
... (this Thought/Action/Observation can repeat N times)
2828
Thought: I now know the final answer
29-
Final Answer: the final answer to the original input question`
30-
const SYSTEM_MESSAGE_SUFFIX = "Begin! Reminder to always use the exact characters `Final Answer` when responding."
29+
FinalAnswer: the final answer to the original input question`
30+
const SYSTEM_MESSAGE_SUFFIX = "Begin! Reminder to always use the exact characters `FinalAnswer` when responding."
3131
const HUMAN_MESSAGE = "{{ .Input }}\n\n{{ .Agent_scratchpad }}"

pkg/agent/tool.go

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"net/http"
88
"net/url"
99
"os"
10+
"strings"
1011
)
1112

1213
var AvailableTools = map[string]Tool{
@@ -15,7 +16,7 @@ var AvailableTools = map[string]Tool{
1516
type Tool interface {
1617
name() string
1718
description() string
18-
run(input any) (any, error)
19+
run(input any) (string, error)
1920
}
2021

2122
func LoadTools(tools []string) map[string]Tool {
@@ -30,8 +31,11 @@ func LoadTools(tools []string) map[string]Tool {
3031
}
3132

3233
type OrganicResult struct {
33-
Title string `json:"title"`
34-
Link string `json:"link"`
34+
Title string `json:"title"`
35+
Link string `json:"link"`
36+
Snippet string `json:"snippet"`
37+
SnippetHighlightedWords []string `json:"snippet_highlighted_words"`
38+
RichSnippet string `json:"rich_snippet"`
3539
}
3640

3741
type SerpApiResponse struct {
@@ -49,7 +53,7 @@ func (s *SerpAPI) description() string {
4953
return "A search engine. Useful for when you need to answer questions about current events. Input should be a search query."
5054
}
5155

52-
func (s *SerpAPI) run(input any) (any, error) {
56+
func (s *SerpAPI) run(input any) (string, error) {
5357
apiKey := os.Getenv("SERPAPI_API_KEY")
5458
searchQuery := input.(string)
5559
serpApiURL := "https://serpapi.com/search"
@@ -61,17 +65,28 @@ func (s *SerpAPI) run(input any) (any, error) {
6165

6266
resp, err := http.Get(fmt.Sprintf("%s?%s", serpApiURL, params.Encode()))
6367
if err != nil {
64-
return nil, err
68+
return "", err
6569
}
6670
body, err := io.ReadAll(resp.Body)
6771
if err != nil {
68-
return nil, err
72+
return "", err
6973
}
7074

7175
var response SerpApiResponse
7276
if err := json.Unmarshal(body, &response); err != nil {
73-
return nil, err
77+
return "", err
7478
}
75-
fmt.Printf("serpapi run: %+v\n", response)
76-
return response, nil
79+
snippets := []string{}
80+
for _, result := range response.OrganicResults {
81+
if len(result.Snippet) > 0 {
82+
snippets = append(snippets, result.Snippet)
83+
}
84+
if len(result.SnippetHighlightedWords) > 0 {
85+
snippets = append(snippets, strings.Join(result.SnippetHighlightedWords, "\n"))
86+
}
87+
if len(result.RichSnippet) > 0 {
88+
snippets = append(snippets, result.RichSnippet)
89+
}
90+
}
91+
return strings.Join(snippets, "\n"), nil
7792
}

0 commit comments

Comments
 (0)