Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove hop-by-hop headers defined in connection header before some middleware #8319

Merged
merged 3 commits into from Jul 30, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Next
fix: remove hop-by-hop headers define in connection header.
  • Loading branch information
ldez committed Jul 29, 2021
commit cbaf86a93014a969b8accf39301932c17d0d73f9
3 changes: 2 additions & 1 deletion pkg/middlewares/auth/forward.go
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/traefik/traefik/v2/pkg/config/dynamic"
"github.com/traefik/traefik/v2/pkg/log"
"github.com/traefik/traefik/v2/pkg/middlewares"
"github.com/traefik/traefik/v2/pkg/middlewares/connectionheader"
"github.com/traefik/traefik/v2/pkg/tracing"
"github.com/vulcand/oxy/forward"
"github.com/vulcand/oxy/utils"
Expand Down Expand Up @@ -89,7 +90,7 @@ func NewForward(ctx context.Context, next http.Handler, config dynamic.ForwardAu
fa.authResponseHeadersRegex = re
}

return fa, nil
return connectionheader.Remove(fa), nil
ldez marked this conversation as resolved.
Show resolved Hide resolved
}

func (fa *forwardAuth) GetTracingInformation() (string, ext.SpanKindEnum) {
Expand Down
44 changes: 44 additions & 0 deletions pkg/middlewares/connectionheader/connectionheader.go
@@ -0,0 +1,44 @@
package connectionheader

import (
"net/http"
"net/textproto"
"strings"

"golang.org/x/net/http/httpguts"
)

// Remove removes hop-by-hop headers listed in the "Connection" header of h.
// See RFC 7230, section 6.1.
func Remove(next http.Handler) http.HandlerFunc {
return func(rw http.ResponseWriter, req *http.Request) {
reqUpType := upgradeType(req.Header)
removeConnectionHeaders(req.Header)

if reqUpType != "" {
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Upgrade", reqUpType)
} else {
req.Header.Del("Connection")
}

next.ServeHTTP(rw, req)
}
}

func removeConnectionHeaders(h http.Header) {
for _, f := range h["Connection"] {
for _, sf := range strings.Split(f, ",") {
if sf = textproto.TrimString(sf); sf != "" {
h.Del(sf)
}
}
}
}

func upgradeType(h http.Header) string {
if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
return ""
}
return h.Get("Upgrade")
}
ldez marked this conversation as resolved.
Show resolved Hide resolved
71 changes: 71 additions & 0 deletions pkg/middlewares/connectionheader/connectionheader_test.go
@@ -0,0 +1,71 @@
package connectionheader

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
)

func TestRemove(t *testing.T) {
testCases := []struct {
desc string
reqHeaders map[string]string
expected http.Header
}{
{
desc: "simple remove",
reqHeaders: map[string]string{
"Foo": "bar",
"Connection": "foo",
},
expected: http.Header{},
},
{
desc: "remove and Upgrade",
reqHeaders: map[string]string{
"Upgrade": "test",
"Foo": "bar",
"Connection": "Upgrade,foo",
},
expected: http.Header{
"Upgrade": []string{"test"},
"Connection": []string{"Upgrade"},
},
},
{
desc: "no remove",
reqHeaders: map[string]string{
"Foo": "bar",
"Connection": "fii",
},
expected: http.Header{
"Foo": []string{"bar"},
},
},
}

for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()

next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {})

h := Remove(next)

req := httptest.NewRequest(http.MethodGet, "https://localhost", nil)

for k, v := range test.reqHeaders {
req.Header.Set(k, v)
}

rw := httptest.NewRecorder()

h.ServeHTTP(rw, req)

assert.Equal(t, test.expected, req.Header)
})
}
}
6 changes: 4 additions & 2 deletions pkg/middlewares/headers/headers.go
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/traefik/traefik/v2/pkg/config/dynamic"
"github.com/traefik/traefik/v2/pkg/log"
"github.com/traefik/traefik/v2/pkg/middlewares"
"github.com/traefik/traefik/v2/pkg/middlewares/connectionheader"
"github.com/traefik/traefik/v2/pkg/tracing"
)

Expand Down Expand Up @@ -58,11 +59,12 @@ func New(ctx context.Context, next http.Handler, cfg dynamic.Headers, name strin

if hasCustomHeaders || hasCorsHeaders {
logger.Debugf("Setting up customHeaders/Cors from %v", cfg)
var err error
handler, err = NewHeader(nextHandler, cfg)
h, err := NewHeader(nextHandler, cfg)
if err != nil {
return nil, err
}

handler = connectionheader.Remove(h)
}

return &headers{
Expand Down