Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions answers.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ import (

type AnswerRequest struct {
Documents []string `json:"documents"`
File string `json:"file"`
Question string `json:"question"`
SearchModel string `json:"search_model"`
Model string `json:"model"`
ExamplesContext string `json:"examples_context"`
Examples [][]string `json:"examples"`
MaxTokens int `json:"max_tokens"`
Stop []string `json:"stop"`
Temperature *float64 `json:"temperature,omitempty"`
}

type AnswerResponse struct {
Expand Down
9 changes: 8 additions & 1 deletion api.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,14 @@ func NewOrgClient(authToken, org string) *Client {
func (c *Client) sendRequest(req *http.Request, v interface{}) error {
req.Header.Set("Accept", "application/json; charset=utf-8")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.authToken))
req.Header.Set("Content-Type", "application/json; charset=utf-8")

// Check whether Content-Type is already set, Upload Files API requires
// Content-Type == multipart/form-data
contentType := req.Header.Get("Content-Type")
if contentType == "" {
req.Header.Set("Content-Type", "application/json; charset=utf-8")
}

if len(c.idOrg) > 0 {
req.Header.Set("OpenAI-Organization", c.idOrg)
}
Expand Down
2 changes: 2 additions & 0 deletions completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ type CompletionRequest struct {

LogProbs int `json:"logprobs,omitempty"`

Model *string `json:"model,omitempty"`

Echo bool `json:"echo,omitempty"`
Stop []string `json:"stop,omitempty"`

Expand Down
94 changes: 94 additions & 0 deletions files.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
package gogpt

import (
"bytes"
"context"
"fmt"
"io"
"mime/multipart"
"net/http"
"net/url"
"os"
"strings"
)

type FileRequest struct {
FileName string `json:"file"`
FilePath string `json:"-"`
Purpose string `json:"purpose"`
}

// File struct represents an OpenAPI file
type File struct {
Bytes int `json:"bytes"`
Expand All @@ -22,6 +34,88 @@ type FilesList struct {
Files []File `json:"data"`
}

// isUrl is a helper function that determines whether the given FilePath
// is a remote URL or a local file path
func isURL(path string) bool {
_, err := url.ParseRequestURI(path)
if err != nil {
return false
}

u, err := url.Parse(path)
if err != nil || u.Scheme == "" || u.Host == "" {
return false
}

return true
}

// CreateFile uploads a jsonl file to GPT3
// FilePath can be either a local file path or a URL
func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File, err error) {
var b bytes.Buffer
w := multipart.NewWriter(&b)

var fw, pw io.Writer
pw, err = w.CreateFormField("purpose")
if err != nil {
return
}

_, err = io.Copy(pw, strings.NewReader(request.Purpose))
if err != nil {
return
}

fw, err = w.CreateFormFile("file", request.FileName)
if err != nil {
return
}

var fileData io.ReadCloser
if isURL(request.FilePath) {
var remoteFile *http.Response
remoteFile, err = http.Get(request.FilePath)
if err != nil {
return
}

defer remoteFile.Body.Close()

// Check server response
if remoteFile.StatusCode != http.StatusOK {
err = fmt.Errorf("error, status code: %d, message: failed to fetch file", remoteFile.StatusCode)
return
}

fileData = remoteFile.Body
} else {
fileData, err = os.Open(request.FilePath)
if err != nil {
return
}
}

_, err = io.Copy(fw, fileData)
if err != nil {
return
}

w.Close()

req, err := http.NewRequest("POST", c.fullURL("/files"), &b)
if err != nil {
return
}

req = req.WithContext(ctx)
req.Header.Set("Content-Type", w.FormDataContentType())

err = c.sendRequest(req, &file)

return
}

// ListFiles Lists the currently available files,
// and provides basic information about each file such as the file name and purpose.
func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) {
Expand Down