diff --git a/cmd/gitops/root/cmd.go b/cmd/gitops/root/cmd.go index a3d6d53662..379598db67 100644 --- a/cmd/gitops/root/cmd.go +++ b/cmd/gitops/root/cmd.go @@ -135,7 +135,7 @@ func RootCmd(client *resty.Client) *cobra.Command { rootCmd.AddCommand(uninstall.Cmd) rootCmd.AddCommand(version.Cmd) rootCmd.AddCommand(flux.Cmd) - rootCmd.AddCommand(ui.Command()) + rootCmd.AddCommand(ui.NewCommand()) rootCmd.AddCommand(get.GetCommand(&options.endpoint, client)) rootCmd.AddCommand(add.GetCommand(&options.endpoint, client)) rootCmd.AddCommand(delete.DeleteCommand(&options.endpoint, client)) diff --git a/cmd/gitops/ui/cmd.go b/cmd/gitops/ui/cmd.go index 0e4664119e..c217502585 100644 --- a/cmd/gitops/ui/cmd.go +++ b/cmd/gitops/ui/cmd.go @@ -5,8 +5,8 @@ import ( "github.com/weaveworks/weave-gitops/cmd/gitops/ui/run" ) -// Command returns the `ui` command and its subcommands. -func Command() *cobra.Command { +// NewCommand returns the `ui` command and its subcommands. +func NewCommand() *cobra.Command { cmd := &cobra.Command{ Use: "ui", Short: "Manages Gitops UI", @@ -17,7 +17,7 @@ func Command() *cobra.Command { Args: cobra.MinimumNArgs(1), } - cmd.AddCommand(run.Command()) + cmd.AddCommand(run.NewCommand()) return cmd } diff --git a/cmd/gitops/ui/run/cmd.go b/cmd/gitops/ui/run/cmd.go index 51bd7132a1..0cbfcee000 100644 --- a/cmd/gitops/ui/run/cmd.go +++ b/cmd/gitops/ui/run/cmd.go @@ -5,6 +5,7 @@ import ( "embed" "fmt" "io/fs" + "net" "net/http" "net/url" "os" @@ -41,13 +42,13 @@ type OIDCAuthenticationOptions struct { ClientID string ClientSecret string RedirectURL string - CookieDuration string + CookieDuration time.Duration } var options Options -// Command returns the `ui run` command -func Command() *cobra.Command { +// NewCommand returns the `ui run` command +func NewCommand() *cobra.Command { cmd := &cobra.Command{ Use: "run [--log]", Short: "Runs gitops ui", @@ -67,7 +68,7 @@ func Command() *cobra.Command { cmd.Flags().StringVar(&options.OIDC.ClientID, "oidc-client-id", "", "The client ID for the OpenID Connect client") cmd.Flags().StringVar(&options.OIDC.ClientSecret, "oidc-client-secret", "", "The client secret to use with OpenID Connect issuer") cmd.Flags().StringVar(&options.OIDC.RedirectURL, "oidc-redirect-url", "", "The OAuth2 redirect URL") - cmd.Flags().StringVar(&options.OIDC.CookieDuration, "oidc-cookie-duration", "1h", "The duration of the ID token cookie. It should be set in the format: number + time unit (s,m,h) e.g., 20m") + cmd.Flags().DurationVar(&options.OIDC.CookieDuration, "oidc-cookie-duration", time.Hour, "The duration of the ID token cookie. It should be set in the format: number + time unit (s,m,h) e.g., 20m") } return cmd @@ -111,7 +112,7 @@ func runCmd(cmd *cobra.Command, args []string) error { profilesConfig := server.NewProfilesConfig(rawClient, options.HelmRepoNamespace, options.HelmRepoName) - var authConfig *auth.AuthConfig + var authServer *auth.AuthServer if server.AuthEnabled() { _, err := url.Parse(options.OIDC.IssuerURL) @@ -129,27 +130,31 @@ func runCmd(cmd *cobra.Command, args []string) error { oidcIssueSecureCookies = true } - oidcCookieDuration, err := time.ParseDuration(options.OIDC.CookieDuration) + srv, err := auth.NewAuthServer(cmd.Context(), appConfig.Logger, http.DefaultClient, + auth.AuthConfig{ + OIDCConfig: auth.OIDCConfig{ + IssuerURL: options.OIDC.IssuerURL, + ClientID: options.OIDC.ClientID, + ClientSecret: options.OIDC.ClientSecret, + RedirectURL: options.OIDC.RedirectURL, + }, + CookieConfig: auth.CookieConfig{ + CookieDuration: options.OIDC.CookieDuration, + IssueSecureCookies: oidcIssueSecureCookies, + }, + }, + ) if err != nil { - return fmt.Errorf("invalid cookie duration: %w", err) + return fmt.Errorf("could not create auth server: %w", err) } - cfg, err := auth.NewAuthConfig(cmd.Context(), options.OIDC.IssuerURL, - options.OIDC.ClientID, options.OIDC.ClientSecret, options.OIDC.RedirectURL, - oidcIssueSecureCookies, oidcCookieDuration, http.DefaultClient, - appConfig.Logger) - if err != nil { - return fmt.Errorf("could not create auth config: %w", err) - } - - cfg.Logger().Info("Registering callback route") - // Register /callback handler with mux - auth.RegisterAuthHandler(mux, "/oauth2", cfg) + appConfig.Logger.Info("Registering callback route") + auth.RegisterAuthServer(mux, "/oauth2", srv) - authConfig = cfg + authServer = srv } - appAndProfilesHandlers, err := server.NewHandlers(context.Background(), &server.Config{AppConfig: appConfig, ProfilesConfig: profilesConfig, AuthConfig: authConfig}) + appAndProfilesHandlers, err := server.NewHandlers(context.Background(), &server.Config{AppConfig: appConfig, ProfilesConfig: profilesConfig, AuthServer: authServer}) if err != nil { return fmt.Errorf("could not create handler: %w", err) } @@ -164,7 +169,7 @@ func runCmd(cmd *cobra.Command, args []string) error { // Redirect all non-file requests to index.html, where the JS routing will take over. if extension == "" { if server.AuthEnabled() { - auth.WithWebAuth(redirector, authConfig).ServeHTTP(w, req) + auth.WithWebAuth(redirector, authServer).ServeHTTP(w, req) } else { redirector(w, req) } @@ -173,7 +178,7 @@ func runCmd(cmd *cobra.Command, args []string) error { assetHandler.ServeHTTP(w, req) })) - addr := "0.0.0.0:" + options.Port + addr := net.JoinHostPort("0.0.0.0", options.Port) srv := &http.Server{ Addr: addr, Handler: mux, diff --git a/pkg/server/auth/auth.go b/pkg/server/auth/auth.go index 2f7832e74f..5de31f6a59 100644 --- a/pkg/server/auth/auth.go +++ b/pkg/server/auth/auth.go @@ -19,16 +19,16 @@ const ( // token. RefreshTokenCookieName = "refresh_token" // ScopeProfile is the "profile" scope - ScopeProfile = "profile" + scopeProfile = "profile" // ScopeEmail is the "email" scope - ScopeEmail = "email" + scopeEmail = "email" ) -// RegisterAuthHandler registers the /callback route under a specified prefix. +// RegisterAuthServer registers the /callback route under a specified prefix. // This route is called by the OIDC Provider in order to pass back state after // the authentication flow completes. -func RegisterAuthHandler(mux *http.ServeMux, prefix string, cfg *AuthConfig) { - mux.Handle(prefix+"/callback", cfg.callback()) +func RegisterAuthServer(mux *http.ServeMux, prefix string, srv *AuthServer) { + mux.Handle(prefix+"/callback", srv) } type principalCtxKey struct{} @@ -47,17 +47,19 @@ func WithPrincipal(ctx context.Context, p *UserPrincipal) context.Context { // WithAPIAuth middleware adds auth validation to API handlers. // // Unauthorized requests will be denied with a 401 status code. -func WithAPIAuth(next http.Handler, cfg *AuthConfig) http.Handler { - cookieAuth := NewJWTCookiePrincipalGetter(cfg.logger, - cfg.verifier(), IDTokenCookieName) - headerAuth := NewJWTAuthorizationHeaderPrincipalGetter(cfg.logger, cfg.verifier()) +func WithAPIAuth(next http.Handler, srv *AuthServer) http.Handler { + cookieAuth := NewJWTCookiePrincipalGetter(srv.logger, + srv.verifier(), IDTokenCookieName) + headerAuth := NewJWTAuthorizationHeaderPrincipalGetter(srv.logger, srv.verifier()) + multi := MultiAuthPrincipal{cookieAuth, headerAuth} return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - principal, err := MultiAuthPrincipal{cookieAuth, headerAuth}.Principal(r) - if err != nil || principal == nil { - cfg.logger.Error(err, "failed to get principal") + principal, err := multi.Principal(r) + if err != nil { + srv.logger.Error(err, "failed to get principal") + } - rw.Header().Set("WWW-Authenticate", `Bearer realm="Weave GitOps"`) + if principal == nil || err != nil { http.Error(rw, "Authentication required", http.StatusUnauthorized) return } @@ -71,17 +73,20 @@ func WithAPIAuth(next http.Handler, cfg *AuthConfig) http.Handler { // Unauthorized requests will be redirected to the OIDC Provider. // It is meant to be used with routes that serve HTML content, // not API routes. -func WithWebAuth(next http.Handler, cfg *AuthConfig) http.Handler { - cookieAuth := NewJWTCookiePrincipalGetter(cfg.logger, - cfg.verifier(), IDTokenCookieName) - headerAuth := NewJWTAuthorizationHeaderPrincipalGetter(cfg.logger, cfg.verifier()) +func WithWebAuth(next http.Handler, srv *AuthServer) http.Handler { + cookieAuth := NewJWTCookiePrincipalGetter(srv.logger, + srv.verifier(), IDTokenCookieName) + headerAuth := NewJWTAuthorizationHeaderPrincipalGetter(srv.logger, srv.verifier()) + multi := MultiAuthPrincipal{cookieAuth, headerAuth} return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - principal, err := MultiAuthPrincipal{cookieAuth, headerAuth}.Principal(r) - if err != nil || principal == nil { - cfg.logger.Error(err, "failed to get principal") + principal, err := multi.Principal(r) + if err != nil { + srv.logger.Error(err, "failed to get principal") + } - startAuthFlow(rw, r, cfg) + if principal == nil || err != nil { + startAuthFlow(rw, r, srv) return } @@ -89,8 +94,8 @@ func WithWebAuth(next http.Handler, cfg *AuthConfig) http.Handler { }) } -func startAuthFlow(rw http.ResponseWriter, r *http.Request, cfg *AuthConfig) { - nonce, err := generateNonce(32) +func startAuthFlow(rw http.ResponseWriter, r *http.Request, srv *AuthServer) { + nonce, err := generateNonce() if err != nil { http.Error(rw, fmt.Sprintf("failed to generate nonce: %v", err), http.StatusInternalServerError) return @@ -109,17 +114,17 @@ func startAuthFlow(rw http.ResponseWriter, r *http.Request, cfg *AuthConfig) { var scopes []string // "openid", "offline_access" and "email" scopes added by default - scopes = append(scopes, ScopeProfile) - authCodeUrl := cfg.oauth2Config(scopes).AuthCodeURL(state) + scopes = append(scopes, scopeProfile) + authCodeUrl := srv.oauth2Config(scopes).AuthCodeURL(state) // Issue state cookie - http.SetCookie(rw, cfg.createCookie(StateCookieName, state)) + http.SetCookie(rw, srv.createCookie(StateCookieName, state)) http.Redirect(rw, r, authCodeUrl, http.StatusSeeOther) } -func generateNonce(n int) (string, error) { - b := make([]byte, n) +func generateNonce() (string, error) { + b := make([]byte, 32) _, err := rand.Read(b) if err != nil { diff --git a/pkg/server/auth/auth_test.go b/pkg/server/auth/auth_test.go index 578227717c..f8616cc30b 100644 --- a/pkg/server/auth/auth_test.go +++ b/pkg/server/auth/auth_test.go @@ -31,22 +31,34 @@ func TestWithAPIAuthReturns401ForUnauthenticatedRequests(t *testing.T) { fake := m.Config() mux := http.NewServeMux() - c, err := auth.NewAuthConfig(ctx, fake.Issuer, fake.ClientID, fake.ClientSecret, "", false, 20*time.Minute, http.DefaultClient, logr.Discard()) + srv, err := auth.NewAuthServer(ctx, logr.Discard(), http.DefaultClient, + auth.AuthConfig{ + auth.OIDCConfig{ + IssuerURL: fake.Issuer, + ClientID: fake.ClientID, + ClientSecret: fake.ClientSecret, + RedirectURL: "", + }, + auth.CookieConfig{ + CookieDuration: 20 * time.Minute, + IssueSecureCookies: false, + }, + }) if err != nil { t.Error("failed to create auth config") } - auth.RegisterAuthHandler(mux, "/oauth2", c) + auth.RegisterAuthServer(mux, "/oauth2", srv) s := httptest.NewServer(mux) defer s.Close() // Set the correct redirect URL now that we have a server running - c.SetRedirectURL(s.URL + "/oauth2/callback") + srv.SetRedirectURL(s.URL + "/oauth2/callback") res := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, s.URL, nil) - auth.WithAPIAuth(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {}), c).ServeHTTP(res, req) + auth.WithAPIAuth(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {}), srv).ServeHTTP(res, req) if res.Result().StatusCode != http.StatusUnauthorized { t.Errorf("expected status of %d but got %d", http.StatusUnauthorized, res.Result().StatusCode) @@ -68,29 +80,41 @@ func TestWithWebAuthRedirectsToOIDCIssuerForUnauthenticatedRequests(t *testing.T fake := m.Config() mux := http.NewServeMux() - c, err := auth.NewAuthConfig(ctx, fake.Issuer, fake.ClientID, fake.ClientSecret, "", false, 20*time.Minute, http.DefaultClient, logr.Discard()) + srv, err := auth.NewAuthServer(ctx, logr.Discard(), http.DefaultClient, + auth.AuthConfig{ + auth.OIDCConfig{ + IssuerURL: fake.Issuer, + ClientID: fake.ClientID, + ClientSecret: fake.ClientSecret, + RedirectURL: "", + }, + auth.CookieConfig{ + CookieDuration: 20 * time.Minute, + IssueSecureCookies: false, + }, + }) if err != nil { t.Error("failed to create auth config") } - auth.RegisterAuthHandler(mux, "/oauth2", c) + auth.RegisterAuthServer(mux, "/oauth2", srv) s := httptest.NewServer(mux) defer s.Close() // Set the correct redirect URL now that we have a server running redirectURL := s.URL + "/oauth2/callback" - c.SetRedirectURL(redirectURL) + srv.SetRedirectURL(redirectURL) res := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, s.URL, nil) - auth.WithWebAuth(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {}), c).ServeHTTP(res, req) + auth.WithWebAuth(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {}), srv).ServeHTTP(res, req) if res.Result().StatusCode != http.StatusSeeOther { t.Errorf("expected status of %d but got %d", http.StatusSeeOther, res.Result().StatusCode) } - authCodeURL := fmt.Sprintf("%s?client_id=%s&redirect_uri=%s&response_type=code&scope=%s", m.AuthorizationEndpoint(), fake.ClientID, url.QueryEscape(redirectURL), strings.Join([]string{auth.ScopeProfile, oidc.ScopeOpenID, oidc.ScopeOfflineAccess, auth.ScopeEmail}, "+")) + authCodeURL := fmt.Sprintf("%s?client_id=%s&redirect_uri=%s&response_type=code&scope=%s", m.AuthorizationEndpoint(), fake.ClientID, url.QueryEscape(redirectURL), strings.Join([]string{"profile", oidc.ScopeOpenID, oidc.ScopeOfflineAccess, "email"}, "+")) if !strings.HasPrefix(res.Result().Header.Get("Location"), authCodeURL) { t.Errorf("expected Location header URL to include scopes %s but does not: %s", authCodeURL, res.Result().Header.Get("Location")) } diff --git a/pkg/server/auth/config.go b/pkg/server/auth/config.go deleted file mode 100644 index c5d7b45d22..0000000000 --- a/pkg/server/auth/config.go +++ /dev/null @@ -1,237 +0,0 @@ -package auth - -import ( - "context" - "encoding/base64" - "encoding/json" - "fmt" - "net/http" - "time" - - "github.com/coreos/go-oidc/v3/oidc" - "github.com/go-logr/logr" - "golang.org/x/oauth2" -) - -// AuthConfig holds auth configuration parameters. -type AuthConfig struct { - logger logr.Logger - provider *oidc.Provider - clientID string - clientSecret string - redirectURL string - issueSecureCookies bool - cookieDuration time.Duration - client *http.Client -} - -// NewAuthConfig creates a new AuthConfig object. -func NewAuthConfig(ctx context.Context, oidcIssuerURL, oidcClientId, oidcClientSecret, oidcRedirectURL string, oidcIssueSecureCookies bool, oidcCookieDuration time.Duration, client *http.Client, logger logr.Logger) (*AuthConfig, error) { - provider, err := oidc.NewProvider(ctx, oidcIssuerURL) - if err != nil { - return nil, fmt.Errorf("could not create provider: %w", err) - } - - return &AuthConfig{ - logger: logger, - provider: provider, - clientID: oidcClientId, - clientSecret: oidcClientSecret, - redirectURL: oidcRedirectURL, - issueSecureCookies: oidcIssueSecureCookies, - cookieDuration: oidcCookieDuration, - client: client, - }, nil -} - -// Logger returns the logger instance -func (c *AuthConfig) Logger() logr.Logger { - return c.logger -} - -// SetRedirectURL is used to set the redirect URL. This is meant to be used -// in unit tests only. -func (c *AuthConfig) SetRedirectURL(url string) { - c.redirectURL = url -} - -func (c *AuthConfig) verifier() *oidc.IDTokenVerifier { - return c.provider.Verifier(&oidc.Config{ClientID: c.clientID}) -} - -func (c *AuthConfig) oauth2Config(scopes []string) *oauth2.Config { - // Ensure "openid" scope is always present. - if !contains(scopes, oidc.ScopeOpenID) { - scopes = append(scopes, oidc.ScopeOpenID) - } - - // Request "offline_access" scope for refresh tokens. - if !contains(scopes, oidc.ScopeOfflineAccess) { - scopes = append(scopes, oidc.ScopeOfflineAccess) - } - - // Request "email" scope to get user's email address. - if !contains(scopes, ScopeEmail) { - scopes = append(scopes, ScopeEmail) - } - - return &oauth2.Config{ - ClientID: c.clientID, - ClientSecret: c.clientSecret, - Endpoint: c.provider.Endpoint(), - RedirectURL: c.redirectURL, - Scopes: scopes, - } -} - -func (c *AuthConfig) callback() http.HandlerFunc { - return func(rw http.ResponseWriter, r *http.Request) { - var ( - token *oauth2.Token - state SessionState - ) - - ctx := oidc.ClientContext(r.Context(), c.client) - - switch r.Method { - case http.MethodGet: - // Authorization redirect callback from OAuth2 auth flow. - if errMsg := r.FormValue("error"); errMsg != "" { - c.logger.Info("authz redirect callback failed", "error", errMsg, "error_description", r.FormValue("error_description")) - http.Error(rw, "", http.StatusBadRequest) - - return - } - - code := r.FormValue("code") - if code == "" { - c.logger.Info("code value was empty") - http.Error(rw, "", http.StatusBadRequest) - - return - } - - cookie, err := r.Cookie(StateCookieName) - if err != nil { - c.logger.Error(err, "cookie was not found in the request", "cookie", StateCookieName) - http.Error(rw, "", http.StatusBadRequest) - - return - } - - if state := r.FormValue("state"); state != cookie.Value { - c.logger.Info("cookie value does not match state value") - http.Error(rw, "", http.StatusBadRequest) - - return - } - - b, err := base64.StdEncoding.DecodeString(cookie.Value) - if err != nil { - c.logger.Error(err, "cannot base64 decode cookie", "cookie", StateCookieName, "cookie_value", cookie.Value) - http.Error(rw, "", http.StatusInternalServerError) - - return - } - - if err := json.Unmarshal(b, &state); err != nil { - c.logger.Error(err, "failed to unmarshal state to JSON") - http.Error(rw, "", http.StatusInternalServerError) - - return - } - - token, err = c.oauth2Config(nil).Exchange(ctx, code) - if err != nil { - c.logger.Error(err, "failed to exchange auth code for token") - http.Error(rw, "", http.StatusInternalServerError) - - return - } - default: - http.Error(rw, fmt.Sprintf("method not implemented: %s", r.Method), http.StatusBadRequest) - - return - } - - rawIDToken, ok := token.Extra("id_token").(string) - if !ok { - http.Error(rw, "no id_token in token response", http.StatusInternalServerError) - return - } - - _, err := c.verifier().Verify(r.Context(), rawIDToken) - if err != nil { - http.Error(rw, fmt.Sprintf("failed to verify ID token: %v", err), http.StatusInternalServerError) - return - } - - // Issue ID token cookie - http.SetCookie(rw, c.createCookie(IDTokenCookieName, rawIDToken)) - - // Some OIDC providers may not include a refresh token - if token.RefreshToken != "" { - // Issue refresh token cookie - http.SetCookie(rw, c.createCookie(RefreshTokenCookieName, token.RefreshToken)) - } - - // Clear state cookie - http.SetCookie(rw, c.clearCookie(StateCookieName)) - - http.Redirect(rw, r, state.ReturnURL, http.StatusSeeOther) - } -} - -func (c *AuthConfig) createCookie(name, value string) *http.Cookie { - cookie := &http.Cookie{ - Name: name, - Value: value, - Path: "/", - Expires: time.Now().UTC().Add(c.cookieDuration), - HttpOnly: true, - } - - if c.issueSecureCookies { - cookie.Secure = true - } - - return cookie -} - -func (c *AuthConfig) clearCookie(name string) *http.Cookie { - cookie := &http.Cookie{ - Name: name, - Value: "", - Path: "/", - Expires: time.Unix(0, 0), - } - - return cookie -} - -// SessionState represents the state that needs to be persisted between -// the AuthN request from the Relying Party (RP) to the authorization -// endpoint of the OpenID Provider (OP) and the AuthN response back from -// the OP to the RP's callback URL. This state could be persisted server-side -// in a data store such as Redis but we prefer to operate stateless so we -// store this in a cookie instead. The cookie value and the value of the -// "state" parameter passed in the AuthN request are identical and set to -// the base64-encoded, JSON serialised state. -// -// https://openid.net/specs/openid-connect-core-1_0.html#Overview -// https://auth0.com/docs/configure/attack-protection/state-parameters#alternate-redirect-method -// https://community.auth0.com/t/state-parameter-and-user-redirection/8387/2 -type SessionState struct { - Nonce string `json:"n"` - ReturnURL string `json:"return_url"` -} - -func contains(ss []string, s string) bool { - for _, v := range ss { - if v == s { - return true - } - } - - return false -} diff --git a/pkg/server/auth/jwt.go b/pkg/server/auth/jwt.go index 4165729fef..1d4a01b5b0 100644 --- a/pkg/server/auth/jwt.go +++ b/pkg/server/auth/jwt.go @@ -102,9 +102,8 @@ func parseJWTToken(ctx context.Context, verifier *oidc.IDTokenVerifier, rawIDTok return &UserPrincipal{ID: claims.Email, Groups: claims.Groups}, nil } -// MultiAuth combines the JWTCookie and JWTAuthorizationHeader -// principal getters together to look for a principal first in -// a cookie and then in an Authorization header. +// MultiAuthPrincipal looks for a principal in an array of principal getters and +// if it finds an error or a principal it returns, otherwise it returns (nil,nil). type MultiAuthPrincipal []PrincipalGetter func (m MultiAuthPrincipal) Principal(r *http.Request) (*UserPrincipal, error) { diff --git a/pkg/server/auth/jwt_test.go b/pkg/server/auth/jwt_test.go new file mode 100644 index 0000000000..0e3fe8d9de --- /dev/null +++ b/pkg/server/auth/jwt_test.go @@ -0,0 +1,171 @@ +package auth_test + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/coreos/go-oidc/v3/oidc" + "github.com/go-logr/logr" + "github.com/google/go-cmp/cmp" + "github.com/weaveworks/weave-gitops/pkg/server/auth" + "github.com/weaveworks/weave-gitops/pkg/testutils" +) + +func TestJWTCookiePrincipalGetter(t *testing.T) { + const cookieName = "auth-token" + + privKey := testutils.MakeRSAPrivateKey(t) + authTests := []struct { + name string + cookie string + want *auth.UserPrincipal + }{ + {"JWT ID Token", testutils.MakeJWToken(t, privKey, "example@example.com"), &auth.UserPrincipal{ID: "example@example.com", Groups: []string{"testing"}}}, + {"no cookie value", "", nil}, + } + + srv := testutils.MakeKeysetServer(t, privKey) + keySet := oidc.NewRemoteKeySet(oidc.ClientContext(context.TODO(), srv.Client()), srv.URL) + verifier := oidc.NewVerifier("http://127.0.0.1:5556/dex", keySet, &oidc.Config{ClientID: "test-service"}) + + for _, tt := range authTests { + t.Run(tt.name, func(t *testing.T) { + principal, err := auth.NewJWTCookiePrincipalGetter(logr.Discard(), verifier, cookieName).Principal(makeCookieRequest(cookieName, tt.cookie)) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(tt.want, principal); diff != "" { + t.Fatalf("failed to get principal:\n%s", diff) + } + }) + } +} + +func TestJWTAuthorizationHeaderPrincipalGetter(t *testing.T) { + privKey := testutils.MakeRSAPrivateKey(t) + authTests := []struct { + name string + authorization string + want *auth.UserPrincipal + }{ + {"JWT ID Token", "Bearer " + testutils.MakeJWToken(t, privKey, "example@example.com"), &auth.UserPrincipal{ID: "example@example.com", Groups: []string{"testing"}}}, + {"no auth header value", "", nil}, + } + + srv := testutils.MakeKeysetServer(t, privKey) + keySet := oidc.NewRemoteKeySet(oidc.ClientContext(context.TODO(), srv.Client()), srv.URL) + verifier := oidc.NewVerifier("http://127.0.0.1:5556/dex", keySet, &oidc.Config{ClientID: "test-service"}) + + for _, tt := range authTests { + t.Run(tt.name, func(t *testing.T) { + principal, err := auth.NewJWTAuthorizationHeaderPrincipalGetter(logr.Discard(), verifier).Principal(makeAuthenticatedRequest(tt.authorization)) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(tt.want, principal); diff != "" { + t.Fatalf("failed to get principal:\n%s", diff) + } + }) + } +} + +func makeCookieRequest(cookieName, token string) *http.Request { + req := httptest.NewRequest("GET", "http://example.com/", nil) + if token != "" { + req.AddCookie(&http.Cookie{ + Name: cookieName, + Value: token, + }) + } + + return req +} + +func makeAuthenticatedRequest(token string) *http.Request { + req := httptest.NewRequest("GET", "http://example.com/", nil) + if token != "" { + req.Header.Set("Authorization", token) + } + + return req +} + +func TestMultiAuth(t *testing.T) { + err := errors.New("oops") + multiAuthTests := []struct { + name string + auths []auth.PrincipalGetter + want *auth.UserPrincipal + err error + }{ + { + name: "no auths", + auths: []auth.PrincipalGetter{}, + want: nil, + }, + { + name: "no successful auths", + auths: []auth.PrincipalGetter{stubPrincipalGetter{}}, + want: nil, + }, + { + name: "one successful auth", + auths: []auth.PrincipalGetter{stubPrincipalGetter{id: "testing"}}, + want: &auth.UserPrincipal{ID: "testing"}, + }, + { + name: "two auths, one unsuccessful", + auths: []auth.PrincipalGetter{stubPrincipalGetter{}, stubPrincipalGetter{id: "testing"}}, + want: &auth.UserPrincipal{ID: "testing"}, + }, + { + name: "two auths, none successful", + auths: []auth.PrincipalGetter{stubPrincipalGetter{}, stubPrincipalGetter{}}, + want: nil, + }, + { + name: "error", + auths: []auth.PrincipalGetter{errorPrincipalGetter{err: err}}, + want: nil, + err: err, + }, + } + + for _, tt := range multiAuthTests { + t.Run(tt.name, func(t *testing.T) { + mg := auth.MultiAuthPrincipal(tt.auths) + req := httptest.NewRequest("GET", "http://example.com/", nil) + + principal, err := mg.Principal(req) + if err != tt.err { + t.Fatalf("got err %s, want %s", err, tt.err) + } + if diff := cmp.Diff(tt.want, principal); diff != "" { + t.Fatalf("failed to get principal:\n%s", diff) + } + }) + } +} + +type stubPrincipalGetter struct { + id string +} + +func (s stubPrincipalGetter) Principal(r *http.Request) (*auth.UserPrincipal, error) { + if s.id != "" { + return &auth.UserPrincipal{ID: s.id}, nil + } + + return nil, nil +} + +type errorPrincipalGetter struct { + err error +} + +func (s errorPrincipalGetter) Principal(r *http.Request) (*auth.UserPrincipal, error) { + return nil, s.err +} diff --git a/pkg/server/auth/server.go b/pkg/server/auth/server.go new file mode 100644 index 0000000000..0717fd3b90 --- /dev/null +++ b/pkg/server/auth/server.go @@ -0,0 +1,244 @@ +package auth + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "time" + + "github.com/coreos/go-oidc/v3/oidc" + "github.com/go-logr/logr" + "golang.org/x/oauth2" +) + +// OIDCConfig is used to configure an AuthServer to interact with +// an OIDC issuer. +type OIDCConfig struct { + IssuerURL string + ClientID string + ClientSecret string + RedirectURL string +} + +// CookieConfig is used to configure the cookies that get issued +// from the OIDC issuer once the OAuth2 process flow completes. +type CookieConfig struct { + CookieDuration time.Duration + IssueSecureCookies bool +} + +// AuthConfig is used to configure an AuthServer. +type AuthConfig struct { + OIDCConfig + CookieConfig +} + +// AuthServer interacts with an OIDC issuer to handle the OAuth2 process flow. +type AuthServer struct { + logger logr.Logger + client *http.Client + provider *oidc.Provider + config AuthConfig +} + +// NewAuthServer creates a new AuthServer object. +func NewAuthServer(ctx context.Context, logger logr.Logger, client *http.Client, config AuthConfig) (*AuthServer, error) { + provider, err := oidc.NewProvider(ctx, config.IssuerURL) + if err != nil { + return nil, fmt.Errorf("could not create provider: %w", err) + } + + return &AuthServer{ + logger: logger, + client: client, + provider: provider, + config: config, + }, nil +} + +// SetRedirectURL is used to set the redirect URL. This is meant to be used +// in unit tests only. +func (c *AuthServer) SetRedirectURL(url string) { + c.config.RedirectURL = url +} + +func (c *AuthServer) verifier() *oidc.IDTokenVerifier { + return c.provider.Verifier(&oidc.Config{ClientID: c.config.ClientID}) +} + +func (c *AuthServer) oauth2Config(scopes []string) *oauth2.Config { + // Ensure "openid" scope is always present. + if !contains(scopes, oidc.ScopeOpenID) { + scopes = append(scopes, oidc.ScopeOpenID) + } + + // Request "offline_access" scope for refresh tokens. + if !contains(scopes, oidc.ScopeOfflineAccess) { + scopes = append(scopes, oidc.ScopeOfflineAccess) + } + + // Request "email" scope to get user's email address. + if !contains(scopes, scopeEmail) { + scopes = append(scopes, scopeEmail) + } + + return &oauth2.Config{ + ClientID: c.config.ClientID, + ClientSecret: c.config.ClientSecret, + Endpoint: c.provider.Endpoint(), + RedirectURL: c.config.RedirectURL, + Scopes: scopes, + } +} + +func (c *AuthServer) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + var ( + token *oauth2.Token + state SessionState + ) + + ctx := oidc.ClientContext(r.Context(), c.client) + + switch r.Method { + case http.MethodGet: + // Authorization redirect callback from OAuth2 auth flow. + if errMsg := r.FormValue("error"); errMsg != "" { + c.logger.Info("authz redirect callback failed", "error", errMsg, "error_description", r.FormValue("error_description")) + http.Error(rw, "", http.StatusBadRequest) + + return + } + + code := r.FormValue("code") + if code == "" { + c.logger.Info("code value was empty") + http.Error(rw, "", http.StatusBadRequest) + + return + } + + cookie, err := r.Cookie(StateCookieName) + if err != nil { + c.logger.Error(err, "cookie was not found in the request", "cookie", StateCookieName) + http.Error(rw, "", http.StatusBadRequest) + + return + } + + if state := r.FormValue("state"); state != cookie.Value { + c.logger.Info("cookie value does not match state value") + http.Error(rw, "", http.StatusBadRequest) + + return + } + + b, err := base64.StdEncoding.DecodeString(cookie.Value) + if err != nil { + c.logger.Error(err, "cannot base64 decode cookie", "cookie", StateCookieName, "cookie_value", cookie.Value) + http.Error(rw, "", http.StatusInternalServerError) + + return + } + + if err := json.Unmarshal(b, &state); err != nil { + c.logger.Error(err, "failed to unmarshal state to JSON") + http.Error(rw, "", http.StatusInternalServerError) + + return + } + + token, err = c.oauth2Config(nil).Exchange(ctx, code) + if err != nil { + c.logger.Error(err, "failed to exchange auth code for token") + http.Error(rw, "", http.StatusInternalServerError) + + return + } + default: + http.Error(rw, fmt.Sprintf("method not implemented: %s", r.Method), http.StatusBadRequest) + + return + } + + rawIDToken, ok := token.Extra("id_token").(string) + if !ok { + http.Error(rw, "no id_token in token response", http.StatusInternalServerError) + return + } + + _, err := c.verifier().Verify(r.Context(), rawIDToken) + if err != nil { + http.Error(rw, fmt.Sprintf("failed to verify ID token: %v", err), http.StatusInternalServerError) + return + } + + // Issue ID token cookie + http.SetCookie(rw, c.createCookie(IDTokenCookieName, rawIDToken)) + + // Some OIDC providers may not include a refresh token + if token.RefreshToken != "" { + // Issue refresh token cookie + http.SetCookie(rw, c.createCookie(RefreshTokenCookieName, token.RefreshToken)) + } + + // Clear state cookie + http.SetCookie(rw, c.clearCookie(StateCookieName)) + + http.Redirect(rw, r, state.ReturnURL, http.StatusSeeOther) +} + +func (c *AuthServer) createCookie(name, value string) *http.Cookie { + cookie := &http.Cookie{ + Name: name, + Value: value, + Path: "/", + Expires: time.Now().UTC().Add(c.config.CookieDuration), + HttpOnly: true, + } + + if c.config.IssueSecureCookies { + cookie.Secure = true + } + + return cookie +} + +func (c *AuthServer) clearCookie(name string) *http.Cookie { + cookie := &http.Cookie{ + Name: name, + Value: "", + Path: "/", + Expires: time.Unix(0, 0), + } + + return cookie +} + +// SessionState represents the state that needs to be persisted between +// the AuthN request from the Relying Party (RP) to the authorization +// endpoint of the OpenID Provider (OP) and the AuthN response back from +// the OP to the RP's callback URL. This state could be persisted server-side +// in a data store such as Redis but we prefer to operate stateless so we +// store this in a cookie instead. The cookie value and the value of the +// "state" parameter passed in the AuthN request are identical and set to +// the base64-encoded, JSON serialised state. +// +// https://openid.net/specs/openid-connect-core-1_0.html#Overview +// https://auth0.com/docs/configure/attack-protection/state-parameters#alternate-redirect-method +// https://community.auth0.com/t/state-parameter-and-user-redirection/8387/2 +type SessionState struct { + Nonce string `json:"n"` + ReturnURL string `json:"return_url"` +} + +func contains(ss []string, s string) bool { + for _, v := range ss { + if v == s { + return true + } + } + + return false +} diff --git a/pkg/server/handler.go b/pkg/server/handler.go index 5c45787ab3..e80576e668 100644 --- a/pkg/server/handler.go +++ b/pkg/server/handler.go @@ -24,7 +24,7 @@ func AuthEnabled() bool { type Config struct { AppConfig *ApplicationsConfig ProfilesConfig ProfilesConfig - AuthConfig *auth.AuthConfig + AuthServer *auth.AuthServer } func NewHandlers(ctx context.Context, cfg *Config) (http.Handler, error) { @@ -33,7 +33,7 @@ func NewHandlers(ctx context.Context, cfg *Config) (http.Handler, error) { httpHandler = middleware.WithProviderToken(cfg.AppConfig.JwtClient, httpHandler, cfg.AppConfig.Logger) if AuthEnabled() { - httpHandler = auth.WithAPIAuth(httpHandler, cfg.AuthConfig) + httpHandler = auth.WithAPIAuth(httpHandler, cfg.AuthServer) } appsSrv := NewApplicationsServer(cfg.AppConfig) diff --git a/pkg/testutils/testutils.go b/pkg/testutils/testutils.go index 346c44b937..0f6ff449ee 100644 --- a/pkg/testutils/testutils.go +++ b/pkg/testutils/testutils.go @@ -1,25 +1,33 @@ package testutils import ( + "crypto/rand" + "crypto/rsa" + "encoding/json" "fmt" "io/ioutil" "log" + "net/http" + "net/http/httptest" "os" "path/filepath" "strings" + "testing" + "time" "github.com/go-logr/logr" + wego "github.com/weaveworks/weave-gitops/api/v1alpha1" "github.com/weaveworks/weave-gitops/pkg/flux" "github.com/weaveworks/weave-gitops/pkg/kube" "github.com/weaveworks/weave-gitops/pkg/osys/osysfakes" "github.com/weaveworks/weave-gitops/pkg/runner" fakelogr "github.com/weaveworks/weave-gitops/pkg/vendorfakes/logr" + "gopkg.in/square/go-jose.v2" + "gopkg.in/square/go-jose.v2/jwt" + ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/envtest" - wego "github.com/weaveworks/weave-gitops/api/v1alpha1" - ctrl "sigs.k8s.io/controller-runtime" - "github.com/fluxcd/go-git-providers/gitprovider" kustomizev2 "github.com/fluxcd/kustomize-controller/api/v1beta2" sourcev1 "github.com/fluxcd/source-controller/api/v1beta1" @@ -215,3 +223,73 @@ func Setenv(k, v string) func() { } } } + +// MakeRSAPrivateKey generates and returns an RSA Private Key. +func MakeRSAPrivateKey(t *testing.T) *rsa.PrivateKey { + t.Helper() + + k, err := rsa.GenerateKey(rand.Reader, 4096) + if err != nil { + t.Fatal(err) + } + + return k +} + +// MakeJWToken creates and signs a token with the provided key. +func MakeJWToken(t *testing.T, key *rsa.PrivateKey, email string) string { + t.Helper() + + signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: key}, nil) + if err != nil { + t.Fatal(err) + } + + maxAgeSecondsAuthCookie := time.Second * 600 + notBefore := time.Now().Add(-time.Second * 60) + claims := jwt.Claims{ + Issuer: "http://127.0.0.1:5556/dex", + Subject: "testing", + Audience: jwt.Audience{"test-service"}, + NotBefore: jwt.NewNumericDate(notBefore), + IssuedAt: jwt.NewNumericDate(notBefore), + Expiry: jwt.NewNumericDate(notBefore.Add(time.Duration(maxAgeSecondsAuthCookie))), + } + githubClaims := struct { + Groups []string `json:"groups"` + Email string `json:"email"` + PreferredUsername string `json:"preferred_username"` + }{ + []string{"testing"}, + email, + "example", + } + + signed, err := jwt.Signed(signer).Claims(claims).Claims(githubClaims).CompactSerialize() + if err != nil { + t.Fatal(err) + } + + return signed +} + +// MakeKeysetServer starts an HTTP server that can serve JSONWebKey sets. +func MakeKeysetServer(t *testing.T, key *rsa.PrivateKey) *httptest.Server { + t.Helper() + + ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + var keys jose.JSONWebKeySet + keys.Keys = []jose.JSONWebKey{ + { + Key: key.Public(), + Use: "sig", + Algorithm: "RS256", + }, + } + _ = json.NewEncoder(w).Encode(keys) + })) + t.Cleanup(ts.Close) + + return ts +}