Skip to content

Commit

Permalink
Refactor and improve CSP token exchange and identify expired tokens
Browse files Browse the repository at this point in the history
Signed-off-by: Pete Wall <pwall@vmware.com>
  • Loading branch information
Pete Wall committed Jul 19, 2022
1 parent 14a8f8f commit a3d143d
Show file tree
Hide file tree
Showing 15 changed files with 248 additions and 161 deletions.
16 changes: 8 additions & 8 deletions cmd/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@ type TokenServices interface {
}

//go:generate counterfeiter . TokenServicesInitializer
type TokenServicesInitializer func(cspHost string) (TokenServices, error)
type TokenServicesInitializer func(cspHost string) TokenServices

var InitializeTokenServices TokenServicesInitializer = func(cspHost string) (TokenServices, error) {
return csp.NewTokenServices(cspHost, Client)
var InitializeTokenServices TokenServicesInitializer = func(cspHost string) TokenServices {
return &csp.TokenServices{
CSPHost: cspHost,
Client: Client,
}
}

func GetRefreshToken(cmd *cobra.Command, args []string) error {
tokenServices, err := InitializeTokenServices(viper.GetString("csp.host"))
if err != nil {
return fmt.Errorf("failed to initialize token services: %w", err)
}
tokenServices := InitializeTokenServices(viper.GetString("csp.host"))

apiToken := viper.GetString("csp.api-token")
if apiToken == "" {
Expand All @@ -36,7 +36,7 @@ func GetRefreshToken(cmd *cobra.Command, args []string) error {

claims, err := tokenServices.Redeem(apiToken)
if err != nil {
return fmt.Errorf("failed to exchange api token: %w", err)
return err
}

viper.Set("csp.refresh-token", claims.Token)
Expand Down
16 changes: 2 additions & 14 deletions cmd/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ var _ = Describe("Auth", func() {
tokenServices = &cmdfakes.FakeTokenServices{}

initializer = &cmdfakes.FakeTokenServicesInitializer{}
initializer.Returns(tokenServices, nil)
initializer.Returns(tokenServices)
InitializeTokenServices = initializer.Spy
})

Expand All @@ -50,18 +50,6 @@ var _ = Describe("Auth", func() {
Expect(tokenServices.RedeemArgsForCall(0)).To(Equal("my-csp-api-token"))
})

Context("fails to initialize token services", func() {
BeforeEach(func() {
initializer.Returns(nil, fmt.Errorf("initializer failed"))
})

It("returns an error", func() {
err := GetRefreshToken(nil, []string{})
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(Equal("failed to initialize token services: initializer failed"))
})
})

Context("fails to exchange api token", func() {
BeforeEach(func() {
tokenServices.RedeemReturns(nil, fmt.Errorf("redeem failed"))
Expand All @@ -70,7 +58,7 @@ var _ = Describe("Auth", func() {
It("returns an error", func() {
err := GetRefreshToken(nil, []string{})
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(Equal("failed to exchange api token: redeem failed"))
Expect(err.Error()).To(Equal("redeem failed"))
})
})
})
Expand Down
7 changes: 4 additions & 3 deletions cmd/cmdfakes/fake_token_services.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

29 changes: 13 additions & 16 deletions cmd/cmdfakes/fake_token_services_initializer.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

86 changes: 28 additions & 58 deletions internal/csp/token_services.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,30 @@ package csp

import (
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strings"

"github.com/golang-jwt/jwt"
"github.com/vmware-labs/marketplace-cli/v2/pkg"
)

type TokenServices struct {
keyfunc jwt.Keyfunc
keyPem string
CSPHost string
Client pkg.HTTPClient
}

type RedeemResponse struct {
AccessToken string `json:"access_token"`
AccessToken string `json:"access_token"`
StatusCode int `json:"statusCode,omitempty"`
ModuleCode int `json:"moduleCode,omitempty"`
Metadata interface{} `json:"metadata,omitempty"` // I don't know what the appropriate type for this field is
TraceID string `json:"traceId,omitempty"`
CSPErrorCode string `json:"cspErrorCode,omitempty"`
Message string `json:"message,omitempty"`
RequestID string `json:"requestId,omitempty"`
}

func (csp *TokenServices) Redeem(refreshToken string) (*Claims, error) {
Expand All @@ -30,79 +37,37 @@ func (csp *TokenServices) Redeem(refreshToken string) (*Claims, error) {
"refresh_token": []string{refreshToken},
}

retried := false
resp, err := csp.Client.PostForm(requestURL, formData)
if err != nil {
return nil, fmt.Errorf("failed to redeem token: %w", err)
}

if resp.StatusCode == http.StatusServiceUnavailable {
retried = true
resp, err = csp.Client.PostForm(requestURL, formData)
if err != nil {
return nil, fmt.Errorf("failed to redeem token on second attempt: %w", err)
}
}

if resp.StatusCode != http.StatusOK {
if !retried {
return nil, fmt.Errorf("failed to exchange refresh token for access token: %s", resp.Status)
}
return nil, fmt.Errorf("failed twice to exchange refresh token for access token: %s", resp.Status)
}

var body RedeemResponse
err = json.NewDecoder(resp.Body).Decode(&body)
if err != nil {
return nil, fmt.Errorf("failed to parse redeem response: %w", err)
}

claims := &Claims{}
_, _ = jwt.ParseWithClaims(body.AccessToken, claims, func(t *jwt.Token) (interface{}, error) {
// token was just retrieved, no need to validate
return "not a valid key anyway", nil
})
// err != nil here are the token validation has failed

claims.Token = body.AccessToken
return claims, nil
}

func (csp *TokenServices) Validate(jwtAccessToken string) (*Claims, error) {
claims := &Claims{}
_, err := jwt.ParseWithClaims(jwtAccessToken, claims, csp.keyfunc)
return claims, err
}

func (csp *TokenServices) VerificationKey() string {
return csp.keyPem
}

func NewTokenServices(cspHost string, client pkg.HTTPClient) (*TokenServices, error) {
keyData, err := fetchPublicKey(cspHost, client)
if err != nil {
return nil, err
if resp.StatusCode == http.StatusBadRequest && strings.Contains(body.Message, "invalid_grant: Invalid refresh token") {
return nil, errors.New("the CSP API token is invalid or expired")
}

publicKey, err := jwt.ParseRSAPublicKeyFromPEM(keyData)
if err != nil {
return nil, fmt.Errorf("failed to make public key structure: %w", err)
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to exchange refresh token for access token: %s: %s", resp.Status, body.Message)
}

rsa := func(*jwt.Token) (interface{}, error) {
return publicKey, nil
claims := &Claims{}
token, err := jwt.ParseWithClaims(body.AccessToken, claims, csp.GetPublicKey)
if err != nil {
return nil, fmt.Errorf("invalid token returned from CSP: %w", err)
}

return &TokenServices{
CSPHost: cspHost,
Client: client,
keyfunc: rsa,
keyPem: string(keyData),
}, nil
claims.Token = token.Raw
return claims, nil
}

func fetchPublicKey(cspHost string, client pkg.HTTPClient) ([]byte, error) {
resp, err := client.Get(pkg.MakeURL(cspHost, "/csp/gateway/am/api/auth/token-public-key", nil))
func (csp *TokenServices) GetPublicKey(*jwt.Token) (interface{}, error) {
resp, err := csp.Client.Get(pkg.MakeURL(csp.CSPHost, "/csp/gateway/am/api/auth/token-public-key", nil))
if err != nil {
return nil, fmt.Errorf("failed to get CSP Public key: %w", err)
}
Expand All @@ -123,5 +88,10 @@ func fetchPublicKey(cspHost string, client pkg.HTTPClient) ([]byte, error) {
return nil, fmt.Errorf("public key value was not in the expected format")
}

return []byte(s), nil
publicKey, err := jwt.ParseRSAPublicKeyFromPEM([]byte(s))
if err != nil {
return nil, fmt.Errorf("failed to parse CSP public key: %w", err)
}

return publicKey, nil
}
Loading

0 comments on commit a3d143d

Please sign in to comment.