diff --git a/authorize/grpc.go b/authorize/grpc.go index a0d53a74bd6..4a785c23f2c 100644 --- a/authorize/grpc.go +++ b/authorize/grpc.go @@ -151,7 +151,7 @@ func getCheckRequestURL(req *envoy_service_auth_v3.CheckRequest) url.URL { Scheme: h.GetScheme(), Host: h.GetHost(), } - u.Host = urlutil.GetDomainsForURL(u)[0] + u.Host = urlutil.GetDomainsForURL(&u)[0] // envoy sends the query string as part of the path path := h.GetPath() if idx := strings.Index(path, "?"); idx != -1 { diff --git a/config/envoyconfig/listeners.go b/config/envoyconfig/listeners.go index 88d0d76b83e..a390762c0ff 100644 --- a/config/envoyconfig/listeners.go +++ b/config/envoyconfig/listeners.go @@ -620,26 +620,28 @@ func getAllServerNames(cfg *config.Config, addr string) ([]string, error) { serverNames := sets.NewSorted[string]() serverNames.Add("*") - routeableHosts, err := getAllRouteableHosts(cfg.Options, addr) + certs, err := cfg.AllCertificates() if err != nil { return nil, err } - for _, hp := range routeableHosts { - if h, _, err := net.SplitHostPort(hp); err == nil { - serverNames.Add(h) - } else { - serverNames.Add(hp) - } + for i := range certs { + serverNames.Add(cryptutil.GetCertificateServerNames(&certs[i])...) } - certs, err := cfg.AllCertificates() - if err != nil { - return nil, err + if addr == cfg.Options.Addr { + sns, err := cfg.Options.GetAllRouteableHTTPServerNames() + if err != nil { + return nil, err + } + serverNames.Add(sns...) } - for i := range certs { - for _, domain := range cryptutil.GetCertificateServerNames(&certs[i]) { - serverNames.Add(domain) + + if addr == cfg.Options.GetGRPCAddr() { + sns, err := cfg.Options.GetAllRouteableGRPCServerNames() + if err != nil { + return nil, err } + serverNames.Add(sns...) } return serverNames.ToSlice(), nil @@ -655,30 +657,12 @@ func urlsMatchHost(urls []*url.URL, host string) bool { } func urlMatchesHost(u *url.URL, host string) bool { - if u == nil { - return false - } - - var defaultPort string - if u.Scheme == "http" { - defaultPort = "80" - } else { - defaultPort = "443" - } - - h1, p1, err := net.SplitHostPort(u.Host) - if err != nil { - h1 = u.Host - p1 = defaultPort - } - - h2, p2, err := net.SplitHostPort(host) - if err != nil { - h2 = host - p2 = defaultPort + for _, h := range urlutil.GetDomainsForURL(u) { + if h == host { + return true + } } - - return h1 == h2 && p1 == p2 + return false } func getPoliciesForServerName(options *config.Options, serverName string) []config.Policy { diff --git a/config/envoyconfig/listeners_test.go b/config/envoyconfig/listeners_test.go index ab44e74e033..44d2c4b0dbe 100644 --- a/config/envoyconfig/listeners_test.go +++ b/config/envoyconfig/listeners_test.go @@ -984,6 +984,7 @@ func Test_getAllDomains(t *testing.T) { {Source: &config.StringURL{URL: mustParseURL(t, "http://a.example.com")}}, {Source: &config.StringURL{URL: mustParseURL(t, "https://b.example.com")}}, {Source: &config.StringURL{URL: mustParseURL(t, "https://c.example.com")}}, + {Source: &config.StringURL{URL: mustParseURL(t, "https://d.unknown.example.com")}}, }, Cert: base64.StdEncoding.EncodeToString(certPEM), Key: base64.StdEncoding.EncodeToString(keyPEM), @@ -1001,6 +1002,8 @@ func Test_getAllDomains(t *testing.T) { "b.example.com:443", "c.example.com", "c.example.com:443", + "d.unknown.example.com", + "d.unknown.example.com:443", } assert.Equal(t, expect, actual) }) @@ -1029,6 +1032,8 @@ func Test_getAllDomains(t *testing.T) { "c.example.com", "c.example.com:443", "cache.example.com:9001", + "d.unknown.example.com", + "d.unknown.example.com:443", } assert.Equal(t, expect, actual) }) @@ -1044,6 +1049,7 @@ func Test_getAllDomains(t *testing.T) { "authenticate.example.com", "b.example.com", "c.example.com", + "d.unknown.example.com", } assert.Equal(t, expect, actual) }) diff --git a/config/options.go b/config/options.go index 662dd3a914e..8aba5076c2e 100644 --- a/config/options.go +++ b/config/options.go @@ -1026,7 +1026,7 @@ func (o *Options) GetAllRouteableGRPCHosts() ([]string, error) { return nil, err } for _, u := range authorizeURLs { - hosts.Add(urlutil.GetDomainsForURL(*u)...) + hosts.Add(urlutil.GetDomainsForURL(u)...) } } else if IsAuthorize(o.Services) { authorizeURLs, err := o.GetInternalAuthorizeURLs() @@ -1034,7 +1034,7 @@ func (o *Options) GetAllRouteableGRPCHosts() ([]string, error) { return nil, err } for _, u := range authorizeURLs { - hosts.Add(urlutil.GetDomainsForURL(*u)...) + hosts.Add(urlutil.GetDomainsForURL(u)...) } } @@ -1045,7 +1045,7 @@ func (o *Options) GetAllRouteableGRPCHosts() ([]string, error) { return nil, err } for _, u := range dataBrokerURLs { - hosts.Add(urlutil.GetDomainsForURL(*u)...) + hosts.Add(urlutil.GetDomainsForURL(u)...) } } else if IsDataBroker(o.Services) { dataBrokerURLs, err := o.GetInternalDataBrokerURLs() @@ -1053,7 +1053,52 @@ func (o *Options) GetAllRouteableGRPCHosts() ([]string, error) { return nil, err } for _, u := range dataBrokerURLs { - hosts.Add(urlutil.GetDomainsForURL(*u)...) + hosts.Add(urlutil.GetDomainsForURL(u)...) + } + } + + return hosts.ToSlice(), nil +} + +// GetAllRouteableGRPCServerNames returns all the possible gRPC server names handled by the Pomerium options. +func (o *Options) GetAllRouteableGRPCServerNames() ([]string, error) { + hosts := sets.NewSorted[string]() + + // authorize urls + if IsAll(o.Services) { + authorizeURLs, err := o.GetAuthorizeURLs() + if err != nil { + return nil, err + } + for _, u := range authorizeURLs { + hosts.Add(urlutil.GetServerNamesForURL(u)...) + } + } else if IsAuthorize(o.Services) { + authorizeURLs, err := o.GetInternalAuthorizeURLs() + if err != nil { + return nil, err + } + for _, u := range authorizeURLs { + hosts.Add(urlutil.GetServerNamesForURL(u)...) + } + } + + // databroker urls + if IsAll(o.Services) { + dataBrokerURLs, err := o.GetDataBrokerURLs() + if err != nil { + return nil, err + } + for _, u := range dataBrokerURLs { + hosts.Add(urlutil.GetServerNamesForURL(u)...) + } + } else if IsDataBroker(o.Services) { + dataBrokerURLs, err := o.GetInternalDataBrokerURLs() + if err != nil { + return nil, err + } + for _, u := range dataBrokerURLs { + hosts.Add(urlutil.GetServerNamesForURL(u)...) } } @@ -1068,22 +1113,22 @@ func (o *Options) GetAllRouteableHTTPHosts() ([]string, error) { if err != nil { return nil, err } - hosts.Add(urlutil.GetDomainsForURL(*authenticateURL)...) + hosts.Add(urlutil.GetDomainsForURL(authenticateURL)...) authenticateURL, err = o.GetAuthenticateURL() if err != nil { return nil, err } - hosts.Add(urlutil.GetDomainsForURL(*authenticateURL)...) + hosts.Add(urlutil.GetDomainsForURL(authenticateURL)...) } // policy urls if IsProxy(o.Services) { for _, policy := range o.GetAllPolicies() { - hosts.Add(urlutil.GetDomainsForURL(*policy.Source.URL)...) + hosts.Add(urlutil.GetDomainsForURL(policy.Source.URL)...) if policy.TLSDownstreamServerName != "" { tlsURL := policy.Source.URL.ResolveReference(&url.URL{Host: policy.TLSDownstreamServerName}) - hosts.Add(urlutil.GetDomainsForURL(*tlsURL)...) + hosts.Add(urlutil.GetDomainsForURL(tlsURL)...) } } } @@ -1091,6 +1136,37 @@ func (o *Options) GetAllRouteableHTTPHosts() ([]string, error) { return hosts.ToSlice(), nil } +// GetAllRouteableHTTPServerNames returns all the possible HTTP server names handled by the Pomerium options. +func (o *Options) GetAllRouteableHTTPServerNames() ([]string, error) { + serverNames := sets.NewSorted[string]() + if IsAuthenticate(o.Services) { + authenticateURL, err := o.GetInternalAuthenticateURL() + if err != nil { + return nil, err + } + serverNames.Add(urlutil.GetServerNamesForURL(authenticateURL)...) + + authenticateURL, err = o.GetAuthenticateURL() + if err != nil { + return nil, err + } + serverNames.Add(urlutil.GetServerNamesForURL(authenticateURL)...) + } + + // policy urls + if IsProxy(o.Services) { + for _, policy := range o.GetAllPolicies() { + serverNames.Add(urlutil.GetServerNamesForURL(policy.Source.URL)...) + if policy.TLSDownstreamServerName != "" { + tlsURL := policy.Source.URL.ResolveReference(&url.URL{Host: policy.TLSDownstreamServerName}) + serverNames.Add(urlutil.GetServerNamesForURL(tlsURL)...) + } + } + } + + return serverNames.ToSlice(), nil +} + // GetClientSecret gets the client secret. func (o *Options) GetClientSecret() (string, error) { if o == nil { diff --git a/config/policy.go b/config/policy.go index b7faef28fc9..238ebb5f682 100644 --- a/config/policy.go +++ b/config/policy.go @@ -599,7 +599,12 @@ func (p *Policy) Matches(requestURL url.URL) bool { return false } - if p.Source.Host != requestURL.Host { + // make sure one of the host domains matches the incoming url + found := false + for _, host := range urlutil.GetDomainsForURL(p.Source.URL) { + found = found || host == requestURL.Host + } + if !found { return false } diff --git a/config/policy_test.go b/config/policy_test.go index 287831147c8..63461578e9a 100644 --- a/config/policy_test.go +++ b/config/policy_test.go @@ -269,4 +269,13 @@ func TestPolicy_Matches(t *testing.T) { assert.True(t, p.Matches(urlutil.MustParseAndValidateURL(`https://www.example.com/admin/foo`))) assert.True(t, p.Matches(urlutil.MustParseAndValidateURL(`https://www.example.com/admin/bar`))) }) + t.Run("tcp", func(t *testing.T) { + p := &Policy{ + From: "tcp+https://proxy.example.com/redis.example.com:6379", + To: mustParseWeightedURLs(t, "tcp://localhost:6379"), + } + assert.NoError(t, p.Validate()) + + assert.True(t, p.Matches(urlutil.MustParseAndValidateURL(`https://redis.example.com:6379`))) + }) } diff --git a/internal/urlutil/url.go b/internal/urlutil/url.go index 405635724ee..32cb48db010 100644 --- a/internal/urlutil/url.go +++ b/internal/urlutil/url.go @@ -94,13 +94,37 @@ func GetAbsoluteURL(r *http.Request) *url.URL { return u } +// GetServerNamesForURL returns the TLS server names for the given URL. The server name is the +// URL hostname. +func GetServerNamesForURL(u *url.URL) []string { + if u == nil { + return nil + } + + return []string{u.Hostname()} +} + // GetDomainsForURL returns the available domains for given url. // // For standard HTTP (80)/HTTPS (443) ports, it returns `example.com` and `example.com:`. // Otherwise, return the URL.Host value. -func GetDomainsForURL(u url.URL) []string { - if IsTCP(&u) { - return []string{u.Host} +func GetDomainsForURL(u *url.URL) []string { + if u == nil { + return nil + } + + // tcp+https://ssh.example.com:22 + // => ssh.example.com:22 + // tcp+https://proxy.example.com/ssh.example.com:22 + // => ssh.example.com:22 + if strings.HasPrefix(u.Scheme, "tcp+") { + hosts := strings.Split(u.Path, "/")[1:] + // if there are no domains in the path part of the URL, use the host + if len(hosts) == 0 { + return []string{u.Host} + } + // otherwise use the path parts of the URL as the hosts + return hosts } var defaultPort string diff --git a/internal/urlutil/url_test.go b/internal/urlutil/url_test.go index a253a3e7d3f..bbfa3167c3d 100644 --- a/internal/urlutil/url_test.go +++ b/internal/urlutil/url_test.go @@ -136,6 +136,31 @@ func TestGetAbsoluteURL(t *testing.T) { } } +func TestGetServerNamesForURL(t *testing.T) { + t.Parallel() + for _, tc := range []struct { + name string + u *url.URL + want []string + }{ + {"http", &url.URL{Scheme: "http", Host: "example.com"}, []string{"example.com"}}, + {"http scheme with host contain 443", &url.URL{Scheme: "http", Host: "example.com:443"}, []string{"example.com"}}, + {"https", &url.URL{Scheme: "https", Host: "example.com"}, []string{"example.com"}}, + {"Host contains other port", &url.URL{Scheme: "https", Host: "example.com:1234"}, []string{"example.com"}}, + {"tcp", &url.URL{Scheme: "tcp+https", Host: "example.com:1234"}, []string{"example.com"}}, + {"tcp with path", &url.URL{Scheme: "tcp+https", Host: "proxy.example.com", Path: "/ssh.example.com:1234"}, []string{"proxy.example.com"}}, + } { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + got := GetServerNamesForURL(tc.u) + if diff := cmp.Diff(got, tc.want); diff != "" { + t.Errorf("GetServerNamesForURL() = %v", diff) + } + }) + } +} + func TestGetDomainsForURL(t *testing.T) { t.Parallel() tests := []struct { @@ -147,12 +172,14 @@ func TestGetDomainsForURL(t *testing.T) { {"http scheme with host contain 443", &url.URL{Scheme: "http", Host: "example.com:443"}, []string{"example.com:443"}}, {"https", &url.URL{Scheme: "https", Host: "example.com"}, []string{"example.com", "example.com:443"}}, {"Host contains other port", &url.URL{Scheme: "https", Host: "example.com:1234"}, []string{"example.com:1234"}}, + {"tcp", &url.URL{Scheme: "tcp+https", Host: "example.com:1234"}, []string{"example.com:1234"}}, + {"tcp with path", &url.URL{Scheme: "tcp+https", Host: "proxy.example.com", Path: "/ssh.example.com:1234"}, []string{"ssh.example.com:1234"}}, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() - got := GetDomainsForURL(*tc.u) + got := GetDomainsForURL(tc.u) if diff := cmp.Diff(got, tc.want); diff != "" { t.Errorf("GetDomainsForURL() = %v", diff) }