Skip to content

Commit

Permalink
fix: copy response headers when auth failed.
Browse files Browse the repository at this point in the history
  • Loading branch information
ldez authored and traefiker committed Apr 23, 2018
1 parent 76dcbe3 commit e741974
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 10 deletions.
12 changes: 6 additions & 6 deletions middlewares/auth/forward.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,20 @@ func Forward(config *types.Forward, w http.ResponseWriter, r *http.Request, next
return http.ErrUseLastResponse
},
}

if config.TLS != nil {
tlsConfig, err := config.TLS.CreateTLSConfig()
if err != nil {
tracing.SetErrorAndDebugLog(r, "Unable to configure TLS to call %s. Cause %s", config.Address, err)
w.WriteHeader(http.StatusInternalServerError)
return
}

httpClient.Transport = &http.Transport{
TLSClientConfig: tlsConfig,
}
}

forwardReq, err := http.NewRequest(http.MethodGet, config.Address, nil)
tracing.LogRequest(tracing.GetSpan(r), forwardReq)
if err != nil {
Expand Down Expand Up @@ -68,6 +71,8 @@ func Forward(config *types.Forward, w http.ResponseWriter, r *http.Request, next
if forwardResponse.StatusCode < http.StatusOK || forwardResponse.StatusCode >= http.StatusMultipleChoices {
log.Debugf("Remote error %s. StatusCode: %d", config.Address, forwardResponse.StatusCode)

utils.CopyHeaders(w.Header(), forwardResponse.Header)

// Grab the location header, if any.
redirectURL, err := forwardResponse.Location()

Expand All @@ -79,12 +84,7 @@ func Forward(config *types.Forward, w http.ResponseWriter, r *http.Request, next
}
} else if redirectURL.String() != "" {
// Set the location in our response if one was sent back.
w.Header().Add("Location", redirectURL.String())
}

// Pass any Set-Cookie headers the forward auth server provides
for _, cookie := range forwardResponse.Cookies() {
w.Header().Add("Set-Cookie", cookie.String())
w.Header().Set("Location", redirectURL.String())
}

tracing.LogResponseCode(tracing.GetSpan(r), forwardResponse.StatusCode)
Expand Down
22 changes: 18 additions & 4 deletions middlewares/auth/forward_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/containous/traefik/testhelpers"
"github.com/containous/traefik/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/urfave/negroni"
)

Expand Down Expand Up @@ -110,7 +111,6 @@ func TestForwardAuthRedirect(t *testing.T) {
assert.Equal(t, http.StatusFound, res.StatusCode, "they should be equal")

location, err := res.Location()

assert.NoError(t, err, "there should be no error")
assert.Equal(t, "http://example.com/redirect-test", location.String(), "they should be equal")

Expand All @@ -119,10 +119,11 @@ func TestForwardAuthRedirect(t *testing.T) {
assert.NotEmpty(t, string(body), "there should be something in the body")
}

func TestForwardAuthCookie(t *testing.T) {
func TestForwardAuthFailResponseHeaders(t *testing.T) {
authTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cookie := &http.Cookie{Name: "example", Value: "testing", Path: "/"}
http.SetCookie(w, cookie)
w.Header().Add("X-Foo", "bar")
http.Error(w, "Forbidden", http.StatusForbidden)
}))
defer authTs.Close()
Expand All @@ -142,23 +143,36 @@ func TestForwardAuthCookie(t *testing.T) {
ts := httptest.NewServer(n)
defer ts.Close()

client := &http.Client{}
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
client := &http.Client{}
res, err := client.Do(req)
assert.NoError(t, err, "there should be no error")
assert.Equal(t, http.StatusForbidden, res.StatusCode, "they should be equal")

require.Len(t, res.Cookies(), 1)
for _, cookie := range res.Cookies() {
assert.Equal(t, "testing", cookie.Value, "they should be equal")
}

expectedHeaders := http.Header{
"Content-Length": []string{"10"},
"Content-Type": []string{"text/plain; charset=utf-8"},
"X-Foo": []string{"bar"},
"Set-Cookie": []string{"example=testing; Path=/"},
"X-Content-Type-Options": []string{"nosniff"},
}

assert.Len(t, res.Header, 6)
for key, value := range expectedHeaders {
assert.Equal(t, value, res.Header[key])
}

body, err := ioutil.ReadAll(res.Body)
assert.NoError(t, err, "there should be no error")
assert.Equal(t, "Forbidden\n", string(body), "they should be equal")
}

func Test_writeHeader(t *testing.T) {

testCases := []struct {
name string
headers map[string]string
Expand Down

0 comments on commit e741974

Please sign in to comment.