diff --git a/judge/handler.go b/judge/handler.go index 0e28377080..42f91056fe 100644 --- a/judge/handler.go +++ b/judge/handler.go @@ -100,7 +100,8 @@ func (h *Handler) judge(w http.ResponseWriter, r *http.Request) { return } - if err := h.RequestHandler.HandleRequest(r, rl); err != nil { + headers, err := h.RequestHandler.HandleRequest(r, rl) + if err != nil { h.Logger.WithError(err). WithField("granted", false). WithField("access_url", r.URL.String()). @@ -114,6 +115,9 @@ func (h *Handler) judge(w http.ResponseWriter, r *http.Request) { WithField("access_url", r.URL.String()). Warn("Access request granted") - w.Header().Set("Authorization", r.Header.Get("Authorization")) + for k := range headers { + w.Header().Set(k, headers.Get(k)) + } + w.WriteHeader(http.StatusOK) } diff --git a/judge/handler_test.go b/judge/handler_test.go index d2364f2b10..0e9a0308f0 100644 --- a/judge/handler_test.go +++ b/judge/handler_test.go @@ -194,7 +194,7 @@ func TestJudge(t *testing.T) { require.NoError(t, err) defer res.Body.Close() - assert.Equal(t, res.Header.Get("Authorization"), tc.authz) + assert.Equal(t, tc.authz, res.Header.Get("Authorization")) assert.Equal(t, tc.code, res.StatusCode) }) } diff --git a/proxy/credentials_issuer.go b/proxy/credentials_issuer.go index 18ef2c7e0c..8613c63f1c 100644 --- a/proxy/credentials_issuer.go +++ b/proxy/credentials_issuer.go @@ -28,6 +28,6 @@ import ( ) type CredentialsIssuer interface { - Issue(r *http.Request, session *AuthenticationSession, config json.RawMessage, rl *rule.Rule) error + Issue(r *http.Request, session *AuthenticationSession, config json.RawMessage, rl *rule.Rule) (http.Header, error) GetID() string } diff --git a/proxy/credentials_issuer_broken.go b/proxy/credentials_issuer_broken.go index 8a66aed340..ce18371242 100644 --- a/proxy/credentials_issuer_broken.go +++ b/proxy/credentials_issuer_broken.go @@ -38,6 +38,6 @@ func (a *CredentialsIssuerBroken) GetID() string { return "broken" } -func (a *CredentialsIssuerBroken) Issue(r *http.Request, session *AuthenticationSession, config json.RawMessage, rl *rule.Rule) error { - return errors.New("forced denial of credentials") +func (a *CredentialsIssuerBroken) Issue(r *http.Request, session *AuthenticationSession, config json.RawMessage, rl *rule.Rule) (http.Header, error) { + return nil, errors.New("forced denial of credentials") } diff --git a/proxy/credentials_issuer_broken_test.go b/proxy/credentials_issuer_broken_test.go index 0e87e25728..84abb0dce8 100644 --- a/proxy/credentials_issuer_broken_test.go +++ b/proxy/credentials_issuer_broken_test.go @@ -32,5 +32,6 @@ func TestCredentialsIssuerBroken(t *testing.T) { assert.NotNil(t, b) assert.NotEmpty(t, b.GetID()) - require.Error(t, b.Issue(nil, nil, nil, nil)) + _, err := b.Issue(nil, nil, nil, nil) + require.Error(t, err) } diff --git a/proxy/credentials_issuer_cookies.go b/proxy/credentials_issuer_cookies.go index fe79ce0125..f5b9a4f403 100644 --- a/proxy/credentials_issuer_cookies.go +++ b/proxy/credentials_issuer_cookies.go @@ -39,7 +39,7 @@ func (a *CredentialsCookies) GetID() string { return "cookies" } -func (a *CredentialsCookies) Issue(r *http.Request, session *AuthenticationSession, config json.RawMessage, rl *rule.Rule) error { +func (a *CredentialsCookies) Issue(r *http.Request, session *AuthenticationSession, config json.RawMessage, rl *rule.Rule) (http.Header, error) { if len(config) == 0 { config = []byte("{}") } @@ -47,8 +47,7 @@ func (a *CredentialsCookies) Issue(r *http.Request, session *AuthenticationSessi // Cache request cookies requestCookies := r.Cookies() - // Remove existing cookies - r.Header.Del("Cookie") + req := http.Request{Header: map[string][]string{}} // Keep track of rule cookies in a map cookies := map[string]bool{} @@ -57,7 +56,7 @@ func (a *CredentialsCookies) Issue(r *http.Request, session *AuthenticationSessi d := json.NewDecoder(bytes.NewBuffer(config)) d.DisallowUnknownFields() if err := d.Decode(&cfg); err != nil { - return errors.WithStack(err) + return nil, errors.WithStack(err) } for cookie, templateString := range cfg.Cookies { @@ -69,17 +68,17 @@ func (a *CredentialsCookies) Issue(r *http.Request, session *AuthenticationSessi if tmpl == nil { tmpl, err = a.RulesCache.New(templateId).Parse(templateString) if err != nil { - return errors.Wrapf(err, `error parsing cookie template "%s" in rule "%s"`, templateString, rl.ID) + return nil, errors.Wrapf(err, `error parsing cookie template "%s" in rule "%s"`, templateString, rl.ID) } } cookieValue := bytes.Buffer{} err = tmpl.Execute(&cookieValue, session) if err != nil { - return errors.Wrapf(err, `error executing cookie template "%s" in rule "%s"`, templateString, rl.ID) + return nil, errors.Wrapf(err, `error executing cookie template "%s" in rule "%s"`, templateString, rl.ID) } - r.AddCookie(&http.Cookie{ + req.AddCookie(&http.Cookie{ Name: cookie, Value: cookieValue.String(), }) @@ -92,9 +91,9 @@ func (a *CredentialsCookies) Issue(r *http.Request, session *AuthenticationSessi // Test if cookie is handled by rule if _, ok := cookies[cookie.Name]; !ok { // Re-add cookie if not handled by rule - r.AddCookie(cookie) + req.AddCookie(cookie) } } - return nil + return req.Header, nil } diff --git a/proxy/credentials_issuer_cookies_test.go b/proxy/credentials_issuer_cookies_test.go index 12b4f2f198..dafa86059c 100644 --- a/proxy/credentials_issuer_cookies_test.go +++ b/proxy/credentials_issuer_cookies_test.go @@ -131,13 +131,15 @@ func TestCredentialsIssuerCookies(t *testing.T) { // Issuer must return non-empty ID assert.NotEmpty(t, issuer.GetID()) + header, err := issuer.Issue(specs.Request, specs.Session, specs.Config, specs.Rule) if specs.Err == nil { - require.NoError(t, issuer.Issue(specs.Request, specs.Session, specs.Config, specs.Rule)) + // Issuer must run without error + require.NoError(t, err) } else { - err := issuer.Issue(specs.Request, specs.Session, specs.Config, specs.Rule) assert.Equal(t, specs.Err.Error(), err.Error()) } + specs.Request.Header = header assert.Equal(t, serializeCookies(specs.Match), serializeCookies(specs.Request.Cookies())) }) } @@ -154,7 +156,7 @@ func TestCredentialsIssuerCookies(t *testing.T) { d := json.NewDecoder(bytes.NewBuffer(specs.Config)) d.Decode(&cfg) - for cookie, _ := range cfg.Cookies { + for cookie := range cfg.Cookies { templateId := fmt.Sprintf("%s:%s", specs.Rule.ID, cookie) cache.New(templateId).Parse("override") overrideCookies = append(overrideCookies, &http.Cookie{Name: cookie, Value: "override"}) @@ -162,13 +164,15 @@ func TestCredentialsIssuerCookies(t *testing.T) { issuer.RulesCache = cache + header, err := issuer.Issue(specs.Request, specs.Session, specs.Config, specs.Rule) if specs.Err == nil { - require.NoError(t, issuer.Issue(specs.Request, specs.Session, specs.Config, specs.Rule)) + // Issuer must run without error + require.NoError(t, err) } else { - err := issuer.Issue(specs.Request, specs.Session, specs.Config, specs.Rule) assert.Equal(t, specs.Err.Error(), err.Error()) } + specs.Request.Header = header assert.Equal(t, serializeCookies(overrideCookies), serializeCookies(specs.Request.Cookies())) } }) diff --git a/proxy/credentials_issuer_headers.go b/proxy/credentials_issuer_headers.go index 67c61b086c..aa5f2dbddc 100644 --- a/proxy/credentials_issuer_headers.go +++ b/proxy/credentials_issuer_headers.go @@ -39,7 +39,7 @@ func (a *CredentialsHeaders) GetID() string { return "headers" } -func (a *CredentialsHeaders) Issue(r *http.Request, session *AuthenticationSession, config json.RawMessage, rl *rule.Rule) error { +func (a *CredentialsHeaders) Issue(r *http.Request, session *AuthenticationSession, config json.RawMessage, rl *rule.Rule) (http.Header, error) { if len(config) == 0 { config = []byte("{}") } @@ -48,9 +48,10 @@ func (a *CredentialsHeaders) Issue(r *http.Request, session *AuthenticationSessi d := json.NewDecoder(bytes.NewBuffer(config)) d.DisallowUnknownFields() if err := d.Decode(&cfg); err != nil { - return errors.WithStack(err) + return nil, errors.WithStack(err) } + headers := http.Header{} for hdr, templateString := range cfg.Headers { var tmpl *template.Template var err error @@ -60,17 +61,17 @@ func (a *CredentialsHeaders) Issue(r *http.Request, session *AuthenticationSessi if tmpl == nil { tmpl, err = a.RulesCache.New(templateId).Parse(templateString) if err != nil { - return errors.Wrapf(err, `error parsing header template "%s" in rule "%s"`, templateString, rl.ID) + return nil, errors.Wrapf(err, `error parsing headers template "%s" in rule "%s"`, templateString, rl.ID) } } headerValue := bytes.Buffer{} err = tmpl.Execute(&headerValue, session) if err != nil { - return errors.Wrapf(err, `error executing header template "%s" in rule "%s"`, templateString, rl.ID) + return nil, errors.Wrapf(err, `error executing headers template "%s" in rule "%s"`, templateString, rl.ID) } - r.Header.Set(hdr, headerValue.String()) + headers.Set(hdr, headerValue.String()) } - return nil + return headers, nil } diff --git a/proxy/credentials_issuer_headers_test.go b/proxy/credentials_issuer_headers_test.go index 72cf52371a..a242d73f17 100644 --- a/proxy/credentials_issuer_headers_test.go +++ b/proxy/credentials_issuer_headers_test.go @@ -120,14 +120,19 @@ func TestCredentialsIssuerHeaders(t *testing.T) { // Issuer must return non-empty ID assert.NotEmpty(t, issuer.GetID()) + header, err := issuer.Issue(specs.Request, specs.Session, specs.Config, specs.Rule) if specs.Err == nil { // Issuer must run without error - require.NoError(t, issuer.Issue(specs.Request, specs.Session, specs.Config, specs.Rule)) + require.NoError(t, err) } else { - err := issuer.Issue(specs.Request, specs.Session, specs.Config, specs.Rule) assert.Equal(t, specs.Err.Error(), err.Error()) } + specs.Request.Header = header + if header == nil { + specs.Request.Header = http.Header{} + } + // Output request headers must match test specs assert.Equal(t, specs.Match, specs.Request.Header) }) @@ -145,7 +150,7 @@ func TestCredentialsIssuerHeaders(t *testing.T) { d := json.NewDecoder(bytes.NewBuffer(specs.Config)) d.Decode(&cfg) - for hdr, _ := range cfg.Headers { + for hdr := range cfg.Headers { templateId := fmt.Sprintf("%s:%s", specs.Rule.ID, hdr) cache.New(templateId).Parse("override") overrideHeaders.Add(hdr, "override") @@ -153,14 +158,19 @@ func TestCredentialsIssuerHeaders(t *testing.T) { issuer.RulesCache = cache + header, err := issuer.Issue(specs.Request, specs.Session, specs.Config, specs.Rule) if specs.Err == nil { // Issuer must run without error - require.NoError(t, issuer.Issue(specs.Request, specs.Session, specs.Config, specs.Rule)) + require.NoError(t, err) } else { - err := issuer.Issue(specs.Request, specs.Session, specs.Config, specs.Rule) assert.Equal(t, specs.Err.Error(), err.Error()) } + specs.Request.Header = header + if header == nil { + specs.Request.Header = http.Header{} + } + assert.Equal(t, overrideHeaders, specs.Request.Header) } }) diff --git a/proxy/credentials_issuer_id_token.go b/proxy/credentials_issuer_id_token.go index 6a8e14f465..f832a86da2 100644 --- a/proxy/credentials_issuer_id_token.go +++ b/proxy/credentials_issuer_id_token.go @@ -78,10 +78,10 @@ func (c *Claims) Valid() error { return nil } -func (a *CredentialsIDToken) Issue(r *http.Request, session *AuthenticationSession, config json.RawMessage, rl *rule.Rule) error { +func (a *CredentialsIDToken) Issue(r *http.Request, session *AuthenticationSession, config json.RawMessage, rl *rule.Rule) (http.Header, error) { privateKey, err := a.km.PrivateKey() if err != nil { - return errors.WithStack(err) + return nil, errors.WithStack(err) } if len(config) == 0 { config = []byte("{}") @@ -91,7 +91,7 @@ func (a *CredentialsIDToken) Issue(r *http.Request, session *AuthenticationSessi d := json.NewDecoder(bytes.NewBuffer(config)) d.DisallowUnknownFields() if err := d.Decode(&cc); err != nil { - return errors.WithStack(err) + return nil, errors.WithStack(err) } now := time.Now().UTC() @@ -120,16 +120,17 @@ func (a *CredentialsIDToken) Issue(r *http.Request, session *AuthenticationSessi case "HS256": token = jwt.NewWithClaims(jwt.SigningMethodHS256, claims) default: - return errors.Errorf("Encountered unknown signing algorithm %s while signing ID Token", a.km.Algorithm()) + return nil, errors.Errorf("Encountered unknown signing algorithm %s while signing ID Token", a.km.Algorithm()) } token.Header["kid"] = a.km.PublicKeyID() signed, err := token.SignedString(privateKey) if err != nil { - return errors.WithStack(err) + return nil, errors.WithStack(err) } - r.Header.Set("Authorization", "Bearer "+signed) - return nil + headers := http.Header{} + headers.Set("Authorization", "Bearer "+signed) + return headers, nil } diff --git a/proxy/credentials_issuer_id_token_test.go b/proxy/credentials_issuer_id_token_test.go index 94b3d82e56..f028a9ba5e 100644 --- a/proxy/credentials_issuer_id_token_test.go +++ b/proxy/credentials_issuer_id_token_test.go @@ -52,7 +52,10 @@ func TestCredentialsIssuerIDToken(t *testing.T) { r := &http.Request{Header: http.Header{}} s := &AuthenticationSession{Subject: "foo"} - require.NoError(t, b.Issue(r, s, json.RawMessage([]byte(`{ "aud": ["foo", "bar"] }`)), nil)) + + header, err := b.Issue(r, s, json.RawMessage([]byte(`{ "aud": ["foo", "bar"] }`)), nil) + require.NoError(t, err) + r.Header = header generated := strings.Replace(r.Header.Get("Authorization"), "Bearer ", "", 1) token, err := jwt.ParseWithClaims(generated, new(Claims), func(token *jwt.Token) (interface{}, error) { diff --git a/proxy/credentials_issuer_noop.go b/proxy/credentials_issuer_noop.go index 4da8b82e69..7abfe24590 100644 --- a/proxy/credentials_issuer_noop.go +++ b/proxy/credentials_issuer_noop.go @@ -37,6 +37,6 @@ func (a *CredentialsIssuerNoOp) GetID() string { return "noop" } -func (a *CredentialsIssuerNoOp) Issue(r *http.Request, session *AuthenticationSession, config json.RawMessage, rl *rule.Rule) error { - return nil +func (a *CredentialsIssuerNoOp) Issue(r *http.Request, session *AuthenticationSession, config json.RawMessage, rl *rule.Rule) (http.Header, error) { + return r.Header, nil } diff --git a/proxy/credentials_issuer_noop_test.go b/proxy/credentials_issuer_noop_test.go index 489d19d6b4..0b4c776706 100644 --- a/proxy/credentials_issuer_noop_test.go +++ b/proxy/credentials_issuer_noop_test.go @@ -21,6 +21,7 @@ package proxy import ( + "net/http" "testing" "github.com/stretchr/testify/assert" @@ -31,5 +32,6 @@ func TestCredentialsIssuerNoOp(t *testing.T) { assert.NotNil(t, NewCredentialsIssuerNoOp()) assert.NotEmpty(t, NewCredentialsIssuerNoOp().GetID()) - require.NoError(t, NewCredentialsIssuerNoOp().Issue(nil, nil, nil, nil)) + _, err := NewCredentialsIssuerNoOp().Issue(&http.Request{Header: map[string][]string{}}, nil, nil, nil) + require.NoError(t, err) } diff --git a/proxy/proxy.go b/proxy/proxy.go index 41decc787e..7a151d003a 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -118,11 +118,16 @@ func (d *Proxy) Director(r *http.Request) { return } - if err := d.RequestHandler.HandleRequest(r, rl); err != nil { + headers, err := d.RequestHandler.HandleRequest(r, rl) + if err != nil { *r = *r.WithContext(context.WithValue(r.Context(), director, err)) return } + for h := range headers { + r.Header.Set(h, headers.Get(h)) + } + if err := configureBackendURL(r, rl); err != nil { *r = *r.WithContext(context.WithValue(r.Context(), director, err)) return diff --git a/proxy/request_handler.go b/proxy/request_handler.go index 4e8d903c94..bee8dc2c44 100644 --- a/proxy/request_handler.go +++ b/proxy/request_handler.go @@ -70,7 +70,7 @@ func NewRequestHandler( return j } -func (d *RequestHandler) HandleRequest(r *http.Request, rl *rule.Rule) error { +func (d *RequestHandler) HandleRequest(r *http.Request, rl *rule.Rule) (http.Header, error) { var err error var session *AuthenticationSession var found bool @@ -82,7 +82,7 @@ func (d *RequestHandler) HandleRequest(r *http.Request, rl *rule.Rule) error { WithField("access_url", r.URL.String()). WithField("reason_id", "authentication_handler_missing"). Warn("No authentication handler was set in the rule") - return err + return nil, err } for _, a := range rl.Authenticators { @@ -94,7 +94,7 @@ func (d *RequestHandler) HandleRequest(r *http.Request, rl *rule.Rule) error { WithField("authentication_handler", a.Handler). WithField("reason_id", "unknown_authentication_handler"). Warn("Unknown authentication handler requested") - return errors.New("Unknown authentication handler requested") + return nil, errors.New("Unknown authentication handler requested") } session, err = anh.Authenticate(r, a.Config, rl) @@ -114,7 +114,7 @@ func (d *RequestHandler) HandleRequest(r *http.Request, rl *rule.Rule) error { WithField("authentication_handler", a.Handler). WithField("reason_id", "authentication_handler_error"). Warn("The authentication handler encountered an error") - return err + return nil, err } } else { // The first authenticator that matches must return the session @@ -130,7 +130,7 @@ func (d *RequestHandler) HandleRequest(r *http.Request, rl *rule.Rule) error { WithField("access_url", r.URL.String()). WithField("reason_id", "authentication_handler_no_match"). Warn("No authentication handler was responsible for handling the authentication request") - return err + return nil, err } azh, ok := d.AuthorizationHandlers[rl.Authorizer.Handler] @@ -141,7 +141,7 @@ func (d *RequestHandler) HandleRequest(r *http.Request, rl *rule.Rule) error { WithField("authorization_handler", rl.Authorizer.Handler). WithField("reason_id", "unknown_authorization_handler"). Warn("Unknown authentication handler requested") - return errors.New("Unknown authorization handler requested") + return nil, errors.New("Unknown authorization handler requested") } if err := azh.Authorize(r, session, rl.Authorizer.Config, rl); err != nil { @@ -152,7 +152,7 @@ func (d *RequestHandler) HandleRequest(r *http.Request, rl *rule.Rule) error { WithField("authorization_handler", rl.Authorizer.Handler). WithField("reason_id", "authorization_handler_error"). Warn("The authorization handler encountered an error") - return err + return nil, err } sh, ok := d.CredentialIssuers[rl.CredentialsIssuer.Handler] @@ -163,12 +163,13 @@ func (d *RequestHandler) HandleRequest(r *http.Request, rl *rule.Rule) error { WithField("session_handler", rl.CredentialsIssuer.Handler). WithField("reason_id", "unknown_credential_issuer"). Warn("Unknown credential issuer requested") - return errors.New("Unknown credential issuer requested") + return nil, errors.New("Unknown credential issuer requested") } - if err := sh.Issue(r, session, rl.CredentialsIssuer.Config, rl); err != nil { - return err + headers, err := sh.Issue(r, session, rl.CredentialsIssuer.Config, rl) + if err != nil { + return nil, err } - return nil + return headers, nil } diff --git a/proxy/request_handler_test.go b/proxy/request_handler_test.go index 6f54e6c409..976457a39a 100644 --- a/proxy/request_handler_test.go +++ b/proxy/request_handler_test.go @@ -123,10 +123,11 @@ func TestRequestHandler(t *testing.T) { }, } { t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { + _, err := tc.j.HandleRequest(tc.r, &tc.rule) if tc.expectErr { - require.Error(t, tc.j.HandleRequest(tc.r, &tc.rule)) + require.Error(t, err) } else { - require.NoError(t, tc.j.HandleRequest(tc.r, &tc.rule)) + require.NoError(t, err) } }) }