Skip to content

Commit 3cc6f3b

Browse files
committed
parser
1 parent d615b90 commit 3cc6f3b

File tree

3 files changed

+62
-2
lines changed

3 files changed

+62
-2
lines changed

pkg/main.go renamed to examples/normal/main.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"fmt"
55
"golangchain/pkg/lib"
66
"golangchain/pkg/openai"
7+
"golangchain/pkg/parser"
78
"golangchain/pkg/prompt"
89
)
910

@@ -13,9 +14,10 @@ func main() {
1314
fmt.Println(err)
1415
}
1516
prompt, err := prompt.NewPromptTemplate("{{.Word}}の意味を教えて。")
17+
parser := parser.NewStrOutputParser()
1618

1719
pipeline := lib.NewPipeline()
18-
pipeline.Pipe(prompt).Pipe(llm)
20+
pipeline.Pipe(prompt).Pipe(llm).Pipe(parser)
1921
m := map[string]string{
2022
"Word": "因果応報",
2123
}
@@ -25,5 +27,4 @@ func main() {
2527
fmt.Println(err)
2628
}
2729
fmt.Printf("Response: %+v\n", response)
28-
2930
}

pkg/parser/stroutputparser.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package parser
2+
3+
import (
4+
"golangchain/pkg/openai"
5+
)
6+
7+
type StrOutputParser struct {
8+
}
9+
10+
func NewStrOutputParser() *StrOutputParser {
11+
return &StrOutputParser{}
12+
}
13+
14+
func (p *StrOutputParser) Invoke(input any) (any, error) {
15+
var output string
16+
res, ok := input.(*openai.Response)
17+
if ok {
18+
output = res.Choices[0].Message.Content
19+
} else {
20+
return nil, nil
21+
}
22+
return output, nil
23+
}

pkg/parser/stroutputparser_test.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package parser
2+
3+
import (
4+
"golangchain/pkg/openai"
5+
"testing"
6+
)
7+
8+
func TestInvoke(t *testing.T) {
9+
tests := []struct {
10+
name string
11+
input any
12+
expected string
13+
}{
14+
{
15+
name: "template is string",
16+
input: &openai.Response{
17+
Choices: []openai.Choice{
18+
{Message: openai.Message{Role: "assistant", Content: "This is the test response"}},
19+
},
20+
},
21+
expected: "This is the test response",
22+
},
23+
}
24+
for _, tc := range tests {
25+
t.Run(tc.name, func(t *testing.T) {
26+
parser := NewStrOutputParser()
27+
have, err := parser.Invoke(tc.input)
28+
if err != nil {
29+
t.Fatalf("Error happens: %v", err)
30+
}
31+
if have != tc.expected {
32+
t.Fatalf("unexpected string: %v != %v", have, tc.expected)
33+
}
34+
})
35+
}
36+
}

0 commit comments

Comments
 (0)