diff --git a/docs/content/routing/routers/index.md b/docs/content/routing/routers/index.md index fe123bd371..74a32c66da 100644 --- a/docs/content/routing/routers/index.md +++ b/docs/content/routing/routers/index.md @@ -839,11 +839,12 @@ If the rule is verified, the router becomes active, calls middlewares, and then The table below lists all the available matchers: -| Rule | Description | -|---------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------| -| ```HostSNI(`domain-1`, ...)``` | Check if the Server Name Indication corresponds to the given `domains`. | -| ```HostSNIRegexp(`example.com`, `{subdomain:[a-z]+}.example.com`, ...)``` | Check if the Server Name Indication matches the given regular expressions. See "Regexp Syntax" below. | -| ```ClientIP(`10.0.0.0/16`, `::1`)``` | Check if the request client IP is one of the given IP/CIDR. It accepts IPv4, IPv6 and CIDR formats. | +| Rule | Description | +|---------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------| +| ```HostSNI(`domain-1`, ...)``` | Checks if the Server Name Indication corresponds to the given `domains`. | +| ```HostSNIRegexp(`example.com`, `{subdomain:[a-z]+}.example.com`, ...)``` | Checks if the Server Name Indication matches the given regular expressions. See "Regexp Syntax" below. | +| ```ClientIP(`10.0.0.0/16`, `::1`)``` | Checks if the connection client IP is one of the given IP/CIDR. It accepts IPv4, IPv6 and CIDR formats. | +| ```ALPN(`mqtt`, `h2c`)``` | Checks if any of the connection ALPN protocols is one of the given protocols. | !!! important "Non-ASCII Domain Names" @@ -879,6 +880,13 @@ The table below lists all the available matchers: The rule is evaluated "before" any middleware has the opportunity to work, and "before" the request is forwarded to the service. +!!! important "ALPN ACME-TLS/1" + + It would be a security issue to let a user-defined router catch the response to + an ACME TLS challenge previously initiated by Traefik. + For this reason, the `ALPN` matcher is not allowed to match the `ACME-TLS/1` + protocol, and Traefik returns an error if this is attempted. + ### Priority To avoid path overlap, routes are sorted, by default, in descending order using rules length. diff --git a/pkg/muxer/tcp/mux.go b/pkg/muxer/tcp/mux.go index a70e5fd34a..e21dfc75de 100644 --- a/pkg/muxer/tcp/mux.go +++ b/pkg/muxer/tcp/mux.go @@ -10,6 +10,7 @@ import ( "strconv" "strings" + "github.com/go-acme/lego/v4/challenge/tlsalpn01" "github.com/traefik/traefik/v2/pkg/ip" "github.com/traefik/traefik/v2/pkg/log" "github.com/traefik/traefik/v2/pkg/rules" @@ -22,6 +23,7 @@ var tcpFuncs = map[string]func(*matchersTree, ...string) error{ "HostSNI": hostSNI, "HostSNIRegexp": hostSNIRegexp, "ClientIP": clientIP, + "ALPN": alpn, } // ParseHostSNI extracts the HostSNIs declared in a rule. @@ -54,10 +56,11 @@ func ParseHostSNI(rule string) ([]string, error) { type ConnData struct { serverName string remoteIP string + alpnProtos []string } // NewConnData builds a connData struct from the given parameters. -func NewConnData(serverName string, conn tcp.WriteCloser) (ConnData, error) { +func NewConnData(serverName string, conn tcp.WriteCloser, alpnProtos []string) (ConnData, error) { remoteIP, _, err := net.SplitHostPort(conn.RemoteAddr().String()) if err != nil { return ConnData{}, fmt.Errorf("error while parsing remote address %q: %w", conn.RemoteAddr().String(), err) @@ -71,6 +74,7 @@ func NewConnData(serverName string, conn tcp.WriteCloser) (ConnData, error) { return ConnData{ serverName: types.CanonicalDomain(serverName), remoteIP: remoteIP, + alpnProtos: alpnProtos, }, nil } @@ -284,6 +288,33 @@ func clientIP(tree *matchersTree, clientIPs ...string) error { return nil } +// alpn checks if any of the connection ALPN protocols matches one of the matcher protocols. +func alpn(tree *matchersTree, protos ...string) error { + if len(protos) == 0 { + return errors.New("empty value for \"ALPN\" matcher is not allowed") + } + + for _, proto := range protos { + if proto == tlsalpn01.ACMETLS1Protocol { + return fmt.Errorf("invalid protocol value for \"ALPN\" matcher, %q is not allowed", proto) + } + } + + tree.matcher = func(meta ConnData) bool { + for _, proto := range meta.alpnProtos { + for _, filter := range protos { + if proto == filter { + return true + } + } + } + + return false + } + + return nil +} + var almostFQDN = regexp.MustCompile(`^[[:alnum:]\.-]+$`) // hostSNI checks if the SNI Host of the connection match the matcher host. diff --git a/pkg/muxer/tcp/mux_test.go b/pkg/muxer/tcp/mux_test.go index 82119319ea..50b8938cf7 100644 --- a/pkg/muxer/tcp/mux_test.go +++ b/pkg/muxer/tcp/mux_test.go @@ -1,10 +1,12 @@ package tcp import ( + "fmt" "net" "testing" "time" + "github.com/go-acme/lego/v4/challenge/tlsalpn01" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/traefik/traefik/v2/pkg/tcp" @@ -58,6 +60,7 @@ func Test_addTCPRoute(t *testing.T) { rule string serverName string remoteAddr string + protos []string routeErr bool matchErr bool }{ @@ -436,6 +439,66 @@ func Test_addTCPRoute(t *testing.T) { serverName: "bar", remoteAddr: "10.0.0.1:80", }, + { + desc: "Invalid ALPN rule matching ACME-TLS/1", + rule: fmt.Sprintf("ALPN(`%s`)", tlsalpn01.ACMETLS1Protocol), + protos: []string{"foo"}, + routeErr: true, + }, + { + desc: "Valid ALPN rule matching single protocol", + rule: "ALPN(`foo`)", + protos: []string{"foo"}, + }, + { + desc: "Valid ALPN rule matching ACME-TLS/1 protocol", + rule: "ALPN(`foo`)", + protos: []string{tlsalpn01.ACMETLS1Protocol}, + matchErr: true, + }, + { + desc: "Valid ALPN rule not matching single protocol", + rule: "ALPN(`foo`)", + protos: []string{"bar"}, + matchErr: true, + }, + { + desc: "Valid alternative case ALPN rule matching single protocol without another being supported", + rule: "ALPN(`foo`) && !alpn(`h2`)", + protos: []string{"foo", "bar"}, + }, + { + desc: "Valid alternative case ALPN rule not matching single protocol because of another being supported", + rule: "ALPN(`foo`) && !alpn(`h2`)", + protos: []string{"foo", "h2", "bar"}, + matchErr: true, + }, + { + desc: "Valid complex alternative case ALPN and HostSNI rule", + rule: "ALPN(`foo`) && (!alpn(`h2`) || hostsni(`foo`))", + protos: []string{"foo", "bar"}, + serverName: "foo", + }, + { + desc: "Valid complex alternative case ALPN and HostSNI rule not matching by SNI", + rule: "ALPN(`foo`) && (!alpn(`h2`) || hostsni(`foo`))", + protos: []string{"foo", "bar", "h2"}, + serverName: "bar", + matchErr: true, + }, + { + desc: "Valid complex alternative case ALPN and HostSNI rule matching by ALPN", + rule: "ALPN(`foo`) && (!alpn(`h2`) || hostsni(`foo`))", + protos: []string{"foo", "bar"}, + serverName: "bar", + }, + { + desc: "Valid complex alternative case ALPN and HostSNI rule not matching by protos", + rule: "ALPN(`foo`) && (!alpn(`h2`) || hostsni(`foo`))", + protos: []string{"h2", "bar"}, + serverName: "bar", + matchErr: true, + }, } for _, test := range testCases { @@ -471,7 +534,7 @@ func Test_addTCPRoute(t *testing.T) { remoteAddr: fakeAddr{addr: addr}, } - connData, err := NewConnData(test.serverName, conn) + connData, err := NewConnData(test.serverName, conn, test.protos) require.NoError(t, err) matchingHandler, _ := router.Match(connData) @@ -918,6 +981,75 @@ func Test_ClientIP(t *testing.T) { } } +func Test_ALPN(t *testing.T) { + testCases := []struct { + desc string + ruleALPNProtos []string + connProto string + buildErr bool + matchErr bool + }{ + { + desc: "Empty", + buildErr: true, + }, + { + desc: "ACME TLS proto", + ruleALPNProtos: []string{tlsalpn01.ACMETLS1Protocol}, + buildErr: true, + }, + { + desc: "Not matching empty proto", + ruleALPNProtos: []string{"h2"}, + matchErr: true, + }, + { + desc: "Not matching ALPN", + ruleALPNProtos: []string{"h2"}, + connProto: "mqtt", + matchErr: true, + }, + { + desc: "Matching ALPN", + ruleALPNProtos: []string{"h2"}, + connProto: "h2", + }, + { + desc: "Not matching multiple ALPNs", + ruleALPNProtos: []string{"h2", "mqtt"}, + connProto: "h2c", + matchErr: true, + }, + { + desc: "Matching multiple ALPNs", + ruleALPNProtos: []string{"h2", "h2c", "mqtt"}, + connProto: "h2c", + }, + } + + for _, test := range testCases { + test := test + + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + matchersTree := &matchersTree{} + err := alpn(matchersTree, test.ruleALPNProtos...) + if test.buildErr { + require.Error(t, err) + return + } + require.NoError(t, err) + + meta := ConnData{ + alpnProtos: []string{test.connProto}, + } + + assert.Equal(t, test.matchErr, !matchersTree.match(meta)) + }) + } +} + func Test_Priority(t *testing.T) { testCases := []struct { desc string diff --git a/pkg/server/router/tcp/router.go b/pkg/server/router/tcp/router.go index 3bcdbcacc0..078d1553c0 100644 --- a/pkg/server/router/tcp/router.go +++ b/pkg/server/router/tcp/router.go @@ -83,10 +83,10 @@ func (r *Router) ServeTCP(conn tcp.WriteCloser) { // Handling Non-TLS TCP connection early if there is neither HTTP(S) nor TLS // routers on the entryPoint, and if there is at least one non-TLS TCP router. // In the case of a non-TLS TCP client (that does not "send" first), we would - // block forever on clientHelloServerName, which is why we want to detect and + // block forever on clientHelloInfo, which is why we want to detect and // handle that case first and foremost. if r.muxerTCP.HasRoutes() && !r.muxerTCPTLS.HasRoutes() && !r.muxerHTTPS.HasRoutes() { - connData, err := tcpmuxer.NewConnData("", conn) + connData, err := tcpmuxer.NewConnData("", conn, nil) if err != nil { log.WithoutContext().Errorf("Error while reading TCP connection data: %v", err) conn.Close() @@ -108,7 +108,7 @@ func (r *Router) ServeTCP(conn tcp.WriteCloser) { // FIXME -- Check if ProxyProtocol changes the first bytes of the request br := bufio.NewReader(conn) - serverName, tls, peeked, err := clientHelloServerName(br) + hello, err := clientHelloInfo(br) if err != nil { conn.Close() return @@ -125,20 +125,20 @@ func (r *Router) ServeTCP(conn tcp.WriteCloser) { log.WithoutContext().Errorf("Error while setting write deadline: %v", err) } - connData, err := tcpmuxer.NewConnData(serverName, conn) + connData, err := tcpmuxer.NewConnData(hello.serverName, conn, hello.protos) if err != nil { log.WithoutContext().Errorf("Error while reading TCP connection data: %v", err) conn.Close() return } - if !tls { + if !hello.isTLS { handler, _ := r.muxerTCP.Match(connData) switch { case handler != nil: - handler.ServeTCP(r.GetConn(conn, peeked)) + handler.ServeTCP(r.GetConn(conn, hello.peeked)) case r.httpForwarder != nil: - r.httpForwarder.ServeTCP(r.GetConn(conn, peeked)) + r.httpForwarder.ServeTCP(r.GetConn(conn, hello.peeked)) default: conn.Close() } @@ -155,14 +155,14 @@ func (r *Router) ServeTCP(conn tcp.WriteCloser) { // In order not to depart from the behavior in 2.6, we only allow an HTTPS router // to take precedence over a TCP-TLS router if it is _not_ an HostSNI(*) router (so // basically any router that has a specific HostSNI based rule). - handlerHTTPS.ServeTCP(r.GetConn(conn, peeked)) + handlerHTTPS.ServeTCP(r.GetConn(conn, hello.peeked)) return } // Contains also TCP TLS passthrough routes. handlerTCPTLS, catchAllTCPTLS := r.muxerTCPTLS.Match(connData) if handlerTCPTLS != nil && !catchAllTCPTLS { - handlerTCPTLS.ServeTCP(r.GetConn(conn, peeked)) + handlerTCPTLS.ServeTCP(r.GetConn(conn, hello.peeked)) return } @@ -170,19 +170,19 @@ func (r *Router) ServeTCP(conn tcp.WriteCloser) { // We end up here for e.g. an HTTPS router that only has a PathPrefix rule, // which under the scenes is counted as an HostSNI(*) rule. if handlerHTTPS != nil { - handlerHTTPS.ServeTCP(r.GetConn(conn, peeked)) + handlerHTTPS.ServeTCP(r.GetConn(conn, hello.peeked)) return } // Fallback on TCP TLS catchAll. if handlerTCPTLS != nil { - handlerTCPTLS.ServeTCP(r.GetConn(conn, peeked)) + handlerTCPTLS.ServeTCP(r.GetConn(conn, hello.peeked)) return } // needed to handle 404s for HTTPS, as well as all non-Host (e.g. PathPrefix) matches. if r.httpsForwarder != nil { - r.httpsForwarder.ServeTCP(r.GetConn(conn, peeked)) + r.httpsForwarder.ServeTCP(r.GetConn(conn, hello.peeked)) return } @@ -300,18 +300,24 @@ func (c *Conn) Read(p []byte) (n int, err error) { return c.WriteCloser.Read(p) } -// clientHelloServerName returns the SNI server name inside the TLS ClientHello, +type clientHello struct { + serverName string // SNI server name + protos []string // ALPN protocols list + isTLS bool // whether we are a TLS handshake + peeked string // the bytes peeked from the hello while getting the info +} + +// clientHelloInfo returns various data from the clientHello handshake, // without consuming any bytes from br. -// On any error, the empty string is returned. -func clientHelloServerName(br *bufio.Reader) (string, bool, string, error) { +// It returns an error if it can't peek the first byte from the connection. +func clientHelloInfo(br *bufio.Reader) (*clientHello, error) { hdr, err := br.Peek(1) if err != nil { var opErr *net.OpError if !errors.Is(err, io.EOF) && (!errors.As(err, &opErr) || opErr.Timeout()) { log.WithoutContext().Errorf("Error while Peeking first byte: %s", err) } - - return "", false, "", err + return nil, err } // No valid TLS record has a type of 0x80, however SSLv2 handshakes @@ -323,16 +329,23 @@ func clientHelloServerName(br *bufio.Reader) (string, bool, string, error) { if hdr[0] != recordTypeHandshake { if hdr[0] == recordTypeSSLv2 { // we consider SSLv2 as TLS and it will be refused by real TLS handshake. - return "", true, getPeeked(br), nil + return &clientHello{ + isTLS: true, + peeked: getPeeked(br), + }, nil } - return "", false, getPeeked(br), nil // Not TLS. + return &clientHello{ + peeked: getPeeked(br), + }, nil // Not TLS. } const recordHeaderLen = 5 hdr, err = br.Peek(recordHeaderLen) if err != nil { log.Errorf("Error while Peeking hello: %s", err) - return "", false, getPeeked(br), nil + return &clientHello{ + peeked: getPeeked(br), + }, nil } recLen := int(hdr[3])<<8 | int(hdr[4]) // ignoring version in hdr[1:3] @@ -344,19 +357,29 @@ func clientHelloServerName(br *bufio.Reader) (string, bool, string, error) { helloBytes, err := br.Peek(recordHeaderLen + recLen) if err != nil { log.Errorf("Error while Hello: %s", err) - return "", true, getPeeked(br), nil + return &clientHello{ + isTLS: true, + peeked: getPeeked(br), + }, nil } sni := "" - server := tls.Server(sniSniffConn{r: bytes.NewReader(helloBytes)}, &tls.Config{ + var protos []string + server := tls.Server(helloSniffConn{r: bytes.NewReader(helloBytes)}, &tls.Config{ GetConfigForClient: func(hello *tls.ClientHelloInfo) (*tls.Config, error) { sni = hello.ServerName + protos = hello.SupportedProtos return nil, nil }, }) _ = server.Handshake() - return sni, true, getPeeked(br), nil + return &clientHello{ + serverName: sni, + isTLS: true, + peeked: getPeeked(br), + protos: protos, + }, nil } func getPeeked(br *bufio.Reader) string { @@ -368,15 +391,15 @@ func getPeeked(br *bufio.Reader) string { return string(peeked) } -// sniSniffConn is a net.Conn that reads from r, fails on Writes, +// helloSniffConn is a net.Conn that reads from r, fails on Writes, // and crashes otherwise. -type sniSniffConn struct { +type helloSniffConn struct { r io.Reader net.Conn // nil; crash on any unexpected use } // Read reads from the underlying reader. -func (c sniSniffConn) Read(p []byte) (int, error) { return c.r.Read(p) } +func (c helloSniffConn) Read(p []byte) (int, error) { return c.r.Read(p) } // Write crashes all the time. -func (sniSniffConn) Write(p []byte) (int, error) { return 0, io.EOF } +func (helloSniffConn) Write(p []byte) (int, error) { return 0, io.EOF }