Skip to content

Commit

Permalink
Merge pull request #24 from otiai10/develop
Browse files Browse the repository at this point in the history
Set max-auto-funcall to limit recursive call
  • Loading branch information
otiai10 committed Jul 9, 2023
2 parents 813db98 + 1b5382c commit cbc5236
Showing 1 changed file with 55 additions and 15 deletions.
70 changes: 55 additions & 15 deletions chatgpt/chatgpt.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"github.com/otiai10/openaigo/functioncall"
)

const DefaultMaxAutoFunctionCall = 8

type Client struct {
openaigo.Client `json:"-"`

Expand Down Expand Up @@ -85,24 +87,31 @@ type Client struct {

// FunctionCall: You ain't need it. Default is "auto".
FunctionCall string `json:"function_call,omitempty"`

// Max number of calling function automatically
MaxFunctionCallHandling int `json:"-"`
}

type Message openaigo.Message
type Message struct {
openaigo.Message
autocalled bool
}

func New(apikey, model string) *Client {
return &Client{
Client: openaigo.Client{
APIKey: apikey,
},
Model: model,
Model: model,
MaxFunctionCallHandling: DefaultMaxAutoFunctionCall,
}
}

func (c *Client) Chat(ctx context.Context, conv []Message) ([]Message, error) {
// Create messages from conv
messages := make([]openaigo.Message, len(conv))
for i, m := range conv {
messages[i] = openaigo.Message(m)
messages[i] = openaigo.Message(m.Message)
}
// Create request
req := openaigo.ChatRequest{
Expand All @@ -116,35 +125,66 @@ func (c *Client) Chat(ctx context.Context, conv []Message) ([]Message, error) {
if err != nil {
return conv, err
}
conv = append(conv, Message(res.Choices[0].Message))
conv = append(conv, Message{
Message: res.Choices[0].Message,
})

if res.Choices[0].Message.FunctionCall != nil {
call := res.Choices[0].Message.FunctionCall
conv = append(conv, Func(call.Name(), c.Functions.Call(call)))
return c.Chat(ctx, conv)
if c.shouldCallFunction(conv) {
call := res.Choices[0].Message.FunctionCall
m := Func(call.Name(), c.Functions.Call(call))
m.autocalled = true
conv, err = c.Chat(ctx, append(conv, m))
}
}

return conv, nil
// Now clean up the auto-called flags
// so that the caller can reuse this slice to restart chat.
for i := range conv {
conv[i].autocalled = false
}

return conv, err
}

func (c *Client) shouldCallFunction(conv []Message) bool {
// Always allow if negative
if c.MaxFunctionCallHandling < 0 {
return true
}
cnt := 0
for _, m := range conv {
if m.autocalled {
cnt++
}
}
return cnt < c.MaxFunctionCallHandling
}

func User(message string) Message {
return Message{
Role: "user",
Content: message,
Message: openaigo.Message{
Role: "user",
Content: message,
},
}
}

func Func(name string, data interface{}) Message {
return Message{
Role: "function",
Name: name,
Content: fmt.Sprintf("%+v\n", data),
Message: openaigo.Message{
Role: "function",
Name: name,
Content: fmt.Sprintf("%+v\n", data),
},
}
}

func System(message string) Message {
return Message{
Role: "system",
Content: message,
Message: openaigo.Message{
Role: "system",
Content: message,
},
}
}

0 comments on commit cbc5236

Please sign in to comment.