Skip to content

Commit

Permalink
Unit tests for the util package (#14)
Browse files Browse the repository at this point in the history
Co-authored-by: Nestor Oprysk <noprysk-ua@singlestore.com>
  • Loading branch information
noprysk-ua and Nestor Oprysk committed Jun 1, 2023
1 parent 148043b commit 5d91e78
Show file tree
Hide file tree
Showing 10 changed files with 361 additions and 9 deletions.
4 changes: 3 additions & 1 deletion internal/provider/util/converters.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package util

import (
"strings"

otypes "github.com/deepmap/oapi-codegen/pkg/types"
"github.com/hashicorp/terraform-plugin-framework/types"
"github.com/singlestore-labs/singlestore-go/management"
Expand Down Expand Up @@ -58,7 +60,7 @@ func WorkspaceStateString(wgs types.String) *management.WorkspaceState {
management.WorkspaceStateSUSPENDED,
management.WorkspaceStateTERMINATED,
} {
if wgs.ValueString() == string(s) {
if strings.EqualFold(wgs.ValueString(), string(s)) {
return &s
}
}
Expand Down
79 changes: 79 additions & 0 deletions internal/provider/util/converters_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package util_test

import (
"testing"

"github.com/google/uuid"
"github.com/hashicorp/terraform-plugin-framework/types"
"github.com/singlestore-labs/singlestore-go/management"
"github.com/singlestore-labs/terraform-provider-singlestoredb/internal/provider/util"
"github.com/stretchr/testify/require"
)

func TestMaybeString(t *testing.T) {
require.Nil(t, util.MaybeString(types.StringNull()))
require.Nil(t, util.MaybeString(types.StringUnknown()))
s := "bar"
require.Equal(t, &s, util.MaybeString(types.StringValue(s)))
}

func TestToString(t *testing.T) {
require.Empty(t, util.ToString(types.StringNull()))
require.Empty(t, util.ToString(types.StringUnknown()))
s := "buzz"
require.Equal(t, s, util.ToString(types.StringValue(s)))
}

func TestMaybeStringValue(t *testing.T) {
require.Equal(t, types.StringNull(), util.MaybeStringValue(nil))
s := "fizz"
require.Equal(t, types.StringValue(s), util.MaybeStringValue(&s))
}

func TestMaybeBool(t *testing.T) {
require.Nil(t, util.MaybeBool(types.BoolNull()))
require.Nil(t, util.MaybeBool(types.BoolUnknown()))
require.True(t, util.Deref(util.MaybeBool(types.BoolValue(true))))
}

func TestMaybeBoolValue(t *testing.T) {
require.Equal(t, types.BoolNull(), util.MaybeBoolValue(nil))
require.Equal(t, types.BoolValue(true), util.MaybeBoolValue(util.Ptr(true)))
}

func TestUUIDStringValue(t *testing.T) {
id := "9966fccf-5116-437e-a34f-008ee32e8d94"
require.Equal(t, types.StringValue(id), util.UUIDStringValue(uuid.MustParse(id)))
}

func TestStringFirewallRanges(t *testing.T) {
a := "192.168.5.10/24"
b := "192.168.5.10/32"
result := util.StringFirewallRanges([]types.String{types.StringValue(a), types.StringValue(b)})
require.Equal(t, []string{a, b}, result)
}

func TestFirewallRanges(t *testing.T) {
a := "192.168.5.10/24"
b := "192.168.5.10/32"
result := util.FirewallRanges(nil)
require.Empty(t, result)
result = util.FirewallRanges(util.Ptr([]string{a, b}))
require.Equal(t, []types.String{types.StringValue(a), types.StringValue(b)}, result)
}

func TestWorkspaceGroupStateStringValue(t *testing.T) {
state := management.ACTIVE
require.Equal(t, string(state), util.WorkspaceGroupStateStringValue(state).ValueString())
}

func TestWorkspaceStateString(t *testing.T) {
require.Nil(t, util.WorkspaceStateString(types.StringValue("something")))
active := string(management.WorkspaceStateACTIVE)
require.Equal(t, management.WorkspaceStateACTIVE, util.Deref(util.WorkspaceStateString(types.StringValue(active))))
}

func TestWorkspaceStateStringValue(t *testing.T) {
state := management.WorkspaceStateACTIVE
require.Equal(t, string(state), util.WorkspaceStateStringValue(state).ValueString())
}
40 changes: 40 additions & 0 deletions internal/provider/util/httpclient.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,54 @@
package util

import (
"errors"
"fmt"
"io"
"net/http"

"github.com/hashicorp/go-retryablehttp"
)

const respReadLimit = int64(4096)

// NewHTTPClient creates an HTTP client for the Terraform provider.
func NewHTTPClient() *http.Client {
result := retryablehttp.NewClient()
result.ErrorHandler = HandleError

return result.StandardClient()
}

var _ retryablehttp.ErrorHandler = HandleError

// HandleError overrides the default behavior of the library
// by exposing the underlying issue because the underlying issue may be useful, e.g.,
// a customer running out of credits and still closing the body.
//
// This function is called if retries are expired, containing the last status
// from the http library. If not specified, default behavior for the library is
// to close the body and return an error indicating how many tries were attempted.
//
// The function is called only when server returns 500s.
func HandleError(resp *http.Response, ierr error, numTries int) (*http.Response, error) {
defer resp.Body.Close()

body, err := io.ReadAll(io.LimitReader(resp.Body, respReadLimit))
if err != nil {
result := fmt.Sprintf("giving up after %d attempts, unable to read response body, status code: %s, error: %s", numTries, http.StatusText(resp.StatusCode), err)

return nil, maybeWithExtraError(result, ierr)
}

result := fmt.Sprintf("giving up after %d attempts, unexpected status code: %s, response: %s", numTries, http.StatusText(resp.StatusCode), body)

return nil, maybeWithExtraError(result, ierr)
}

func maybeWithExtraError(main string, extra error) error {
if extra == nil {
return errors.New(main)
}

return fmt.Errorf("%s: %w", main, extra)
}
85 changes: 85 additions & 0 deletions internal/provider/util/httpclient_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package util_test

import (
"context"
"errors"
"io"
"net/http"
"net/http/httptest"
"strconv"
"testing"
"testing/iotest"

"github.com/singlestore-labs/terraform-provider-singlestoredb/internal/provider/util"
"github.com/stretchr/testify/require"
)

func TestHTTPClientStatusOK(t *testing.T) {
body := []byte("fizz")
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write(body)
require.NoError(t, err)
}))
t.Cleanup(server.Close)

client := util.NewHTTPClient()
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL, nil)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
result, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, body, result)
}

func TestHTTPClientStatusInternalServerError(t *testing.T) {
body := []byte("fizz")
attempts := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempts++
w.WriteHeader(http.StatusInternalServerError)
_, err := w.Write(body)
require.NoError(t, err)
}))
t.Cleanup(server.Close)

client := util.NewHTTPClient()
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL, nil)
require.NoError(t, err)
_, err = client.Do(req) //nolint: bodyclose
require.ErrorContains(t, err, string(body), "returns an error on 500s")
require.ErrorContains(t, err, http.StatusText(http.StatusInternalServerError))
require.Greater(t, attempts, 1, "retries 500s")
}

func TestHTTPClientStatusConflict(t *testing.T) {
body := []byte("insufficient credits")
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusConflict)
_, err := w.Write(body)
require.NoError(t, err)
}))
t.Cleanup(server.Close)

client := util.NewHTTPClient()
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL, nil)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err, "returns no error and a body on not 500s")
defer resp.Body.Close()
result, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, body, result)
}

func TestHandleError(t *testing.T) {
readErr := errors.New("failed to read")
extra := errors.New("extra")
numTries := 3
_, err := util.HandleError(&http.Response{Body: io.NopCloser(iotest.ErrReader(readErr))}, extra, numTries) //nolint: bodyclose
require.ErrorContains(t, err, readErr.Error())
require.ErrorContains(t, err, extra.Error())
require.ErrorContains(t, err, strconv.Itoa(numTries))
}
4 changes: 2 additions & 2 deletions internal/provider/util/statuscoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func StatusOK(resp StatusCoder, ierr error,
Detail: "An unsuccessful status code occurred when calling SingleStore API. " +
config.InvalidAPIKeyErrorDetail +
config.CreateProviderIssueIfNotClearErrorDetail +
"\n\nSingleStore client response body: " + maybeBody(resp),
"\n\nSingleStore client response body: " + MaybeBody(resp),
}
}

Expand All @@ -56,7 +56,7 @@ func ReturnNilOnNotFound(code int) (bool, *SummaryWithDetailError) {
return false, nil
}

func maybeBody(resp StatusCoder) string {
func MaybeBody(resp StatusCoder) string {
v := reflect.ValueOf(resp)

if v.Kind() == reflect.Ptr {
Expand Down
52 changes: 52 additions & 0 deletions internal/provider/util/statuscoder_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package util_test

import (
"errors"
"net/http"
"testing"

"github.com/singlestore-labs/singlestore-go/management"
Expand All @@ -13,8 +15,58 @@ func TestStatusOK(t *testing.T) {
Body: []byte("foo-bar-buzz-yes"),
}
result := util.StatusOK(input, nil)
require.NotNil(t, result)
require.Contains(t, result.Detail, string(input.Body))

result = util.StatusOK(&input, nil)
require.NotNil(t, result)
require.Contains(t, result.Detail, string(input.Body), "should deref pointer")

ierr := errors.New("foo")
result = util.StatusOK(nil, ierr)
require.NotNil(t, result)
require.Contains(t, result.Detail, ierr.Error())

result = util.StatusOK(management.GetV1RegionsResponse{
HTTPResponse: &http.Response{StatusCode: http.StatusNotFound},
}, nil)
require.NotNil(t, result)

result = util.StatusOK(management.GetV1RegionsResponse{
HTTPResponse: &http.Response{StatusCode: http.StatusNotFound},
}, nil, util.ReturnNilOnNotFound)
require.Nil(t, result)

result = util.StatusOK(management.GetV1RegionsResponse{
HTTPResponse: &http.Response{StatusCode: http.StatusInternalServerError},
}, nil, util.ReturnNilOnNotFound)
require.NotNil(t, result)

result = util.StatusOK(management.GetV1RegionsResponse{
HTTPResponse: &http.Response{StatusCode: http.StatusOK},
}, nil)
require.Nil(t, result)
}

type statusCoderNotStruct int

func (sc statusCoderNotStruct) StatusCode() int {
return int(sc)
}

type statusCoderWithoutBody struct {
Code int
}

func (sc statusCoderWithoutBody) StatusCode() int {
return sc.Code
}

func TestMaybeBody(t *testing.T) {
require.Empty(t, util.MaybeBody(statusCoderNotStruct(0)))
require.Empty(t, util.MaybeBody(statusCoderWithoutBody{Code: 0}))
body := "buzz"
require.Equal(t, body, util.MaybeBody(management.GetV1RegionsResponse{
Body: []byte(body),
}))
}
6 changes: 3 additions & 3 deletions internal/provider/util/timevalidator.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"github.com/hashicorp/terraform-plugin-framework/schema/validator"
)

var _ validator.String = timeValidator{}
var _ validator.String = &timeValidator{}

// timeValidator validates that a string Attribute's value matches the expected time format.
type timeValidator struct {
Expand All @@ -31,7 +31,7 @@ func (v timeValidator) MarkdownDescription(ctx context.Context) string {
}

// Validate performs the validation.
func (v timeValidator) ValidateString(ctx context.Context, request validator.StringRequest, response *validator.StringResponse) {
func (v *timeValidator) ValidateString(ctx context.Context, request validator.StringRequest, response *validator.StringResponse) {
if request.ConfigValue.IsNull() || request.ConfigValue.IsUnknown() {
return
}
Expand All @@ -56,7 +56,7 @@ func (v timeValidator) ValidateString(ctx context.Context, request validator.Str
//
// Null (unconfigured) and unknown (known after apply) values are skipped.
func NewTimeValidator() validator.String {
return timeValidator{}
return &timeValidator{}
}

// parseTime parses time in RFC3339.
Expand Down
Loading

0 comments on commit 5d91e78

Please sign in to comment.