Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ go 1.24.0
require (
github.com/fatih/color v1.18.0
github.com/goccy/go-yaml v1.19.2
github.com/golang-jwt/jwt/v5 v5.3.0
github.com/golang-jwt/jwt/v5 v5.3.1
github.com/google/go-cmp v0.7.0
github.com/google/uuid v1.6.0
github.com/inhies/go-bytesize v0.0.0-20220417184213-4913239db9cf
Expand All @@ -15,7 +15,7 @@ require (
github.com/spf13/cobra v1.10.2
github.com/spf13/pflag v1.0.10
github.com/spf13/viper v1.21.0
github.com/stackitcloud/stackit-sdk-go/core v0.20.1
github.com/stackitcloud/stackit-sdk-go/core v0.21.1
github.com/stackitcloud/stackit-sdk-go/services/alb v0.9.0
github.com/stackitcloud/stackit-sdk-go/services/authorization v0.11.0
github.com/stackitcloud/stackit-sdk-go/services/cdn v1.9.1
Expand Down
8 changes: 4 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,8 @@ github.com/gofrs/flock v0.13.0/go.mod h1:jxeyy9R1auM5S6JYDBhDt+E2TCo7DkratH4Pgi8
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY=
github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
Expand Down Expand Up @@ -600,8 +600,8 @@ github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU=
github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY=
github.com/ssgreg/nlreturn/v2 v2.2.1 h1:X4XDI7jstt3ySqGU86YGAURbxw3oTDPK9sPEi6YEwQ0=
github.com/ssgreg/nlreturn/v2 v2.2.1/go.mod h1:E/iiPB78hV7Szg2YfRgyIrk1AD6JVMTRkkxBiELzh2I=
github.com/stackitcloud/stackit-sdk-go/core v0.20.1 h1:odiuhhRXmxvEvnVTeZSN9u98edvw2Cd3DcnkepncP3M=
github.com/stackitcloud/stackit-sdk-go/core v0.20.1/go.mod h1:fqto7M82ynGhEnpZU6VkQKYWYoFG5goC076JWXTUPRQ=
github.com/stackitcloud/stackit-sdk-go/core v0.21.1 h1:Y/PcAgM7DPYMNqum0MLv4n1mF9ieuevzcCIZYQfm3Ts=
github.com/stackitcloud/stackit-sdk-go/core v0.21.1/go.mod h1:osMglDby4csGZ5sIfhNyYq1bS1TxIdPY88+skE/kkmI=
github.com/stackitcloud/stackit-sdk-go/services/alb v0.9.0 h1:P24WoKPt14dfUiUJ4czIv+IiVmdCFQGrKgVtw23fxNg=
github.com/stackitcloud/stackit-sdk-go/services/alb v0.9.0/go.mod h1:63XvbCslxdfWEp+0Q4OSzQrpbY4kvVODOiIEAEEVH8M=
github.com/stackitcloud/stackit-sdk-go/services/authorization v0.11.0 h1:4YFY5PG4vP/NiEP1uxCwh+kQHEU7iHG6syuFD7NPqcw=
Expand Down
118 changes: 54 additions & 64 deletions internal/pkg/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@ import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/json"
"encoding/pem"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strconv"
"testing"
"time"
Expand Down Expand Up @@ -235,58 +238,28 @@ func TestAuthenticationConfig(t *testing.T) {

func TestInitKeyFlow(t *testing.T) {
tests := []struct {
description string
accessTokenSet bool
refreshToken string
saKey string
privateKeySet bool
tokenEndpoint string
isValid bool
description string
saKey string
privateKeySet bool
isValid bool
}{
{
description: "base",
accessTokenSet: true,
refreshToken: "refresh_token",
saKey: testServiceAccountKey,
privateKeySet: true,
tokenEndpoint: "token_url",
isValid: true,
description: "base",
saKey: testServiceAccountKey,
privateKeySet: true,
isValid: true,
},
{
description: "invalid_service_account_key",
accessTokenSet: true,
refreshToken: "refresh_token",
saKey: "",
privateKeySet: true,
tokenEndpoint: "token_url",
isValid: false,
},
{
description: "invalid_private_key",
accessTokenSet: true,
refreshToken: "refresh_token",
saKey: testServiceAccountKey,
privateKeySet: false,
tokenEndpoint: "token_url",
isValid: false,
},
{
description: "invalid_access_token",
accessTokenSet: false,
refreshToken: "refresh_token",
saKey: testServiceAccountKey,
privateKeySet: true,
tokenEndpoint: "token_url",
isValid: false,
description: "invalid_service_account_key",
saKey: "",
privateKeySet: true,
isValid: false,
},
{
description: "empty_refresh_token",
accessTokenSet: false,
refreshToken: "",
saKey: testServiceAccountKey,
privateKeySet: true,
tokenEndpoint: "token_url",
isValid: false,
description: "no_private_key_set",
saKey: testServiceAccountKey,
privateKeySet: false,
isValid: false,
},
}

Expand All @@ -297,13 +270,11 @@ func TestInitKeyFlow(t *testing.T) {
authFields := make(map[authFieldKey]string)
var accessToken string
var err error
if tt.accessTokenSet {
accessTokenJWT := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(timestamp)})
accessToken, err = accessTokenJWT.SignedString(testSigningKey)
if err != nil {
t.Fatalf("Get test access token as string: %s", err)
}
accessTokenJWT := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(timestamp)})
accessToken, err = accessTokenJWT.SignedString(testSigningKey)
if err != nil {
t.Fatalf("Get test access token as string: %s", err)
}
if tt.privateKeySet {
privateKey, err := generatePrivateKey()
Expand All @@ -313,16 +284,42 @@ func TestInitKeyFlow(t *testing.T) {
authFields[PRIVATE_KEY] = string(privateKey)
}
authFields[ACCESS_TOKEN] = accessToken
authFields[REFRESH_TOKEN] = tt.refreshToken
authFields[SERVICE_ACCOUNT_KEY] = tt.saKey
authFields[TOKEN_CUSTOM_ENDPOINT] = tt.tokenEndpoint

// Mock server to avoid HTTP calls
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
resp := clients.TokenResponseBody{
AccessToken: accessToken,
ExpiresIn: 3600,
TokenType: "Bearer",
}
jsonResp, err := json.Marshal(resp)
if err != nil {
t.Fatalf("Failed to marshal json: %v", err)
}
_, err = w.Write(jsonResp)
if err != nil {
t.Fatalf("Failed to write response: %v", err)
}
}))
defer server.Close()
authFields[TOKEN_CUSTOM_ENDPOINT] = server.URL

err = SetAuthFieldMap(authFields)
if err != nil {
t.Fatalf("Failed to set in auth storage: %v", err)
}

keyFlowWithStorage, err := initKeyFlowWithStorage()
if err != nil {
if !tt.isValid {
return
}
t.Fatalf("Expected no error but error was returned: %v", err)
}

getAccessToken, err := keyFlowWithStorage.keyFlow.GetAccessToken()
if !tt.isValid {
if err == nil {
t.Fatalf("Expected error but no error was returned")
Expand All @@ -331,15 +328,8 @@ func TestInitKeyFlow(t *testing.T) {
if err != nil {
t.Fatalf("Expected no error but error was returned: %v", err)
}
expectedToken := &clients.TokenResponseBody{
AccessToken: accessToken,
ExpiresIn: int(timestamp.Unix()),
RefreshToken: tt.refreshToken,
Scope: "",
TokenType: "Bearer",
}
if !cmp.Equal(*expectedToken, keyFlowWithStorage.keyFlow.GetToken()) {
t.Errorf("The returned result is wrong. Expected %+v, got %+v", expectedToken, keyFlowWithStorage.keyFlow.GetToken())
if !cmp.Equal(accessToken, getAccessToken) {
t.Errorf("The returned result is wrong. Expected %+v, got %+v", accessToken, getAccessToken)
}
}
})
Expand Down
35 changes: 13 additions & 22 deletions internal/pkg/auth/service_account.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
type keyFlowInterface interface {
GetAccessToken() (string, error)
GetConfig() clients.KeyFlowConfig
GetToken() clients.TokenResponseBody
RoundTrip(*http.Request) (*http.Response, error)
}

Expand All @@ -32,7 +31,7 @@ var _ http.RoundTripper = &keyFlowWithStorage{}

// AuthenticateServiceAccount checks the type of the provided roundtripper,
// authenticates the CLI accordingly and store the credentials.
// For the key flow, it fetches an access and refresh token from the Service Account API.
// For the key flow, it fetches an access token from the Service Account API.
// For the token flow, it just stores the provided token and doesn't check if it is valid.
// It returns the email associated with the service account
// If disableWriting is set to true the credentials are not stored on disk (keyring, file).
Expand All @@ -56,7 +55,6 @@ func AuthenticateServiceAccount(p *print.Printer, rt http.RoundTripper, disableW
}

authFields[ACCESS_TOKEN] = accessToken
authFields[REFRESH_TOKEN] = flow.GetToken().RefreshToken
authFields[SERVICE_ACCOUNT_KEY] = string(saKeyBytes)
authFields[PRIVATE_KEY] = flow.GetConfig().PrivateKey
case tokenFlowInterface:
Expand Down Expand Up @@ -100,8 +98,6 @@ func AuthenticateServiceAccount(p *print.Printer, rt http.RoundTripper, disableW
// initKeyFlowWithStorage initializes the keyFlow from the SDK and creates a keyFlowWithStorage struct that uses that keyFlow
func initKeyFlowWithStorage() (*keyFlowWithStorage, error) {
authFields := map[authFieldKey]string{
ACCESS_TOKEN: "",
REFRESH_TOKEN: "",
SERVICE_ACCOUNT_KEY: "",
PRIVATE_KEY: "",
TOKEN_CUSTOM_ENDPOINT: "",
Expand All @@ -110,12 +106,6 @@ func initKeyFlowWithStorage() (*keyFlowWithStorage, error) {
if err != nil {
return nil, fmt.Errorf("get from auth storage: %w", err)
}
if authFields[ACCESS_TOKEN] == "" {
return nil, fmt.Errorf("access token not set")
}
if authFields[REFRESH_TOKEN] == "" {
return nil, fmt.Errorf("refresh token not set")
}

var serviceAccountKey = &clients.ServiceAccountKeyResponse{}
err = json.Unmarshal([]byte(authFields[SERVICE_ACCOUNT_KEY]), serviceAccountKey)
Expand All @@ -134,10 +124,6 @@ func initKeyFlowWithStorage() (*keyFlowWithStorage, error) {
if err != nil {
return nil, fmt.Errorf("initialize key flow: %w", err)
}
err = keyFlow.SetToken(authFields[ACCESS_TOKEN], authFields[REFRESH_TOKEN])
if err != nil {
return nil, fmt.Errorf("set access and refresh token: %w", err)
}

// create keyFlowWithStorage roundtripper that stores the credentials after executing a request
keyFlowWithStorage := &keyFlowWithStorage{
Expand All @@ -146,21 +132,26 @@ func initKeyFlowWithStorage() (*keyFlowWithStorage, error) {
return keyFlowWithStorage, nil
}

// The keyFlowWithStorage Roundtrip executes the keyFlow roundtrip and then stores the access and refresh tokens
// The keyFlowWithStorage Roundtrip executes the keyFlow roundtrip and then stores the access token
func (kf *keyFlowWithStorage) RoundTrip(req *http.Request) (*http.Response, error) {
resp, err := kf.keyFlow.RoundTrip(req)

token := kf.keyFlow.GetToken()
accessToken := token.AccessToken
refreshToken := token.RefreshToken
accessToken, getTokenErr := kf.keyFlow.GetAccessToken()
if getTokenErr != nil {
return nil, fmt.Errorf("get access token: %w", getTokenErr)
}

tokenValues := map[authFieldKey]string{
ACCESS_TOKEN: accessToken,
REFRESH_TOKEN: refreshToken,
ACCESS_TOKEN: accessToken,
}

storageErr := SetAuthFieldMap(tokenValues)
if storageErr != nil {
return nil, fmt.Errorf("set access and refresh token in the storage: %w", err)
// If the request was successful, but storing the token failed we still return the response and a nil error
if err == nil {
return resp, nil
}
return nil, fmt.Errorf("set access token in the storage: %w", err)
}

return resp, err
Expand Down
Loading