Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 118 additions & 33 deletions api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,30 @@ func TestEdits(t *testing.T) {
}
}

// TestModeration Tests the moderations endpoint of the API using the mocked server.
func TestModerations(t *testing.T) {
// create the test server
var err error
ts := OpenAITestServer()
ts.Start()
defer ts.Close()

client := NewClient(testAPIToken)
ctx := context.Background()
client.BaseURL = ts.URL + "/v1"

// create an edit request
model := "text-moderation-stable"
moderationReq := ModerationRequest{
Model: &model,
Input: "I want to kill them.",
}
_, err = client.Moderations(ctx, moderationReq)
if err != nil {
t.Fatalf("Moderation error: %v", err)
}
}

func TestEmbedding(t *testing.T) {
embeddedModels := []EmbeddingModel{
AdaSimilarity,
Expand Down Expand Up @@ -160,6 +184,25 @@ func TestEmbedding(t *testing.T) {
}
}

func TestImages(t *testing.T) {
// create the test server
var err error
ts := OpenAITestServer()
ts.Start()
defer ts.Close()

client := NewClient(testAPIToken)
ctx := context.Background()
client.BaseURL = ts.URL + "/v1"

req := ImageRequest{}
req.Prompt = "Lorem ipsum"
_, err = client.CreateImage(ctx, req)
if err != nil {
t.Fatalf("CreateImage error: %v", err)
}
}

// getEditBody Returns the body of the request to create an edit.
func getEditBody(r *http.Request) (EditsRequest, error) {
edit := EditsRequest{}
Expand Down Expand Up @@ -261,6 +304,21 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, string(resBytes))
}

// getCompletionBody Returns the body of the request to create a completion.
func getCompletionBody(r *http.Request) (CompletionRequest, error) {
completion := CompletionRequest{}
// read the request body
reqBody, err := ioutil.ReadAll(r.Body)
if err != nil {
return CompletionRequest{}, err
}
err = json.Unmarshal(reqBody, &completion)
if err != nil {
return CompletionRequest{}, err
}
return completion, nil
}

// handleImageEndpoint Handles the images endpoint by the test server.
func handleImageEndpoint(w http.ResponseWriter, r *http.Request) {
var err error
Expand Down Expand Up @@ -296,34 +354,78 @@ func handleImageEndpoint(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, string(resBytes))
}

// getCompletionBody Returns the body of the request to create a completion.
func getCompletionBody(r *http.Request) (CompletionRequest, error) {
completion := CompletionRequest{}
// getImageBody Returns the body of the request to create a image.
func getImageBody(r *http.Request) (ImageRequest, error) {
image := ImageRequest{}
// read the request body
reqBody, err := ioutil.ReadAll(r.Body)
if err != nil {
return CompletionRequest{}, err
return ImageRequest{}, err
}
err = json.Unmarshal(reqBody, &completion)
err = json.Unmarshal(reqBody, &image)
if err != nil {
return CompletionRequest{}, err
return ImageRequest{}, err
}
return completion, nil
return image, nil
}

// getImageBody Returns the body of the request to create a image.
func getImageBody(r *http.Request) (ImageRequest, error) {
image := ImageRequest{}
// handleModerationEndpoint Handles the moderation endpoint by the test server.
func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) {
var err error
var resBytes []byte

// completions only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
var moderationReq ModerationRequest
if moderationReq, err = getModerationBody(r); err != nil {
http.Error(w, "could not read request", http.StatusInternalServerError)
return
}

resCat := ResultCategories{}
resCatScore := ResultCategoryScores{}
switch {
case strings.Contains(moderationReq.Input, "kill"):
resCat = ResultCategories{Violence: true}
resCatScore = ResultCategoryScores{Violence: 1}
case strings.Contains(moderationReq.Input, "hate"):
resCat = ResultCategories{Hate: true}
resCatScore = ResultCategoryScores{Hate: 1}
case strings.Contains(moderationReq.Input, "suicide"):
resCat = ResultCategories{SelfHarm: true}
resCatScore = ResultCategoryScores{SelfHarm: 1}
case strings.Contains(moderationReq.Input, "porn"):
resCat = ResultCategories{Sexual: true}
resCatScore = ResultCategoryScores{Sexual: 1}
}

result := Result{Categories: resCat, CategoryScores: resCatScore, Flagged: true}

res := ModerationResponse{
ID: strconv.Itoa(int(time.Now().Unix())),
Model: *moderationReq.Model,
}
res.Results = append(res.Results, result)

resBytes, _ = json.Marshal(res)
fmt.Fprintln(w, string(resBytes))
}

// getModerationBody Returns the body of the request to do a moderation.
func getModerationBody(r *http.Request) (ModerationRequest, error) {
moderation := ModerationRequest{}
// read the request body
reqBody, err := ioutil.ReadAll(r.Body)
if err != nil {
return ImageRequest{}, err
return ModerationRequest{}, err
}
err = json.Unmarshal(reqBody, &image)
err = json.Unmarshal(reqBody, &moderation)
if err != nil {
return ImageRequest{}, err
return ModerationRequest{}, err
}
return image, nil
return moderation, nil
}

// numTokens Returns the number of GPT-3 encoded tokens in the given text.
Expand All @@ -335,25 +437,6 @@ func numTokens(s string) int {
return int(float32(len(s)) / 4)
}

func TestImages(t *testing.T) {
// create the test server
var err error
ts := OpenAITestServer()
ts.Start()
defer ts.Close()

client := NewClient(testAPIToken)
ctx := context.Background()
client.BaseURL = ts.URL + "/v1"

req := ImageRequest{}
req.Prompt = "Lorem ipsum"
_, err = client.CreateImage(ctx, req)
if err != nil {
t.Fatalf("CreateImage error: %v", err)
}
}

// OpenAITestServer Creates a mocked OpenAI server which can pretend to handle requests during testing.
func OpenAITestServer() *httptest.Server {
return httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -373,6 +456,8 @@ func OpenAITestServer() *httptest.Server {
case "/v1/completions":
handleCompletionEndpoint(w, r)
return
case "/v1/moderations":
handleModerationEndpoint(w, r)
case "/v1/images/generations":
handleImageEndpoint(w, r)
// TODO: implement the other endpoints
Expand Down