Skip to content

Commit ed0d685

Browse files
committed
agent_etc
1 parent c816c67 commit ed0d685

File tree

11 files changed

+233
-45
lines changed

11 files changed

+233
-45
lines changed

examples/normal/main.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@ package main
33
import (
44
"fmt"
55
"golangchain/pkg/lib"
6-
"golangchain/pkg/openai"
6+
"golangchain/pkg/llm"
77
"golangchain/pkg/parser"
88
"golangchain/pkg/prompt"
99
)
1010

1111
func main() {
12-
llm, err := openai.NewChatOpenAI("gpt-3.5-turbo")
12+
llm, err := llm.NewChatOpenAI("gpt-3.5-turbo")
1313
if err != nil {
1414
fmt.Println(err)
1515
}

pkg/agent/agent.go

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,72 @@ package agent
22

33
import (
44
"fmt"
5+
"golangchain/pkg/lib"
6+
"golangchain/pkg/parser"
7+
"golangchain/pkg/prompt"
58
"strings"
69
)
710

8-
type Agent struct{}
11+
type Agent struct {
12+
LLMChain *lib.Pipeline
13+
Prompt *prompt.ChatPromptTemplate
14+
UserInput string
15+
}
916

10-
func NewAgent() *Agent {
11-
return &Agent{}
17+
func NewAgent(tools map[string]Tool, llm lib.Runnable) (*Agent, error) {
18+
parser := parser.NewStrOutputParser()
19+
prompt, err := CreatePrompt(tools)
20+
if err != nil {
21+
return nil, err
22+
}
23+
pl := lib.NewPipeline()
24+
pl.Pipe(prompt).Pipe(llm).Pipe(parser)
25+
agent := &Agent{
26+
LLMChain: pl,
27+
Prompt: prompt,
28+
}
29+
30+
return agent, nil
1231
}
1332

14-
func (a *Agent) CreatePrompt(tools []Tool) {
33+
func CreatePrompt(tools map[string]Tool) (*prompt.ChatPromptTemplate, error) {
1534
var toolStrings []string
1635
var toolNames []string
1736
for _, tool := range tools {
18-
toolStrings = append(toolStrings, fmt.Sprintf("%s: %s", tool.name, tool.description))
19-
toolNames = append(toolNames, tool.name)
37+
toolStrings = append(toolStrings, fmt.Sprintf("%s: %s", tool.name(), tool.description()))
38+
toolNames = append(toolNames, tool.name())
2039
}
2140
toolStringsJoined := strings.Join(toolStrings, "\n")
2241
toolNamesJoined := strings.Join(toolNames, ",")
2342
formatInstructions := strings.Replace(FORMAT_INSTRUCTIONS, "{{.ToolNames}}", toolNamesJoined, 1)
24-
fmt.Println(toolStringsJoined)
25-
fmt.Println(formatInstructions)
43+
instructions := []string{
44+
SYSTEM_MESSAGE_PREFIX,
45+
toolStringsJoined,
46+
formatInstructions,
47+
SYSTEM_MESSAGE_SUFFIX,
48+
}
49+
instruction := strings.Join(instructions, "\n\n")
50+
prompt, err := prompt.NewChatPromptTemplate(instruction, HUMAN_MESSAGE)
51+
if err != nil {
52+
return nil, err
53+
}
54+
return prompt, nil
55+
}
2656

57+
func (a *Agent) Plan(intermediateSteps []string) (any, error) {
58+
agentScratchpad := strings.Join(intermediateSteps, "\n")
59+
m := map[string]string{
60+
"Input": a.UserInput,
61+
"Agent_scratchpad": agentScratchpad,
62+
}
63+
prompt, err := a.Prompt.Invoke(m)
64+
if err != nil {
65+
return nil, err
66+
}
67+
agentDecision, err := a.LLMChain.Invoke(prompt)
68+
if err != nil {
69+
return nil, err
70+
}
71+
fmt.Println(agentDecision)
72+
return agentDecision, err
2773
}

pkg/agent/agent_executor.go

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,33 +7,67 @@ import (
77

88
type AgentExecutor struct {
99
Agent *Agent
10-
Tools []Tool
10+
Tools map[string]Tool
1111
MaxIterations int
1212
}
1313

14-
func InitializeAgent(tools []Tool, llm *lib.Runnable) *AgentExecutor {
15-
agent := NewAgent()
14+
func InitializeAgent(tools map[string]Tool, llm lib.Runnable) (*AgentExecutor, error) {
15+
agent, err := NewAgent(tools, llm)
16+
if err != nil {
17+
return nil, err
18+
}
1619

1720
return &AgentExecutor{
1821
Agent: agent,
1922
Tools: tools,
2023
MaxIterations: 15,
21-
}
24+
}, nil
2225
}
2326

2427
func (a *AgentExecutor) Invoke(input any) (any, error) {
25-
return nil, nil
28+
switch input.(type) {
29+
case string:
30+
a.Agent.UserInput = input.(string)
31+
default:
32+
return nil, nil
33+
}
34+
output, err := a.call()
35+
if err != nil {
36+
return nil, err
37+
}
38+
return output, nil
2639
}
2740

28-
func (a *AgentExecutor) TakeNextStep() any {
29-
return nil
41+
func (a *AgentExecutor) takeNextStep(intermediateSteps []string) (any, error) {
42+
var results []any
43+
output, err := a.Agent.Plan(intermediateSteps)
44+
if err != nil {
45+
return nil, err
46+
}
47+
48+
action := output.(string)
49+
var observation any
50+
if tool, ok := a.Tools[action]; ok {
51+
observation, err = tool.run("aaaaa")
52+
if err != nil {
53+
return nil, err
54+
}
55+
}
56+
results = append(results, observation)
57+
58+
return results, nil
3059
}
3160

3261
func (a *AgentExecutor) call() (any, error) {
3362
iterations := 0
63+
var intermediateSteps []string
3464
var nextStepOutput any
3565
for a.MaxIterations > iterations {
36-
nextStepOutput := a.TakeNextStep()
66+
nextStepOutput, err := a.takeNextStep(intermediateSteps)
67+
if err != nil {
68+
return nil, err
69+
}
70+
// TODO: fix
3771
fmt.Println(nextStepOutput)
3872
iterations += 1
3973
}

pkg/agent/prompt.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,4 @@ Observation: the result of the action
2626
Thought: I now know the final answer
2727
Final Answer: the final answer to the original input question`
2828
const SYSTEM_MESSAGE_SUFFIX = "Begin! Reminder to always use the exact characters `Final Answer` when responding."
29-
const HUMAN_MESSAGE = "{input}\n\n{agent_scratchpad}"
29+
const HUMAN_MESSAGE = "{{ .Input }}\n\n{{ .Agent_scratchpad }}"

pkg/agent/tool.go

Lines changed: 66 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,76 @@
11
package agent
22

3+
import (
4+
"encoding/json"
5+
"fmt"
6+
"io"
7+
"net/http"
8+
"net/url"
9+
"os"
10+
)
11+
312
var AvailableTools = map[string]Tool{
4-
"llm-math": {
5-
name: "Calculator",
6-
description: "Useful for when you need to answer questions about math.",
7-
},
8-
"serpapi": {
9-
name: "Search",
10-
description: "A search engine. Useful for when you need to answer questions about current events. Input should be a search query.",
11-
}}
12-
13-
type Tool struct {
14-
name string
15-
description string
16-
}
17-
18-
func LoadTools(tools []string) []Tool {
19-
confirmedTools := []Tool{}
13+
"serpapi": &SerpAPI{}}
14+
15+
type Tool interface {
16+
name() string
17+
description() string
18+
run(input any) (any, error)
19+
}
20+
21+
func LoadTools(tools []string) map[string]Tool {
22+
confirmedTools := map[string]Tool{}
2023
for _, toolname := range tools {
2124
if tool, ok := AvailableTools[toolname]; ok {
2225
confirmedTool := tool
23-
confirmedTools = append(confirmedTools, confirmedTool)
26+
confirmedTools[toolname] = confirmedTool
2427
}
2528
}
2629
return confirmedTools
2730
}
31+
32+
type OrganicResult struct {
33+
Title string `json:"title"`
34+
Link string `json:"link"`
35+
}
36+
37+
type SerpApiResponse struct {
38+
OrganicResults []OrganicResult `json:"organic_results"`
39+
}
40+
41+
type SerpAPI struct {
42+
}
43+
44+
func (s *SerpAPI) name() string {
45+
return "Search"
46+
}
47+
48+
func (s *SerpAPI) description() string {
49+
return "A search engine. Useful for when you need to answer questions about current events. Input should be a search query."
50+
}
51+
52+
func (s *SerpAPI) run(input any) (any, error) {
53+
apiKey := os.Getenv("SERPAPI_API_KEY")
54+
searchQuery := input.(string)
55+
serpApiURL := "https://serpapi.com/search"
56+
57+
params := url.Values{}
58+
params.Add("engine", "google")
59+
params.Add("q", searchQuery)
60+
params.Add("api_key", apiKey)
61+
62+
resp, err := http.Get(fmt.Sprintf("%s?%s", serpApiURL, params.Encode()))
63+
if err != nil {
64+
return nil, err
65+
}
66+
body, err := io.ReadAll(resp.Body)
67+
if err != nil {
68+
return nil, err
69+
}
70+
71+
var response SerpApiResponse
72+
if err := json.Unmarshal(body, &response); err != nil {
73+
return nil, err
74+
}
75+
return response, nil
76+
}

pkg/openai/openai.go renamed to pkg/llm/openai.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package openai
1+
package llm
22

33
import (
44
"bytes"

pkg/openai/openai_test.go renamed to pkg/llm/openai_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package openai
1+
package llm
22

33
import (
44
"net/http"

pkg/openai/struct.go renamed to pkg/llm/struct.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package openai
1+
package llm
22

33
// message represents a single message in the chat
44
type Message struct {

pkg/parser/stroutputparser.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package parser
22

33
import (
4-
"golangchain/pkg/openai"
4+
"golangchain/pkg/llm"
55
)
66

77
type StrOutputParser struct {
@@ -13,7 +13,7 @@ func NewStrOutputParser() *StrOutputParser {
1313

1414
func (p *StrOutputParser) Invoke(input any) (any, error) {
1515
var output string
16-
res, ok := input.(*openai.Response)
16+
res, ok := input.(*llm.Response)
1717
if ok {
1818
output = res.Choices[0].Message.Content
1919
} else {

pkg/parser/stroutputparser_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package parser
22

33
import (
4-
"golangchain/pkg/openai"
4+
"golangchain/pkg/llm"
55
"testing"
66
)
77

@@ -13,9 +13,9 @@ func TestInvoke(t *testing.T) {
1313
}{
1414
{
1515
name: "template is string",
16-
input: &openai.Response{
17-
Choices: []openai.Choice{
18-
{Message: openai.Message{Role: "assistant", Content: "This is the test response"}},
16+
input: &llm.Response{
17+
Choices: []llm.Choice{
18+
{Message: llm.Message{Role: "assistant", Content: "This is the test response"}},
1919
},
2020
},
2121
expected: "This is the test response",

pkg/prompt/chat_prompt.go

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
package prompt
2+
3+
import (
4+
"bytes"
5+
"fmt"
6+
"golangchain/pkg/llm"
7+
"text/template"
8+
)
9+
10+
type ChatTemplate struct {
11+
system *template.Template
12+
human *template.Template
13+
}
14+
15+
type ChatPromptTemplate struct {
16+
template *ChatTemplate
17+
}
18+
19+
func NewChatPromptTemplate(system string, human string) (*ChatPromptTemplate, error) {
20+
sysTmpl, err := template.New("system").Parse(system)
21+
if err != nil {
22+
return nil, fmt.Errorf("failed to parse template: %w", err)
23+
}
24+
humanTmpl, err := template.New("human").Parse(human)
25+
if err != nil {
26+
return nil, fmt.Errorf("failed to parse template: %w", err)
27+
}
28+
Chat := &ChatTemplate{
29+
system: sysTmpl,
30+
human: humanTmpl,
31+
}
32+
return &ChatPromptTemplate{
33+
template: Chat,
34+
}, nil
35+
}
36+
37+
func (t *ChatPromptTemplate) Invoke(input any) (any, error) {
38+
var sbuf bytes.Buffer
39+
if err := t.template.system.Execute(&sbuf, input); err != nil {
40+
return "", err
41+
}
42+
sMes := llm.Message{
43+
Role: "system",
44+
Content: sbuf.String(),
45+
}
46+
var hbuf bytes.Buffer
47+
if err := t.template.human.Execute(&hbuf, input); err != nil {
48+
return "", err
49+
}
50+
hMes := llm.Message{
51+
Role: "user",
52+
Content: hbuf.String(),
53+
}
54+
messages := []llm.Message{
55+
sMes,
56+
hMes,
57+
}
58+
return messages, nil
59+
}

0 commit comments

Comments
 (0)