diff --git a/api_test.go b/api_test.go index 1e1c5d086..8848d48d4 100644 --- a/api_test.go +++ b/api_test.go @@ -121,6 +121,30 @@ func TestEdits(t *testing.T) { } } +// TestModeration Tests the moderations endpoint of the API using the mocked server. +func TestModerations(t *testing.T) { + // create the test server + var err error + ts := OpenAITestServer() + ts.Start() + defer ts.Close() + + client := NewClient(testAPIToken) + ctx := context.Background() + client.BaseURL = ts.URL + "/v1" + + // create an edit request + model := "text-moderation-stable" + moderationReq := ModerationRequest{ + Model: &model, + Input: "I want to kill them.", + } + _, err = client.Moderations(ctx, moderationReq) + if err != nil { + t.Fatalf("Moderation error: %v", err) + } +} + func TestEmbedding(t *testing.T) { embeddedModels := []EmbeddingModel{ AdaSimilarity, @@ -160,6 +184,25 @@ func TestEmbedding(t *testing.T) { } } +func TestImages(t *testing.T) { + // create the test server + var err error + ts := OpenAITestServer() + ts.Start() + defer ts.Close() + + client := NewClient(testAPIToken) + ctx := context.Background() + client.BaseURL = ts.URL + "/v1" + + req := ImageRequest{} + req.Prompt = "Lorem ipsum" + _, err = client.CreateImage(ctx, req) + if err != nil { + t.Fatalf("CreateImage error: %v", err) + } +} + // getEditBody Returns the body of the request to create an edit. func getEditBody(r *http.Request) (EditsRequest, error) { edit := EditsRequest{} @@ -261,6 +304,21 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, string(resBytes)) } +// getCompletionBody Returns the body of the request to create a completion. +func getCompletionBody(r *http.Request) (CompletionRequest, error) { + completion := CompletionRequest{} + // read the request body + reqBody, err := ioutil.ReadAll(r.Body) + if err != nil { + return CompletionRequest{}, err + } + err = json.Unmarshal(reqBody, &completion) + if err != nil { + return CompletionRequest{}, err + } + return completion, nil +} + // handleImageEndpoint Handles the images endpoint by the test server. func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { var err error @@ -296,34 +354,78 @@ func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, string(resBytes)) } -// getCompletionBody Returns the body of the request to create a completion. -func getCompletionBody(r *http.Request) (CompletionRequest, error) { - completion := CompletionRequest{} +// getImageBody Returns the body of the request to create a image. +func getImageBody(r *http.Request) (ImageRequest, error) { + image := ImageRequest{} // read the request body reqBody, err := ioutil.ReadAll(r.Body) if err != nil { - return CompletionRequest{}, err + return ImageRequest{}, err } - err = json.Unmarshal(reqBody, &completion) + err = json.Unmarshal(reqBody, &image) if err != nil { - return CompletionRequest{}, err + return ImageRequest{}, err } - return completion, nil + return image, nil } -// getImageBody Returns the body of the request to create a image. -func getImageBody(r *http.Request) (ImageRequest, error) { - image := ImageRequest{} +// handleModerationEndpoint Handles the moderation endpoint by the test server. +func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) { + var err error + var resBytes []byte + + // completions only accepts POST requests + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + var moderationReq ModerationRequest + if moderationReq, err = getModerationBody(r); err != nil { + http.Error(w, "could not read request", http.StatusInternalServerError) + return + } + + resCat := ResultCategories{} + resCatScore := ResultCategoryScores{} + switch { + case strings.Contains(moderationReq.Input, "kill"): + resCat = ResultCategories{Violence: true} + resCatScore = ResultCategoryScores{Violence: 1} + case strings.Contains(moderationReq.Input, "hate"): + resCat = ResultCategories{Hate: true} + resCatScore = ResultCategoryScores{Hate: 1} + case strings.Contains(moderationReq.Input, "suicide"): + resCat = ResultCategories{SelfHarm: true} + resCatScore = ResultCategoryScores{SelfHarm: 1} + case strings.Contains(moderationReq.Input, "porn"): + resCat = ResultCategories{Sexual: true} + resCatScore = ResultCategoryScores{Sexual: 1} + } + + result := Result{Categories: resCat, CategoryScores: resCatScore, Flagged: true} + + res := ModerationResponse{ + ID: strconv.Itoa(int(time.Now().Unix())), + Model: *moderationReq.Model, + } + res.Results = append(res.Results, result) + + resBytes, _ = json.Marshal(res) + fmt.Fprintln(w, string(resBytes)) +} + +// getModerationBody Returns the body of the request to do a moderation. +func getModerationBody(r *http.Request) (ModerationRequest, error) { + moderation := ModerationRequest{} // read the request body reqBody, err := ioutil.ReadAll(r.Body) if err != nil { - return ImageRequest{}, err + return ModerationRequest{}, err } - err = json.Unmarshal(reqBody, &image) + err = json.Unmarshal(reqBody, &moderation) if err != nil { - return ImageRequest{}, err + return ModerationRequest{}, err } - return image, nil + return moderation, nil } // numTokens Returns the number of GPT-3 encoded tokens in the given text. @@ -335,25 +437,6 @@ func numTokens(s string) int { return int(float32(len(s)) / 4) } -func TestImages(t *testing.T) { - // create the test server - var err error - ts := OpenAITestServer() - ts.Start() - defer ts.Close() - - client := NewClient(testAPIToken) - ctx := context.Background() - client.BaseURL = ts.URL + "/v1" - - req := ImageRequest{} - req.Prompt = "Lorem ipsum" - _, err = client.CreateImage(ctx, req) - if err != nil { - t.Fatalf("CreateImage error: %v", err) - } -} - // OpenAITestServer Creates a mocked OpenAI server which can pretend to handle requests during testing. func OpenAITestServer() *httptest.Server { return httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -373,6 +456,8 @@ func OpenAITestServer() *httptest.Server { case "/v1/completions": handleCompletionEndpoint(w, r) return + case "/v1/moderations": + handleModerationEndpoint(w, r) case "/v1/images/generations": handleImageEndpoint(w, r) // TODO: implement the other endpoints