diff --git a/CHANGELOG.md b/CHANGELOG.md index 99b165832..789968aab 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,8 @@ - **New:** API for application load balancer - `cdn`: [v0.1.0](services/cdn/CHANGELOG.md#v010-2025-03-19) - **New:** Introduce new API for content delivery +- `core`: [v0.16.2](core/CHANGELOG.md#v0162-2025-03-21) + - **New:** If a custom http.Client is provided, the http.Transport is respected. This allows customizing the http.Client with custom timeouts or instrumentation. - `serverupdate`: [v1.0.0](services/serverupdate/CHANGELOG.md#v100-2025-03-19) - **Breaking Change:** The region is no longer specified within the client configuration. Instead, the region must be passed as a parameter to any region-specific request. - `serverbackup`: [v1.0.0](services/serverbackup/CHANGELOG.md#v100-2025-03-19) diff --git a/core/CHANGELOG.md b/core/CHANGELOG.md index 194b06fe6..cd48cbe6d 100644 --- a/core/CHANGELOG.md +++ b/core/CHANGELOG.md @@ -1,3 +1,6 @@ +## v0.16.2 (2025-03-21) +- **New:** If a custom http.Client is provided, the http.Transport is respected. This allows customizing the http.Client with custom timeouts or instrumentation. + ## v0.16.1 (2025-02-25) - **Bugfix:** STACKIT_PRIVATE_KEY and STACKIT_SERVICE_ACCOUNT_KEY can be set via environment variable or via credentials file. diff --git a/core/auth/auth.go b/core/auth/auth.go index 19e400f94..7da6c9687 100644 --- a/core/auth/auth.go +++ b/core/auth/auth.go @@ -45,7 +45,7 @@ func SetupAuth(cfg *config.Configuration) (rt http.RoundTripper, err error) { if cfg.CustomAuth != nil { return cfg.CustomAuth, nil } else if cfg.NoAuth { - noAuthRoundTripper, err := NoAuth() + noAuthRoundTripper, err := NoAuth(cfg) if err != nil { return nil, fmt.Errorf("configuring no auth client: %w", err) } @@ -98,9 +98,22 @@ func DefaultAuth(cfg *config.Configuration) (rt http.RoundTripper, err error) { // NoAuth configures a flow without authentication and returns an http.RoundTripper // that can be used to make unauthenticated requests -func NoAuth() (rt http.RoundTripper, err error) { +func NoAuth(cfgs ...*config.Configuration) (rt http.RoundTripper, err error) { noAuthConfig := clients.NoAuthFlowConfig{} noAuthRoundTripper := &clients.NoAuthFlow{} + + var cfg *config.Configuration + + if len(cfgs) > 0 { + cfg = cfgs[0] + } else { + cfg = &config.Configuration{} + } + + if cfg.HTTPClient != nil && cfg.HTTPClient.Transport != nil { + noAuthConfig.HTTPTransport = cfg.HTTPClient.Transport + } + if err := noAuthRoundTripper.Init(noAuthConfig); err != nil { return nil, fmt.Errorf("initializing client: %w", err) } @@ -130,6 +143,10 @@ func TokenAuth(cfg *config.Configuration) (http.RoundTripper, error) { ServiceAccountToken: cfg.Token, } + if cfg.HTTPClient != nil && cfg.HTTPClient.Transport != nil { + tokenCfg.HTTPTransport = cfg.HTTPClient.Transport + } + client := &clients.TokenFlow{} if err := client.Init(&tokenCfg); err != nil { return nil, fmt.Errorf("error initializing client: %w", err) @@ -187,6 +204,10 @@ func KeyAuth(cfg *config.Configuration) (http.RoundTripper, error) { BackgroundTokenRefreshContext: cfg.BackgroundTokenRefreshContext, } + if cfg.HTTPClient != nil && cfg.HTTPClient.Transport != nil { + keyCfg.HTTPTransport = cfg.HTTPClient.Transport + } + client := &clients.KeyFlow{} if err := client.Init(&keyCfg); err != nil { return nil, fmt.Errorf("error initializing client: %w", err) diff --git a/core/auth/auth_test.go b/core/auth/auth_test.go index 65647bd68..413399bde 100644 --- a/core/auth/auth_test.go +++ b/core/auth/auth_test.go @@ -6,6 +6,7 @@ import ( "crypto/x509" "encoding/json" "encoding/pem" + "net/http" "os" "reflect" "testing" @@ -125,6 +126,7 @@ func TestSetupAuth(t *testing.T) { t.Fatalf("Creating temporary file: %s", err) } defer func() { + _ = credentialsKeyFile.Close() err := os.Remove(credentialsKeyFile.Name()) if err != nil { t.Fatalf("Removing temporary file: %s", err) @@ -361,6 +363,7 @@ func TestDefaultAuth(t *testing.T) { t.Fatalf("Creating temporary file: %s", err) } defer func() { + _ = saKeyFile.Close() err := os.Remove(saKeyFile.Name()) if err != nil { t.Fatalf("Removing temporary file: %s", err) @@ -377,19 +380,13 @@ func TestDefaultAuth(t *testing.T) { t.Fatalf("Writing private key to temporary file: %s", err) } - defer func() { - err := saKeyFile.Close() - if err != nil { - t.Fatalf("Removing temporary file: %s", err) - } - }() - // create a credentials file with saKey and private key credentialsKeyFile, errs := os.CreateTemp("", "temp-*.txt") if errs != nil { t.Fatalf("Creating temporary file: %s", err) } defer func() { + _ = credentialsKeyFile.Close() err := os.Remove(credentialsKeyFile.Name()) if err != nil { t.Fatalf("Removing temporary file: %s", err) @@ -693,6 +690,28 @@ func TestNoAuth(t *testing.T) { } } +func TestNoAuthWithConfig(t *testing.T) { + for _, test := range []struct { + desc string + }{ + { + desc: "valid_case", + }, + } { + t.Run(test.desc, func(t *testing.T) { + setTemporaryHome(t) // Get the default authentication client and ensure that it's not nil + authClient, err := NoAuth(&config.Configuration{HTTPClient: http.DefaultClient}) + if err != nil { + t.Fatalf("Test returned error on valid test case: %v", err) + } + + if authClient == nil { + t.Fatalf("Client returned is nil for valid test case") + } + }) + } +} + func TestGetServiceAccountEmail(t *testing.T) { for _, test := range []struct { description string diff --git a/core/clients/key_flow.go b/core/clients/key_flow.go index 59c90bd41..448c2bd44 100644 --- a/core/clients/key_flow.go +++ b/core/clients/key_flow.go @@ -34,9 +34,9 @@ const ( // KeyFlow handles auth with SA key type KeyFlow struct { - client *http.Client + rt http.RoundTripper + authClient *http.Client config *KeyFlowConfig - doer func(req *http.Request) (resp *http.Response, err error) key *ServiceAccountKeyResponse privateKey *rsa.PrivateKey privateKeyPEM []byte @@ -53,6 +53,8 @@ type KeyFlowConfig struct { ClientRetry *RetryConfig TokenUrl string BackgroundTokenRefreshContext context.Context // Functionality is enabled if this isn't nil + HTTPTransport http.RoundTripper + AuthHTTPClient *http.Client } // TokenResponseBody is the API response @@ -124,7 +126,18 @@ func (c *KeyFlow) Init(cfg *KeyFlowConfig) error { if c.config.TokenUrl == "" { c.config.TokenUrl = tokenAPI } - c.configureHTTPClient() + + if c.rt = cfg.HTTPTransport; c.rt == nil { + c.rt = http.DefaultTransport + } + + if c.authClient = cfg.AuthHTTPClient; cfg.AuthHTTPClient == nil { + c.authClient = &http.Client{ + Transport: c.rt, + Timeout: DefaultClientTimeout, + } + } + err := c.validate() if err != nil { return err @@ -163,7 +176,7 @@ func (c *KeyFlow) SetToken(accessToken, refreshToken string) error { // Roundtrip performs the request func (c *KeyFlow) RoundTrip(req *http.Request) (*http.Response, error) { - if c.client == nil { + if c.rt == nil { return nil, fmt.Errorf("please run Init()") } @@ -172,17 +185,21 @@ func (c *KeyFlow) RoundTrip(req *http.Request) (*http.Response, error) { return nil, err } req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) - return c.doer(req) + return c.rt.RoundTrip(req) } // GetAccessToken returns a short-lived access token and saves the access and refresh tokens in the token field func (c *KeyFlow) GetAccessToken() (string, error) { - if c.client == nil { - return "", fmt.Errorf("nil http client, please run Init()") + if c.rt == nil { + return "", fmt.Errorf("nil http round tripper, please run Init()") } + var accessToken string + c.tokenMutex.RLock() - accessToken := c.token.AccessToken + if c.token != nil { + accessToken = c.token.AccessToken + } c.tokenMutex.RUnlock() accessTokenExpired, err := tokenExpired(accessToken) @@ -203,14 +220,6 @@ func (c *KeyFlow) GetAccessToken() (string, error) { return accessToken, nil } -// configureHTTPClient configures the HTTP client -func (c *KeyFlow) configureHTTPClient() { - client := &http.Client{} - client.Timeout = DefaultClientTimeout - c.client = client - c.doer = c.client.Do -} - // validate the client is configured well func (c *KeyFlow) validate() error { if c.config.ServiceAccountKey == nil { @@ -242,8 +251,12 @@ func (c *KeyFlow) validate() error { // recreateAccessToken is used to create a new access token // when the existing one isn't valid anymore func (c *KeyFlow) recreateAccessToken() error { + var refreshToken string + c.tokenMutex.RLock() - refreshToken := c.token.RefreshToken + if c.token != nil { + refreshToken = c.token.RefreshToken + } c.tokenMutex.RUnlock() refreshTokenExpired, err := tokenExpired(refreshToken) @@ -279,10 +292,6 @@ func (c *KeyFlow) createAccessToken() (err error) { // createAccessTokenWithRefreshToken creates an access token using // an existing pre-validated refresh token func (c *KeyFlow) createAccessTokenWithRefreshToken() (err error) { - if c.client == nil { - return fmt.Errorf("nil http client, please run Init()") - } - c.tokenMutex.RLock() refreshToken := c.token.RefreshToken c.tokenMutex.RUnlock() @@ -334,7 +343,8 @@ func (c *KeyFlow) requestToken(grant, assertion string) (*http.Response, error) return nil, err } req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - return c.doer(req) + + return c.authClient.Do(req) } // parseTokenResponse parses the response from the server diff --git a/core/clients/key_flow_continuous_refresh.go b/core/clients/key_flow_continuous_refresh.go index dfafc10ea..f5129aa02 100644 --- a/core/clients/key_flow_continuous_refresh.go +++ b/core/clients/key_flow_continuous_refresh.go @@ -46,9 +46,12 @@ func (refresher *continuousTokenRefresher) continuousRefreshToken() error { // Compute timestamp where we'll refresh token // Access token may be empty at this point, we have to check it var startRefreshTimestamp time.Time + var accessToken string refresher.keyFlow.tokenMutex.RLock() - accessToken := refresher.keyFlow.token.AccessToken + if refresher.keyFlow.token != nil { + accessToken = refresher.keyFlow.token.AccessToken + } refresher.keyFlow.tokenMutex.RUnlock() if accessToken == "" { startRefreshTimestamp = time.Now() diff --git a/core/clients/key_flow_continuous_refresh_test.go b/core/clients/key_flow_continuous_refresh_test.go index 960086636..76b29f55a 100644 --- a/core/clients/key_flow_continuous_refresh_test.go +++ b/core/clients/key_flow_continuous_refresh_test.go @@ -137,8 +137,9 @@ func TestContinuousRefreshToken(t *testing.T) { config: &KeyFlowConfig{ BackgroundTokenRefreshContext: ctx, }, - client: &http.Client{}, - doer: mockDo, + authClient: &http.Client{ + Transport: mockTransportFn{mockDo}, + }, token: &TokenResponseBody{ AccessToken: accessToken, RefreshToken: refreshToken, @@ -328,11 +329,13 @@ func TestContinuousRefreshTokenConcurrency(t *testing.T) { } keyFlow := &KeyFlow{ - client: &http.Client{}, config: &KeyFlowConfig{ BackgroundTokenRefreshContext: ctx, }, - doer: mockDo, + authClient: &http.Client{ + Transport: mockTransportFn{mockDo}, + }, + rt: mockTransportFn{mockDo}, token: &TokenResponseBody{ AccessToken: accessTokenFirst, RefreshToken: refreshToken, diff --git a/core/clients/key_flow_test.go b/core/clients/key_flow_test.go index 33645a2c6..b37b9593f 100644 --- a/core/clients/key_flow_test.go +++ b/core/clients/key_flow_test.go @@ -1,13 +1,18 @@ package clients import ( + "context" "crypto/rand" "crypto/rsa" "crypto/x509" + "encoding/json" "encoding/pem" + "errors" "fmt" "io" "net/http" + "net/http/httptest" + "net/url" "strings" "testing" "time" @@ -21,6 +26,8 @@ var ( testSigningKey = []byte(`Test`) ) +const testBearerToken = "eyJhbGciOiJub25lIn0.eyJleHAiOjIxNDc0ODM2NDd9." //nolint:gosec // linter false positive + func fixtureServiceAccountKey(mods ...func(*ServiceAccountKeyResponse)) *ServiceAccountKeyResponse { validUntil := time.Now().Add(time.Hour) serviceAccountKeyResponse := &ServiceAccountKeyResponse{ @@ -268,13 +275,14 @@ func TestRequestToken(t *testing.T) { for _, tt := range testCases { t.Run(tt.name, func(t *testing.T) { - mockDo := func(_ *http.Request) (resp *http.Response, err error) { - return tt.mockResponse, tt.mockError - } - c := &KeyFlow{ + authClient: &http.Client{ + Transport: mockTransportFn{func(_ *http.Request) (*http.Response, error) { + return tt.mockResponse, tt.mockError + }}, + }, config: &KeyFlowConfig{}, - doer: mockDo, + rt: http.DefaultTransport, } res, err := c.requestToken(tt.grant, tt.assertion) @@ -289,7 +297,7 @@ func TestRequestToken(t *testing.T) { if tt.expectedError != nil { if err == nil { t.Errorf("Expected error '%v' but no error was returned", tt.expectedError) - } else if tt.expectedError.Error() != err.Error() { + } else if errors.Is(err, tt.expectedError) { t.Errorf("Error is not correct. Expected %v, got %v", tt.expectedError, err) } } else { @@ -303,3 +311,211 @@ func TestRequestToken(t *testing.T) { }) } } + +func TestKeyFlow_Do(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + keyFlow *KeyFlow + handlerFn func(tb testing.TB) http.HandlerFunc + want int + wantErr bool + }{ + { + name: "success", + keyFlow: &KeyFlow{rt: http.DefaultTransport, config: &KeyFlowConfig{}}, + handlerFn: func(tb testing.TB) http.HandlerFunc { + tb.Helper() + + return func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Authorization") != "Bearer "+testBearerToken { + tb.Errorf("expected Authorization header to be 'Bearer %s', but got %s", testBearerToken, r.Header.Get("Authorization")) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = fmt.Fprintln(w, `{"status":"ok"}`) + } + }, + want: http.StatusOK, + wantErr: false, + }, + { + name: "success with code 500", + keyFlow: &KeyFlow{rt: http.DefaultTransport, config: &KeyFlowConfig{}}, + handlerFn: func(_ testing.TB) http.HandlerFunc { + return func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/html") + w.WriteHeader(http.StatusInternalServerError) + _, _ = fmt.Fprintln(w, `Internal Server Error`) + } + }, + want: http.StatusInternalServerError, + wantErr: false, + }, + { + name: "success with custom transport", + keyFlow: &KeyFlow{ + rt: mockTransportFn{ + fn: func(req *http.Request) (*http.Response, error) { + req.Header.Set("User-Agent", "custom_transport") + + return http.DefaultTransport.RoundTrip(req) + }, + }, + config: &KeyFlowConfig{}, + }, + handlerFn: func(tb testing.TB) http.HandlerFunc { + tb.Helper() + + return func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("User-Agent") != "custom_transport" { + tb.Errorf("expected User-Agent header to be 'custom_transport', but got %s", r.Header.Get("User-Agent")) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = fmt.Fprintln(w, `{"status":"ok"}`) + } + }, + want: http.StatusOK, + wantErr: false, + }, + { + name: "fail with custom proxy", + keyFlow: &KeyFlow{ + rt: &http.Transport{ + Proxy: func(_ *http.Request) (*url.URL, error) { + return nil, fmt.Errorf("proxy error") + }, + }, + config: &KeyFlowConfig{}, + }, + handlerFn: func(testing.TB) http.HandlerFunc { + return func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = fmt.Fprintln(w, `{"status":"ok"}`) + } + }, + want: 0, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + t.Cleanup(cancel) // This cancels the refresher goroutine + + privateKeyBytes, err := generatePrivateKey() + if err != nil { + t.Errorf("no error is expected, but got %v", err) + } + + tt.keyFlow.config.ServiceAccountKey = fixtureServiceAccountKey() + tt.keyFlow.config.PrivateKey = string(privateKeyBytes) + tt.keyFlow.config.BackgroundTokenRefreshContext = ctx + tt.keyFlow.authClient = &http.Client{ + Transport: mockTransportFn{ + fn: func(_ *http.Request) (*http.Response, error) { + res := httptest.NewRecorder() + res.WriteHeader(http.StatusOK) + res.Header().Set("Content-Type", "application/json") + + token := &TokenResponseBody{ + AccessToken: testBearerToken, + ExpiresIn: 2147483647, + RefreshToken: testBearerToken, + TokenType: "Bearer", + } + + if err := json.NewEncoder(res.Body).Encode(token); err != nil { + t.Logf("no error is expected, but got %v", err) + } + + return res.Result(), nil + }, + }, + } + + if err := tt.keyFlow.validate(); err != nil { + t.Errorf("no error is expected, but got %v", err) + } + + go continuousRefreshToken(tt.keyFlow) + + tokenCtx, tokenCancel := context.WithTimeout(context.Background(), 1*time.Second) + + token: + for { + select { + case <-tokenCtx.Done(): + t.Error(tokenCtx.Err()) + case <-time.After(50 * time.Millisecond): + tt.keyFlow.tokenMutex.RLock() + if tt.keyFlow.token != nil { + tt.keyFlow.tokenMutex.RUnlock() + tokenCancel() + break token + } + + tt.keyFlow.tokenMutex.RUnlock() + } + } + + server := httptest.NewServer(tt.handlerFn(t)) + t.Cleanup(server.Close) + + u, err := url.Parse(server.URL) + if err != nil { + t.Errorf("no error is expected, but got %v", err) + } + + req, err := http.NewRequest(http.MethodGet, u.String(), http.NoBody) + if err != nil { + t.Errorf("no error is expected, but got %v", err) + } + + httpClient := &http.Client{ + Transport: tt.keyFlow, + } + + res, err := httpClient.Do(req) + + if tt.wantErr { + if err == nil { + t.Errorf("error is expected, but got %v", err) + } + } else { + if err != nil { + t.Errorf("no error is expected, but got %v", err) + } + + if res.StatusCode != tt.want { + t.Errorf("expected status code %d, but got %d", tt.want, res.StatusCode) + } + + // Defer discard and close the body + t.Cleanup(func() { + if _, err := io.Copy(io.Discard, res.Body); err != nil { + t.Errorf("no error is expected, but got %v", err) + } + + if err := res.Body.Close(); err != nil { + t.Errorf("no error is expected, but got %v", err) + } + }) + } + }) + } +} + +type mockTransportFn struct { + fn func(req *http.Request) (*http.Response, error) +} + +func (m mockTransportFn) RoundTrip(req *http.Request) (*http.Response, error) { + return m.fn(req) +} diff --git a/core/clients/no_auth_flow.go b/core/clients/no_auth_flow.go index 4db1bf156..b0b2406ce 100644 --- a/core/clients/no_auth_flow.go +++ b/core/clients/no_auth_flow.go @@ -6,14 +6,15 @@ import ( ) type NoAuthFlow struct { - client *http.Client + rt http.RoundTripper config *NoAuthFlowConfig } // NoAuthFlowConfig holds the configuration for the unauthenticated flow type NoAuthFlowConfig struct { // Deprecated: retry options were removed to reduce complexity of the client. If this functionality is needed, you can provide your own custom HTTP client. - ClientRetry *RetryConfig + ClientRetry *RetryConfig + HTTPTransport http.RoundTripper } // GetConfig returns the flow configuration @@ -24,18 +25,21 @@ func (c *NoAuthFlow) GetConfig() NoAuthFlowConfig { return *c.config } -func (c *NoAuthFlow) Init(_ NoAuthFlowConfig) error { +func (c *NoAuthFlow) Init(cfg NoAuthFlowConfig) error { c.config = &NoAuthFlowConfig{} - c.client = &http.Client{ - Timeout: DefaultClientTimeout, + + if c.rt = cfg.HTTPTransport; c.rt == nil { + c.rt = http.DefaultTransport } + return nil } -// Roundtrip performs the request +// RoundTrip performs the request func (c *NoAuthFlow) RoundTrip(req *http.Request) (*http.Response, error) { - if c.client == nil { + if c.rt == nil { return nil, fmt.Errorf("please run Init()") } - return c.client.Do(req) + + return c.rt.RoundTrip(req) } diff --git a/core/clients/no_auth_flow_test.go b/core/clients/no_auth_flow_test.go index c3d6f15ad..5d7abc293 100644 --- a/core/clients/no_auth_flow_test.go +++ b/core/clients/no_auth_flow_test.go @@ -21,6 +21,7 @@ func TestNoAuthFlow_Init(t *testing.T) { wantErr bool }{ {"ok", args{context.Background(), NoAuthFlowConfig{}}, false}, + {"with transport", args{context.Background(), NoAuthFlowConfig{HTTPTransport: http.DefaultTransport}}, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -33,74 +34,134 @@ func TestNoAuthFlow_Init(t *testing.T) { } func TestNoAuthFlow_Do(t *testing.T) { - type fields struct { - client *http.Client - } - type args struct{} + t.Parallel() + tests := []struct { - name string - fields fields - args args - want int - wantErr bool + name string + noAuthFlow *NoAuthFlow + handlerFn func(tb testing.TB) http.HandlerFunc + want int + wantErr bool }{ { - name: "fail", - fields: fields{nil}, - args: args{}, - want: 0, - wantErr: true, + name: "success with rt", + noAuthFlow: &NoAuthFlow{http.DefaultTransport, &NoAuthFlowConfig{}}, + handlerFn: func(_ testing.TB) http.HandlerFunc { + return func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = fmt.Fprintln(w, `{"status":"ok"}`) + } + }, + want: http.StatusOK, + wantErr: false, }, { - name: "success", - fields: fields{ - &http.Client{}, + name: "success with code 500", + noAuthFlow: &NoAuthFlow{http.DefaultTransport, &NoAuthFlowConfig{}}, + handlerFn: func(_ testing.TB) http.HandlerFunc { + return func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/html") + w.WriteHeader(http.StatusInternalServerError) + _, _ = fmt.Fprintln(w, `Internal Server Error`) + } + }, + want: http.StatusInternalServerError, + wantErr: false, + }, + { + name: "success with custom transport", + noAuthFlow: &NoAuthFlow{ + mockTransportFn{ + fn: func(req *http.Request) (*http.Response, error) { + req.Header.Set("User-Agent", "custom_transport") + + return http.DefaultTransport.RoundTrip(req) + }, + }, + &NoAuthFlowConfig{}, + }, + handlerFn: func(tb testing.TB) http.HandlerFunc { + tb.Helper() + + return func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("User-Agent") != "custom_transport" { + tb.Errorf("expected User-Agent header to be 'custom_transport', but got %s", r.Header.Get("User-Agent")) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = fmt.Fprintln(w, `{"status":"ok"}`) + } }, - args: args{}, want: http.StatusOK, wantErr: false, }, + { + name: "fail with custom proxy", + noAuthFlow: &NoAuthFlow{ + &http.Transport{ + Proxy: func(_ *http.Request) (*url.URL, error) { + return nil, fmt.Errorf("proxy error") + }, + }, + &NoAuthFlowConfig{}, + }, + handlerFn: func(testing.TB) http.HandlerFunc { + return func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = fmt.Fprintln(w, `{"status":"ok"}`) + } + }, + want: 0, + wantErr: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := &NoAuthFlow{ - client: tt.fields.client, - } - handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _, _ = fmt.Fprintln(w, `{"status":"ok"}`) - }) - server := httptest.NewServer(handler) - defer server.Close() + server := httptest.NewServer(tt.handlerFn(t)) + t.Cleanup(server.Close) + u, err := url.Parse(server.URL) if err != nil { - t.Error(err) - return + t.Errorf("no error is expected, but got %v", err) } + req, err := http.NewRequest(http.MethodGet, u.String(), http.NoBody) if err != nil { - t.Error(err) - return + t.Errorf("no error is expected, but got %v", err) } - got, err := c.RoundTrip(req) - if err == nil { + + httpClient := &http.Client{ + Transport: tt.noAuthFlow, + } + + res, err := httpClient.Do(req) + + if tt.wantErr { + if err == nil { + t.Errorf("error is expected, but got %v", err) + } + } else { + if err != nil { + t.Errorf("no error is expected, but got %v", err) + } + + if res.StatusCode != tt.want { + t.Errorf("expected status code %d, but got %d", tt.want, res.StatusCode) + } + // Defer discard and close the body - defer func() { - if _, discardErr := io.Copy(io.Discard, got.Body); discardErr != nil && err == nil { - err = discardErr + t.Cleanup(func() { + if _, err := io.Copy(io.Discard, res.Body); err != nil { + t.Errorf("no error is expected, but got %v", err) } - if closeErr := got.Body.Close(); closeErr != nil && err == nil { - err = closeErr + + if err := res.Body.Close(); err != nil { + t.Errorf("no error is expected, but got %v", err) } - }() - } - if (err != nil) != tt.wantErr { - t.Errorf("NoAuthFlow.Do() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != nil && got.StatusCode != tt.want { - t.Errorf("NoAuthFlow.Do() = %v, want %v", got.StatusCode, tt.want) + }) } }) } diff --git a/core/clients/token_flow.go b/core/clients/token_flow.go index 3748c4dc8..ac1ff779a 100644 --- a/core/clients/token_flow.go +++ b/core/clients/token_flow.go @@ -13,7 +13,7 @@ const ( // TokenFlow handles auth with SA static token type TokenFlow struct { - client *http.Client + rt http.RoundTripper config *TokenFlowConfig } @@ -23,7 +23,8 @@ type TokenFlowConfig struct { ServiceAccountEmail string ServiceAccountToken string // Deprecated: retry options were removed to reduce complexity of the client. If this functionality is needed, you can provide your own custom HTTP client. - ClientRetry *RetryConfig + ClientRetry *RetryConfig + HTTPTransport http.RoundTripper } // GetConfig returns the flow configuration @@ -36,15 +37,12 @@ func (c *TokenFlow) GetConfig() TokenFlowConfig { func (c *TokenFlow) Init(cfg *TokenFlowConfig) error { c.config = cfg - c.configureHTTPClient() - return c.validate() -} -// configureHTTPClient configures the HTTP client -func (c *TokenFlow) configureHTTPClient() { - client := &http.Client{} - client.Timeout = DefaultClientTimeout - c.client = client + if c.rt = cfg.HTTPTransport; c.rt == nil { + c.rt = http.DefaultTransport + } + + return c.validate() } // validate the client is configured well @@ -55,11 +53,11 @@ func (c *TokenFlow) validate() error { return nil } -// Roundtrip performs the request +// RoundTrip performs the request func (c *TokenFlow) RoundTrip(req *http.Request) (*http.Response, error) { - if c.client == nil { + if c.rt == nil { return nil, fmt.Errorf("please run Init()") } req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.ServiceAccountToken)) - return c.client.Do(req) + return c.rt.RoundTrip(req) } diff --git a/core/clients/token_flow_test.go b/core/clients/token_flow_test.go index 8294709cd..9e389c91f 100644 --- a/core/clients/token_flow_test.go +++ b/core/clients/token_flow_test.go @@ -22,6 +22,10 @@ func TestTokenFlow_Init(t *testing.T) { {"ok", args{&TokenFlowConfig{ ServiceAccountToken: "efg", }}, false}, + {"with transport", args{&TokenFlowConfig{ + ServiceAccountToken: "efg", + HTTPTransport: http.DefaultTransport, + }}, false}, {"error 1", args{&TokenFlowConfig{ ServiceAccountToken: "", }}, true}, @@ -50,62 +54,152 @@ func TestTokenFlow_Init(t *testing.T) { } func TestTokenFlow_Do(t *testing.T) { - type fields struct { - client *http.Client - config *TokenFlowConfig - } - type args struct{} + t.Parallel() + tests := []struct { - name string - fields fields - args args - want int - wantErr bool + name string + tokenFlow *TokenFlow + handlerFn func(tb testing.TB) http.HandlerFunc + want int + wantErr bool }{ - {"fail", fields{nil, nil}, args{}, 0, true}, - {"success", fields{&http.Client{}, &TokenFlowConfig{}}, args{}, http.StatusOK, false}, + { + name: "success", + tokenFlow: &TokenFlow{http.DefaultTransport, &TokenFlowConfig{ + ServiceAccountToken: "efg", + }}, + handlerFn: func(tb testing.TB) http.HandlerFunc { + tb.Helper() + + return func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Authorization") != "Bearer efg" { + tb.Errorf("expected Authorization header to be 'Bearer efg', but got %s", r.Header.Get("Authorization")) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = fmt.Fprintln(w, `{"status":"ok"}`) + } + }, + want: http.StatusOK, + wantErr: false, + }, + { + name: "success with code 500", + tokenFlow: &TokenFlow{http.DefaultTransport, &TokenFlowConfig{ + ServiceAccountToken: "efg", + }}, + handlerFn: func(testing.TB) http.HandlerFunc { + return func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/html") + w.WriteHeader(http.StatusInternalServerError) + _, _ = fmt.Fprintln(w, `Internal Server Error`) + } + }, + want: http.StatusInternalServerError, + wantErr: false, + }, + { + name: "success with custom transport", + tokenFlow: &TokenFlow{ + mockTransportFn{ + fn: func(req *http.Request) (*http.Response, error) { + req.Header.Set("User-Agent", "custom_transport") + + return http.DefaultTransport.RoundTrip(req) + }, + }, + &TokenFlowConfig{ + ServiceAccountToken: "efg", + }, + }, + handlerFn: func(tb testing.TB) http.HandlerFunc { + tb.Helper() + + return func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Authorization") != "Bearer efg" { + tb.Errorf("expected Authorization header to be 'Bearer efg', but got %s", r.Header.Get("Authorization")) + } + + if r.Header.Get("User-Agent") != "custom_transport" { + tb.Errorf("expected User-Agent header to be 'custom_transport', but got %s", r.Header.Get("User-Agent")) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = fmt.Fprintln(w, `{"status":"ok"}`) + } + }, + want: http.StatusOK, + wantErr: false, + }, + { + name: "fail with custom proxy", + tokenFlow: &TokenFlow{ + &http.Transport{ + Proxy: func(_ *http.Request) (*url.URL, error) { + return nil, fmt.Errorf("proxy error") + }, + }, + &TokenFlowConfig{ + ServiceAccountToken: "efg", + }, + }, + handlerFn: func(testing.TB) http.HandlerFunc { + return func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = fmt.Fprintln(w, `{"status":"ok"}`) + } + }, + want: 0, + wantErr: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := &TokenFlow{ - client: tt.fields.client, - config: tt.fields.config, - } - handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _, _ = fmt.Fprintln(w, `{"status":"ok"}`) - }) - server := httptest.NewServer(handler) - defer server.Close() + server := httptest.NewServer(tt.handlerFn(t)) + t.Cleanup(server.Close) + u, err := url.Parse(server.URL) if err != nil { - t.Error(err) - return + t.Errorf("no error is expected, but got %v", err) } + req, err := http.NewRequest(http.MethodGet, u.String(), http.NoBody) if err != nil { - t.Error(err) - return + t.Errorf("no error is expected, but got %v", err) + } + + httpClient := &http.Client{ + Transport: tt.tokenFlow, } - got, err := c.RoundTrip(req) - if err == nil { + + res, err := httpClient.Do(req) + + if tt.wantErr { + if err == nil { + t.Errorf("error is expected, but got %v", err) + } + } else { + if err != nil { + t.Errorf("no error is expected, but got %v", err) + } + + if res.StatusCode != tt.want { + t.Errorf("expected status code %d, but got %d", tt.want, res.StatusCode) + } + // Defer discard and close the body - defer func() { - if _, discardErr := io.Copy(io.Discard, got.Body); discardErr != nil && err == nil { - err = discardErr + t.Cleanup(func() { + if _, err := io.Copy(io.Discard, res.Body); err != nil { + t.Errorf("no error is expected, but got %v", err) } - if closeErr := got.Body.Close(); closeErr != nil && err == nil { - err = closeErr + + if err := res.Body.Close(); err != nil { + t.Errorf("no error is expected, but got %v", err) } - }() - } - if (err != nil) != tt.wantErr { - t.Errorf("TokenFlow.Do() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != nil && got.StatusCode != tt.want { - t.Errorf("TokenFlow.Do() = %v, want %v", got.StatusCode, tt.want) + }) } }) }