Skip to content

Commit

Permalink
Fix Location response header http to https when SSL
Browse files Browse the repository at this point in the history
  • Loading branch information
unrolled committed Sep 13, 2019
1 parent 48ade6b commit 624f918
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 0 deletions.
7 changes: 7 additions & 0 deletions secure.go
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,13 @@ func (s *Secure) isSSL(r *http.Request) bool {
// Used by http.ReverseProxy.
func (s *Secure) ModifyResponseHeaders(res *http.Response) error {
if res != nil && res.Request != nil {
// Fix Location response header http to https when SSL is enabled.
location := res.Header.Get("Location")
if s.isSSL(res.Request) && strings.Contains(location, "http:") {
location = strings.Replace(location, "http:", "https:", 1)
res.Header.Set("Location", location)
}

responseHeader := res.Request.Context().Value(ctxSecureHeaderKey)
if responseHeader != nil {
for header, values := range responseHeader.(http.Header) {
Expand Down
37 changes: 37 additions & 0 deletions secure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1237,6 +1237,43 @@ func TestSSLForceHostTemporaryRedirect(t *testing.T) {
expect(t, res.Code, http.StatusTemporaryRedirect)
}

func TestModifyResponseHeadersNoSSL(t *testing.T) {
s := New(Options{
SSLRedirect: false,
})

res := &http.Response{}
res.Header = http.Header{"Location": []string{"http://example.com"}}

err := s.ModifyResponseHeaders(res)
expect(t, err, nil)

expect(t, res.Header.Get("Location"), "http://example.com")
}

func TestModifyResponseHeadersWithSSL(t *testing.T) {
s := New(Options{
SSLRedirect: true,
SSLProxyHeaders: map[string]string{"X-Forwarded-Proto": "https"},
})

req, _ := http.NewRequest("GET", "/foo", nil)
req.Host = "www.example.com"
req.URL.Scheme = "http"
req.Header.Add("X-Forwarded-Proto", "https")

res := &http.Response{}
res.Header = http.Header{"Location": []string{"http://example.com"}}
res.Request = req

expect(t, res.Header.Get("Location"), "http://example.com")

err := s.ModifyResponseHeaders(res)
expect(t, err, nil)

expect(t, res.Header.Get("Location"), "https://example.com")
}

/* Test Helpers */
func expect(t *testing.T, a interface{}, b interface{}) {
if a != b {
Expand Down

0 comments on commit 624f918

Please sign in to comment.