Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

httputil : wrap handlers for additional context #413

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
103 changes: 46 additions & 57 deletions authenticate/handlers.go
Expand Up @@ -32,12 +32,12 @@ func (a *Authenticate) Handler() http.Handler {
csrf.UnsafePaths([]string{callbackPath}), // enforce CSRF on "safe" handler
csrf.FormValueName("state"), // rfc6749 section-10.12
csrf.CookieName(fmt.Sprintf("%s_csrf", a.cookieOptions.Name)),
csrf.ErrorHandler(http.HandlerFunc(httputil.CSRFFailureHandler)),
csrf.ErrorHandler(httputil.HandlerFunc(httputil.CSRFFailureHandler)),
))

r.HandleFunc("/robots.txt", a.RobotsTxt).Methods(http.MethodGet)
r.Path("/robots.txt").HandlerFunc(a.RobotsTxt).Methods(http.MethodGet)
// Identity Provider (IdP) endpoints
r.HandleFunc("/oauth2/callback", a.OAuthCallback).Methods(http.MethodGet)
r.Path("/oauth2/callback").Handler(httputil.HandlerFunc(a.OAuthCallback)).Methods(http.MethodGet)

// Proxy service endpoints
v := r.PathPrefix("/.pomerium").Subrouter()
Expand All @@ -53,37 +53,36 @@ func (a *Authenticate) Handler() http.Handler {
v.Use(middleware.ValidateSignature(a.sharedKey))
v.Use(sessions.RetrieveSession(a.sessionLoaders...))
v.Use(a.VerifySession)
v.HandleFunc("/sign_in", a.SignIn)
v.HandleFunc("/sign_out", a.SignOut)
v.Path("/sign_in").Handler(httputil.HandlerFunc(a.SignIn))
v.Path("/sign_out").Handler(httputil.HandlerFunc(a.SignOut))

// programmatic access api endpoint
api := r.PathPrefix("/api").Subrouter()
api.Use(sessions.RetrieveSession(a.sessionLoaders...))
api.HandleFunc("/v1/refresh", a.RefreshAPI)
api.Path("/v1/refresh").Handler(httputil.HandlerFunc(a.RefreshAPI))

return r
}

// VerifySession is the middleware used to enforce a valid authentication
// session state is attached to the users's request context.
func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
state, err := sessions.FromContext(r.Context())
if errors.Is(err, sessions.ErrExpired) {
if err := a.refresh(w, r, state); err != nil {
log.FromRequest(r).Info().Err(err).Msg("authenticate: verify session, refresh")
a.reauthenticateOrFail(w, r, err)
return
return a.reauthenticateOrFail(w, r, err)
}
// redirect to restart middleware-chain following refresh
httputil.Redirect(w, r, urlutil.GetAbsoluteURL(r).String(), http.StatusFound)
return
return nil
} else if err != nil {
log.FromRequest(r).Info().Err(err).Msg("authenticate: verify session")
a.reauthenticateOrFail(w, r, err)
return
return a.reauthenticateOrFail(w, r, err)
}
next.ServeHTTP(w, r)
return nil
})
}

Expand All @@ -106,11 +105,10 @@ func (a *Authenticate) RobotsTxt(w http.ResponseWriter, r *http.Request) {
}

// SignIn handles to authenticating a user.
func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) {
func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI))
if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("malformed redirect_uri", http.StatusBadRequest, err))
return
return httputil.NewError(http.StatusBadRequest, err)
}

jwtAudience := []string{a.RedirectURL.Hostname(), redirectURL.Hostname()}
Expand All @@ -120,8 +118,7 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) {
if callbackStr := r.FormValue(urlutil.QueryCallbackURI); callbackStr != "" {
callbackURL, err = urlutil.ParseAndValidateURL(callbackStr)
if err != nil {
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err))
return
return httputil.NewError(http.StatusBadRequest, err)
}
jwtAudience = append(jwtAudience, callbackURL.Hostname())
} else {
Expand All @@ -138,16 +135,14 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) {

s, err := sessions.FromContext(r.Context())
if err != nil {
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err))
return
return httputil.NewError(http.StatusBadRequest, err)
}

// user impersonation
if impersonate := r.FormValue(urlutil.QueryImpersonateAction); impersonate != "" {
s.SetImpersonation(r.FormValue(urlutil.QueryImpersonateEmail), r.FormValue(urlutil.QueryImpersonateGroups))
if err := a.sessionStore.SaveSession(w, r, s); err != nil {
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err))
return
return httputil.NewError(http.StatusBadRequest, err)
}
}

Expand All @@ -159,8 +154,8 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) {
newSession.Programmatic = true
encSession, err := a.encryptedEncoder.Marshal(newSession)
if err != nil {
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err))
return
return httputil.NewError(http.StatusBadRequest, err)

}
callbackParams.Set(urlutil.QueryRefreshToken, string(encSession))
callbackParams.Set(urlutil.QueryIsProgrammatic, "true")
Expand All @@ -169,8 +164,7 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) {
// sign the route session, as a JWT
signedJWT, err := a.sharedEncoder.Marshal(newSession.RouteSession(DefaultSessionDuration))
if err != nil {
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err))
return
return httputil.NewError(http.StatusBadRequest, err)
}

// encrypt our route-based token JWT avoiding any accidental logging
Expand All @@ -187,28 +181,28 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) {
// proxy's callback URL which is responsible for setting our new route-session
uri := urlutil.NewSignedURL(a.sharedKey, callbackURL)
httputil.Redirect(w, r, uri.String(), http.StatusFound)
return nil
}

// SignOut signs the user out and attempts to revoke the user's identity session
// Handles both GET and POST.
func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) {
func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) error {
session, err := sessions.FromContext(r.Context())
if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusBadRequest, err))
return
return httputil.NewError(http.StatusBadRequest, err)
}
a.sessionStore.ClearSession(w, r)
err = a.provider.Revoke(r.Context(), session.AccessToken)
if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("could not revoke user session", http.StatusBadRequest, err))
return
return httputil.NewError(http.StatusBadRequest, err)
}
redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI))
if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("malformed redirect_uri", http.StatusBadRequest, err))
return
return httputil.NewError(http.StatusBadRequest, err)

}
httputil.Redirect(w, r, redirectURL.String(), http.StatusFound)
return nil
}

// reauthenticateOrFail starts the authenticate process by redirecting the
Expand All @@ -221,11 +215,10 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) {
// https://openid.net/specs/openid-connect-core-1_0-final.html#AuthRequest
// https://tools.ietf.org/html/rfc6749#section-4.2.1
// https://developer.mozilla.org/en-US/docs/Web/API/XMLHttpRequest
func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Request, err error) {
func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Request, err error) error {
// If request AJAX/XHR request, return a 401 instead .
if reqType := r.Header.Get("X-Requested-With"); strings.EqualFold(reqType, "XmlHttpRequest") {
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusUnauthorized, err))
return
return httputil.NewError(http.StatusUnauthorized, err)
}
a.sessionStore.ClearSession(w, r)
redirectURL := a.RedirectURL.ResolveReference(r.URL)
Expand All @@ -236,32 +229,33 @@ func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Reque
b = append(b, enc...)
encodedState := base64.URLEncoding.EncodeToString(b)
httputil.Redirect(w, r, a.provider.GetSignInURL(encodedState), http.StatusFound)
return nil
}

// OAuthCallback handles the callback from the identity provider.
//
// https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowSteps
// https://openid.net/specs/openid-connect-core-1_0.html#AuthResponse
func (a *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) {
func (a *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) error {
redirect, err := a.getOAuthCallback(w, r)
if err != nil {
httputil.ErrorResponse(w, r, fmt.Errorf("oauth callback : %w", err))
return
return fmt.Errorf("oauth callback : %w", err)
}
httputil.Redirect(w, r, redirect.String(), http.StatusFound)
return nil
}

func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) (*url.URL, error) {
// Error Authentication Response: rfc6749#section-4.1.2.1 & OIDC#3.1.2.6
//
// first, check if the identity provider returned an error
if idpError := r.FormValue("error"); idpError != "" {
return nil, httputil.Error(idpError, http.StatusBadRequest, fmt.Errorf("identity provider: %v", idpError))
return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("identity provider: %v", idpError))
}
// fail if no session redemption code is returned
code := r.FormValue("code")
if code == "" {
return nil, httputil.Error("identity provider returned empty code", http.StatusBadRequest, nil)
return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("identity provider returned empty code"))
}

// Successful Authentication Response: rfc6749#section-4.1.2 & OIDC#3.1.2.5
Expand All @@ -274,20 +268,19 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
// state includes a csrf nonce (validated by middleware) and redirect uri
bytes, err := base64.URLEncoding.DecodeString(r.FormValue("state"))
if err != nil {
return nil, httputil.Error("malformed state", http.StatusBadRequest, err)
return nil, httputil.NewError(http.StatusBadRequest, err)
}

// split state into concat'd components
// (nonce|timestamp|redirect_url|encrypted_data(redirect_url)+mac(nonce,ts))
statePayload := strings.SplitN(string(bytes), "|", 3)
if len(statePayload) != 3 {
return nil, httputil.Error("'state' is malformed", http.StatusBadRequest,
fmt.Errorf("state malformed, size: %d", len(statePayload)))
return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("state malformed, size: %d", len(statePayload)))
}

// verify that the returned timestamp is valid
if err := cryptutil.ValidTimestamp(statePayload[1]); err != nil {
return nil, httputil.Error(err.Error(), http.StatusBadRequest, err)
return nil, httputil.NewError(http.StatusBadRequest, err)
}

// Use our AEAD construct to enforce secrecy and authenticity:
Expand All @@ -296,12 +289,12 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
b := []byte(fmt.Sprint(statePayload[0], "|", statePayload[1], "|"))
redirectString, err := cryptutil.Decrypt(a.cookieCipher, []byte(statePayload[2]), b)
if err != nil {
return nil, httputil.Error("'state' has invalid hmac", http.StatusBadRequest, err)
return nil, httputil.NewError(http.StatusBadRequest, err)
}

redirectURL, err := urlutil.ParseAndValidateURL(string(redirectString))
if err != nil {
return nil, httputil.Error("'state' has invalid redirect uri", http.StatusBadRequest, err)
return nil, httputil.NewError(http.StatusBadRequest, err)
}

// OK. Looks good so let's persist our user session
Expand All @@ -314,29 +307,25 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
// RefreshAPI loads a global state, and attempts to refresh the session's access
// tokens and state with the identity provider. If successful, a new signed JWT
// and refresh token (`refresh_token`) are returned as JSON
func (a *Authenticate) RefreshAPI(w http.ResponseWriter, r *http.Request) {
func (a *Authenticate) RefreshAPI(w http.ResponseWriter, r *http.Request) error {
s, err := sessions.FromContext(r.Context())
if err != nil && !errors.Is(err, sessions.ErrExpired) {
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusBadRequest, err))
return
return httputil.NewError(http.StatusBadRequest, err)
}
newSession, err := a.provider.Refresh(r.Context(), s)
if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusInternalServerError, err))
return
return err
}
newSession = newSession.NewSession(s.Issuer, s.Audience)

encSession, err := a.encryptedEncoder.Marshal(newSession)
if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusInternalServerError, err))
return
return err
}

signedJWT, err := a.sharedEncoder.Marshal(newSession.RouteSession(DefaultSessionDuration))
if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusInternalServerError, err))
return
return err
}
var response struct {
JWT string `json:"jwt"`
Expand All @@ -347,9 +336,9 @@ func (a *Authenticate) RefreshAPI(w http.ResponseWriter, r *http.Request) {

jsonResponse, err := json.Marshal(&response)
if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusBadRequest, err))
return
return httputil.NewError(http.StatusBadRequest, err)
}
w.Header().Set("Content-Type", "application/json")
w.Write(jsonResponse)
return nil
}
20 changes: 9 additions & 11 deletions authenticate/handlers_test.go
Expand Up @@ -11,6 +11,8 @@ import (
"testing"
"time"

"github.com/pomerium/pomerium/internal/httputil"

"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/encoding/mock"
Expand Down Expand Up @@ -154,8 +156,7 @@ func TestAuthenticate_SignIn(t *testing.T) {
r = r.WithContext(ctx)

w := httptest.NewRecorder()

a.SignIn(w, r)
httputil.HandlerFunc(a.SignIn).ServeHTTP(w, r)
if status := w.Code; status != tt.wantCode {
t.Errorf("handler returned wrong status code: got %v want %v %s", status, tt.wantCode, uri)
t.Errorf("\n%+v", w.Body)
Expand Down Expand Up @@ -186,9 +187,9 @@ func TestAuthenticate_SignOut(t *testing.T) {
wantBody string
}{
{"good post", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusFound, ""},
{"failed revoke", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"error\":\"could not revoke user session\"}\n"},
{"load session error", http.MethodPost, errors.New("error"), "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"error\":\"Bad Request\"}\n"},
{"bad redirect uri", http.MethodPost, nil, "corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"error\":\"malformed redirect_uri\"}\n"},
{"failed revoke", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: OH NO\"}\n"},
{"load session error", http.MethodPost, errors.New("error"), "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: error\"}\n"},
{"bad redirect uri", http.MethodPost, nil, "corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: corp.pomerium.io/ url does contain a valid scheme\"}\n"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand All @@ -211,8 +212,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
r.Header.Set("Accept", "application/json")

w := httptest.NewRecorder()

a.SignOut(w, r)
httputil.HandlerFunc(a.SignOut).ServeHTTP(w, r)
if status := w.Code; status != tt.wantCode {
t.Errorf("handler returned wrong status code: got %v want %v", status, tt.wantCode)
}
Expand Down Expand Up @@ -299,8 +299,7 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
r := httptest.NewRequest(tt.method, u.String(), nil)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()

a.OAuthCallback(w, r)
httputil.HandlerFunc(a.OAuthCallback).ServeHTTP(w, r)
if w.Result().StatusCode != tt.wantCode {
t.Errorf("Authenticate.OAuthCallback() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantCode, w.Body.String())
return
Expand Down Expand Up @@ -366,7 +365,6 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
got.ServeHTTP(w, r)
if status := w.Code; status != tt.wantStatus {
t.Errorf("VerifySession() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())

}
})
}
Expand Down Expand Up @@ -417,7 +415,7 @@ func TestAuthenticate_RefreshAPI(t *testing.T) {
r.Header.Set("Accept", "application/json")

w := httptest.NewRecorder()
a.RefreshAPI(w, r)
httputil.HandlerFunc(a.RefreshAPI).ServeHTTP(w, r)
if status := w.Code; status != tt.wantStatus {
t.Errorf("VerifySession() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())

Expand Down