Skip to content

Commit

Permalink
judge: Set request headers for credential issuers
Browse files Browse the repository at this point in the history
Closes #120
Closes #133

Signed-off-by: aeneasr <aeneas@ory.sh>
  • Loading branch information
aeneasr committed Apr 6, 2019
1 parent f9fdefb commit 548ad67
Show file tree
Hide file tree
Showing 16 changed files with 89 additions and 57 deletions.
8 changes: 6 additions & 2 deletions judge/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ func (h *Handler) judge(w http.ResponseWriter, r *http.Request) {
return
}

if err := h.RequestHandler.HandleRequest(r, rl); err != nil {
headers, err := h.RequestHandler.HandleRequest(r, rl)
if err != nil {
h.Logger.WithError(err).
WithField("granted", false).
WithField("access_url", r.URL.String()).
Expand All @@ -114,6 +115,9 @@ func (h *Handler) judge(w http.ResponseWriter, r *http.Request) {
WithField("access_url", r.URL.String()).
Warn("Access request granted")

w.Header().Set("Authorization", r.Header.Get("Authorization"))
for k := range headers {
w.Header().Set(k, headers.Get(k))
}

w.WriteHeader(http.StatusOK)
}
2 changes: 1 addition & 1 deletion judge/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ func TestJudge(t *testing.T) {
require.NoError(t, err)
defer res.Body.Close()

assert.Equal(t, res.Header.Get("Authorization"), tc.authz)
assert.Equal(t, tc.authz, res.Header.Get("Authorization"))
assert.Equal(t, tc.code, res.StatusCode)
})
}
Expand Down
2 changes: 1 addition & 1 deletion proxy/credentials_issuer.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,6 @@ import (
)

type CredentialsIssuer interface {
Issue(r *http.Request, session *AuthenticationSession, config json.RawMessage, rl *rule.Rule) error
Issue(r *http.Request, session *AuthenticationSession, config json.RawMessage, rl *rule.Rule) (http.Header, error)
GetID() string
}
4 changes: 2 additions & 2 deletions proxy/credentials_issuer_broken.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,6 @@ func (a *CredentialsIssuerBroken) GetID() string {
return "broken"
}

func (a *CredentialsIssuerBroken) Issue(r *http.Request, session *AuthenticationSession, config json.RawMessage, rl *rule.Rule) error {
return errors.New("forced denial of credentials")
func (a *CredentialsIssuerBroken) Issue(r *http.Request, session *AuthenticationSession, config json.RawMessage, rl *rule.Rule) (http.Header, error) {
return nil, errors.New("forced denial of credentials")
}
3 changes: 2 additions & 1 deletion proxy/credentials_issuer_broken_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,6 @@ func TestCredentialsIssuerBroken(t *testing.T) {
assert.NotNil(t, b)
assert.NotEmpty(t, b.GetID())

require.Error(t, b.Issue(nil, nil, nil, nil))
_, err := b.Issue(nil, nil, nil, nil)
require.Error(t, err)
}
17 changes: 8 additions & 9 deletions proxy/credentials_issuer_cookies.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,15 @@ func (a *CredentialsCookies) GetID() string {
return "cookies"
}

func (a *CredentialsCookies) Issue(r *http.Request, session *AuthenticationSession, config json.RawMessage, rl *rule.Rule) error {
func (a *CredentialsCookies) Issue(r *http.Request, session *AuthenticationSession, config json.RawMessage, rl *rule.Rule) (http.Header, error) {
if len(config) == 0 {
config = []byte("{}")
}

// Cache request cookies
requestCookies := r.Cookies()

// Remove existing cookies
r.Header.Del("Cookie")
req := http.Request{Header: map[string][]string{}}

// Keep track of rule cookies in a map
cookies := map[string]bool{}
Expand All @@ -57,7 +56,7 @@ func (a *CredentialsCookies) Issue(r *http.Request, session *AuthenticationSessi
d := json.NewDecoder(bytes.NewBuffer(config))
d.DisallowUnknownFields()
if err := d.Decode(&cfg); err != nil {
return errors.WithStack(err)
return nil, errors.WithStack(err)
}

for cookie, templateString := range cfg.Cookies {
Expand All @@ -69,17 +68,17 @@ func (a *CredentialsCookies) Issue(r *http.Request, session *AuthenticationSessi
if tmpl == nil {
tmpl, err = a.RulesCache.New(templateId).Parse(templateString)
if err != nil {
return errors.Wrapf(err, `error parsing cookie template "%s" in rule "%s"`, templateString, rl.ID)
return nil, errors.Wrapf(err, `error parsing cookie template "%s" in rule "%s"`, templateString, rl.ID)
}
}

cookieValue := bytes.Buffer{}
err = tmpl.Execute(&cookieValue, session)
if err != nil {
return errors.Wrapf(err, `error executing cookie template "%s" in rule "%s"`, templateString, rl.ID)
return nil, errors.Wrapf(err, `error executing cookie template "%s" in rule "%s"`, templateString, rl.ID)
}

r.AddCookie(&http.Cookie{
req.AddCookie(&http.Cookie{
Name: cookie,
Value: cookieValue.String(),
})
Expand All @@ -92,9 +91,9 @@ func (a *CredentialsCookies) Issue(r *http.Request, session *AuthenticationSessi
// Test if cookie is handled by rule
if _, ok := cookies[cookie.Name]; !ok {
// Re-add cookie if not handled by rule
r.AddCookie(cookie)
req.AddCookie(cookie)
}
}

return nil
return req.Header, nil
}
14 changes: 9 additions & 5 deletions proxy/credentials_issuer_cookies_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,15 @@ func TestCredentialsIssuerCookies(t *testing.T) {
// Issuer must return non-empty ID
assert.NotEmpty(t, issuer.GetID())

header, err := issuer.Issue(specs.Request, specs.Session, specs.Config, specs.Rule)
if specs.Err == nil {
require.NoError(t, issuer.Issue(specs.Request, specs.Session, specs.Config, specs.Rule))
// Issuer must run without error
require.NoError(t, err)
} else {
err := issuer.Issue(specs.Request, specs.Session, specs.Config, specs.Rule)
assert.Equal(t, specs.Err.Error(), err.Error())
}

specs.Request.Header = header
assert.Equal(t, serializeCookies(specs.Match), serializeCookies(specs.Request.Cookies()))
})
}
Expand All @@ -154,21 +156,23 @@ func TestCredentialsIssuerCookies(t *testing.T) {
d := json.NewDecoder(bytes.NewBuffer(specs.Config))
d.Decode(&cfg)

for cookie, _ := range cfg.Cookies {
for cookie := range cfg.Cookies {
templateId := fmt.Sprintf("%s:%s", specs.Rule.ID, cookie)
cache.New(templateId).Parse("override")
overrideCookies = append(overrideCookies, &http.Cookie{Name: cookie, Value: "override"})
}

issuer.RulesCache = cache

header, err := issuer.Issue(specs.Request, specs.Session, specs.Config, specs.Rule)
if specs.Err == nil {
require.NoError(t, issuer.Issue(specs.Request, specs.Session, specs.Config, specs.Rule))
// Issuer must run without error
require.NoError(t, err)
} else {
err := issuer.Issue(specs.Request, specs.Session, specs.Config, specs.Rule)
assert.Equal(t, specs.Err.Error(), err.Error())
}

specs.Request.Header = header
assert.Equal(t, serializeCookies(overrideCookies), serializeCookies(specs.Request.Cookies()))
}
})
Expand Down
13 changes: 7 additions & 6 deletions proxy/credentials_issuer_headers.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func (a *CredentialsHeaders) GetID() string {
return "headers"
}

func (a *CredentialsHeaders) Issue(r *http.Request, session *AuthenticationSession, config json.RawMessage, rl *rule.Rule) error {
func (a *CredentialsHeaders) Issue(r *http.Request, session *AuthenticationSession, config json.RawMessage, rl *rule.Rule) (http.Header, error) {
if len(config) == 0 {
config = []byte("{}")
}
Expand All @@ -48,9 +48,10 @@ func (a *CredentialsHeaders) Issue(r *http.Request, session *AuthenticationSessi
d := json.NewDecoder(bytes.NewBuffer(config))
d.DisallowUnknownFields()
if err := d.Decode(&cfg); err != nil {
return errors.WithStack(err)
return nil, errors.WithStack(err)
}

headers := http.Header{}
for hdr, templateString := range cfg.Headers {
var tmpl *template.Template
var err error
Expand All @@ -60,17 +61,17 @@ func (a *CredentialsHeaders) Issue(r *http.Request, session *AuthenticationSessi
if tmpl == nil {
tmpl, err = a.RulesCache.New(templateId).Parse(templateString)
if err != nil {
return errors.Wrapf(err, `error parsing header template "%s" in rule "%s"`, templateString, rl.ID)
return nil, errors.Wrapf(err, `error parsing headers template "%s" in rule "%s"`, templateString, rl.ID)
}
}

headerValue := bytes.Buffer{}
err = tmpl.Execute(&headerValue, session)
if err != nil {
return errors.Wrapf(err, `error executing header template "%s" in rule "%s"`, templateString, rl.ID)
return nil, errors.Wrapf(err, `error executing headers template "%s" in rule "%s"`, templateString, rl.ID)
}
r.Header.Set(hdr, headerValue.String())
headers.Set(hdr, headerValue.String())
}

return nil
return headers, nil
}
20 changes: 15 additions & 5 deletions proxy/credentials_issuer_headers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,19 @@ func TestCredentialsIssuerHeaders(t *testing.T) {
// Issuer must return non-empty ID
assert.NotEmpty(t, issuer.GetID())

header, err := issuer.Issue(specs.Request, specs.Session, specs.Config, specs.Rule)
if specs.Err == nil {
// Issuer must run without error
require.NoError(t, issuer.Issue(specs.Request, specs.Session, specs.Config, specs.Rule))
require.NoError(t, err)
} else {
err := issuer.Issue(specs.Request, specs.Session, specs.Config, specs.Rule)
assert.Equal(t, specs.Err.Error(), err.Error())
}

specs.Request.Header = header
if header == nil {
specs.Request.Header = http.Header{}
}

// Output request headers must match test specs
assert.Equal(t, specs.Match, specs.Request.Header)
})
Expand All @@ -145,22 +150,27 @@ func TestCredentialsIssuerHeaders(t *testing.T) {
d := json.NewDecoder(bytes.NewBuffer(specs.Config))
d.Decode(&cfg)

for hdr, _ := range cfg.Headers {
for hdr := range cfg.Headers {
templateId := fmt.Sprintf("%s:%s", specs.Rule.ID, hdr)
cache.New(templateId).Parse("override")
overrideHeaders.Add(hdr, "override")
}

issuer.RulesCache = cache

header, err := issuer.Issue(specs.Request, specs.Session, specs.Config, specs.Rule)
if specs.Err == nil {
// Issuer must run without error
require.NoError(t, issuer.Issue(specs.Request, specs.Session, specs.Config, specs.Rule))
require.NoError(t, err)
} else {
err := issuer.Issue(specs.Request, specs.Session, specs.Config, specs.Rule)
assert.Equal(t, specs.Err.Error(), err.Error())
}

specs.Request.Header = header
if header == nil {
specs.Request.Header = http.Header{}
}

assert.Equal(t, overrideHeaders, specs.Request.Header)
}
})
Expand Down
15 changes: 8 additions & 7 deletions proxy/credentials_issuer_id_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,10 @@ func (c *Claims) Valid() error {
return nil
}

func (a *CredentialsIDToken) Issue(r *http.Request, session *AuthenticationSession, config json.RawMessage, rl *rule.Rule) error {
func (a *CredentialsIDToken) Issue(r *http.Request, session *AuthenticationSession, config json.RawMessage, rl *rule.Rule) (http.Header, error) {
privateKey, err := a.km.PrivateKey()
if err != nil {
return errors.WithStack(err)
return nil, errors.WithStack(err)
}
if len(config) == 0 {
config = []byte("{}")
Expand All @@ -91,7 +91,7 @@ func (a *CredentialsIDToken) Issue(r *http.Request, session *AuthenticationSessi
d := json.NewDecoder(bytes.NewBuffer(config))
d.DisallowUnknownFields()
if err := d.Decode(&cc); err != nil {
return errors.WithStack(err)
return nil, errors.WithStack(err)
}

now := time.Now().UTC()
Expand Down Expand Up @@ -120,16 +120,17 @@ func (a *CredentialsIDToken) Issue(r *http.Request, session *AuthenticationSessi
case "HS256":
token = jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
default:
return errors.Errorf("Encountered unknown signing algorithm %s while signing ID Token", a.km.Algorithm())
return nil, errors.Errorf("Encountered unknown signing algorithm %s while signing ID Token", a.km.Algorithm())
}

token.Header["kid"] = a.km.PublicKeyID()

signed, err := token.SignedString(privateKey)
if err != nil {
return errors.WithStack(err)
return nil, errors.WithStack(err)
}

r.Header.Set("Authorization", "Bearer "+signed)
return nil
headers := http.Header{}
headers.Set("Authorization", "Bearer "+signed)
return headers, nil
}
5 changes: 4 additions & 1 deletion proxy/credentials_issuer_id_token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ func TestCredentialsIssuerIDToken(t *testing.T) {

r := &http.Request{Header: http.Header{}}
s := &AuthenticationSession{Subject: "foo"}
require.NoError(t, b.Issue(r, s, json.RawMessage([]byte(`{ "aud": ["foo", "bar"] }`)), nil))

header, err := b.Issue(r, s, json.RawMessage([]byte(`{ "aud": ["foo", "bar"] }`)), nil)
require.NoError(t, err)
r.Header = header

generated := strings.Replace(r.Header.Get("Authorization"), "Bearer ", "", 1)
token, err := jwt.ParseWithClaims(generated, new(Claims), func(token *jwt.Token) (interface{}, error) {
Expand Down
4 changes: 2 additions & 2 deletions proxy/credentials_issuer_noop.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,6 @@ func (a *CredentialsIssuerNoOp) GetID() string {
return "noop"
}

func (a *CredentialsIssuerNoOp) Issue(r *http.Request, session *AuthenticationSession, config json.RawMessage, rl *rule.Rule) error {
return nil
func (a *CredentialsIssuerNoOp) Issue(r *http.Request, session *AuthenticationSession, config json.RawMessage, rl *rule.Rule) (http.Header, error) {
return r.Header, nil
}
4 changes: 3 additions & 1 deletion proxy/credentials_issuer_noop_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
package proxy

import (
"net/http"
"testing"

"github.com/stretchr/testify/assert"
Expand All @@ -31,5 +32,6 @@ func TestCredentialsIssuerNoOp(t *testing.T) {
assert.NotNil(t, NewCredentialsIssuerNoOp())
assert.NotEmpty(t, NewCredentialsIssuerNoOp().GetID())

require.NoError(t, NewCredentialsIssuerNoOp().Issue(nil, nil, nil, nil))
_, err := NewCredentialsIssuerNoOp().Issue(&http.Request{Header: map[string][]string{}}, nil, nil, nil)
require.NoError(t,err)
}
7 changes: 6 additions & 1 deletion proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,16 @@ func (d *Proxy) Director(r *http.Request) {
return
}

if err := d.RequestHandler.HandleRequest(r, rl); err != nil {
headers, err := d.RequestHandler.HandleRequest(r, rl)
if err != nil {
*r = *r.WithContext(context.WithValue(r.Context(), director, err))
return
}

for h := range headers {
r.Header.Set(h, headers.Get(h))
}

if err := configureBackendURL(r, rl); err != nil {
*r = *r.WithContext(context.WithValue(r.Context(), director, err))
return
Expand Down
Loading

0 comments on commit 548ad67

Please sign in to comment.