diff --git a/answers.go b/answers.go index 662a3609c..84e1322a8 100644 --- a/answers.go +++ b/answers.go @@ -9,6 +9,7 @@ import ( type AnswerRequest struct { Documents []string `json:"documents"` + File string `json:"file"` Question string `json:"question"` SearchModel string `json:"search_model"` Model string `json:"model"` @@ -16,6 +17,7 @@ type AnswerRequest struct { Examples [][]string `json:"examples"` MaxTokens int `json:"max_tokens"` Stop []string `json:"stop"` + Temperature *float64 `json:"temperature,omitempty"` } type AnswerResponse struct { diff --git a/api.go b/api.go index d8e4c7e4a..9023a3e5c 100644 --- a/api.go +++ b/api.go @@ -46,7 +46,14 @@ func NewOrgClient(authToken, org string) *Client { func (c *Client) sendRequest(req *http.Request, v interface{}) error { req.Header.Set("Accept", "application/json; charset=utf-8") req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.authToken)) - req.Header.Set("Content-Type", "application/json; charset=utf-8") + + // Check whether Content-Type is already set, Upload Files API requires + // Content-Type == multipart/form-data + contentType := req.Header.Get("Content-Type") + if contentType == "" { + req.Header.Set("Content-Type", "application/json; charset=utf-8") + } + if len(c.idOrg) > 0 { req.Header.Set("OpenAI-Organization", c.idOrg) } diff --git a/completion.go b/completion.go index 647bd7b9b..b68889667 100644 --- a/completion.go +++ b/completion.go @@ -20,6 +20,8 @@ type CompletionRequest struct { LogProbs int `json:"logprobs,omitempty"` + Model *string `json:"model,omitempty"` + Echo bool `json:"echo,omitempty"` Stop []string `json:"stop,omitempty"` diff --git a/files.go b/files.go index 75ed5402e..69763bd20 100644 --- a/files.go +++ b/files.go @@ -1,11 +1,23 @@ package gogpt import ( + "bytes" "context" "fmt" + "io" + "mime/multipart" "net/http" + "net/url" + "os" + "strings" ) +type FileRequest struct { + FileName string `json:"file"` + FilePath string `json:"-"` + Purpose string `json:"purpose"` +} + // File struct represents an OpenAPI file type File struct { Bytes int `json:"bytes"` @@ -22,6 +34,88 @@ type FilesList struct { Files []File `json:"data"` } +// isUrl is a helper function that determines whether the given FilePath +// is a remote URL or a local file path +func isURL(path string) bool { + _, err := url.ParseRequestURI(path) + if err != nil { + return false + } + + u, err := url.Parse(path) + if err != nil || u.Scheme == "" || u.Host == "" { + return false + } + + return true +} + +// CreateFile uploads a jsonl file to GPT3 +// FilePath can be either a local file path or a URL +func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File, err error) { + var b bytes.Buffer + w := multipart.NewWriter(&b) + + var fw, pw io.Writer + pw, err = w.CreateFormField("purpose") + if err != nil { + return + } + + _, err = io.Copy(pw, strings.NewReader(request.Purpose)) + if err != nil { + return + } + + fw, err = w.CreateFormFile("file", request.FileName) + if err != nil { + return + } + + var fileData io.ReadCloser + if isURL(request.FilePath) { + var remoteFile *http.Response + remoteFile, err = http.Get(request.FilePath) + if err != nil { + return + } + + defer remoteFile.Body.Close() + + // Check server response + if remoteFile.StatusCode != http.StatusOK { + err = fmt.Errorf("error, status code: %d, message: failed to fetch file", remoteFile.StatusCode) + return + } + + fileData = remoteFile.Body + } else { + fileData, err = os.Open(request.FilePath) + if err != nil { + return + } + } + + _, err = io.Copy(fw, fileData) + if err != nil { + return + } + + w.Close() + + req, err := http.NewRequest("POST", c.fullURL("/files"), &b) + if err != nil { + return + } + + req = req.WithContext(ctx) + req.Header.Set("Content-Type", w.FormDataContentType()) + + err = c.sendRequest(req, &file) + + return +} + // ListFiles Lists the currently available files, // and provides basic information about each file such as the file name and purpose. func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) {