Skip to content

Commit

Permalink
config: add support for extended TCP route URLs (#3845)
Browse files Browse the repository at this point in the history
* config: add support for extended TCP route URLs

* nevermind, add duplicate names
  • Loading branch information
calebdoxsey committed Dec 27, 2022
1 parent 67e1210 commit 271b078
Show file tree
Hide file tree
Showing 8 changed files with 181 additions and 50 deletions.
2 changes: 1 addition & 1 deletion authorize/grpc.go
Expand Up @@ -151,7 +151,7 @@ func getCheckRequestURL(req *envoy_service_auth_v3.CheckRequest) url.URL {
Scheme: h.GetScheme(),
Host: h.GetHost(),
}
u.Host = urlutil.GetDomainsForURL(u)[0]
u.Host = urlutil.GetDomainsForURL(&u)[0]
// envoy sends the query string as part of the path
path := h.GetPath()
if idx := strings.Index(path, "?"); idx != -1 {
Expand Down
56 changes: 20 additions & 36 deletions config/envoyconfig/listeners.go
Expand Up @@ -620,26 +620,28 @@ func getAllServerNames(cfg *config.Config, addr string) ([]string, error) {
serverNames := sets.NewSorted[string]()
serverNames.Add("*")

routeableHosts, err := getAllRouteableHosts(cfg.Options, addr)
certs, err := cfg.AllCertificates()
if err != nil {
return nil, err
}
for _, hp := range routeableHosts {
if h, _, err := net.SplitHostPort(hp); err == nil {
serverNames.Add(h)
} else {
serverNames.Add(hp)
}
for i := range certs {
serverNames.Add(cryptutil.GetCertificateServerNames(&certs[i])...)
}

certs, err := cfg.AllCertificates()
if err != nil {
return nil, err
if addr == cfg.Options.Addr {
sns, err := cfg.Options.GetAllRouteableHTTPServerNames()
if err != nil {
return nil, err
}
serverNames.Add(sns...)
}
for i := range certs {
for _, domain := range cryptutil.GetCertificateServerNames(&certs[i]) {
serverNames.Add(domain)

if addr == cfg.Options.GetGRPCAddr() {
sns, err := cfg.Options.GetAllRouteableGRPCServerNames()
if err != nil {
return nil, err
}
serverNames.Add(sns...)
}

return serverNames.ToSlice(), nil
Expand All @@ -655,30 +657,12 @@ func urlsMatchHost(urls []*url.URL, host string) bool {
}

func urlMatchesHost(u *url.URL, host string) bool {
if u == nil {
return false
}

var defaultPort string
if u.Scheme == "http" {
defaultPort = "80"
} else {
defaultPort = "443"
}

h1, p1, err := net.SplitHostPort(u.Host)
if err != nil {
h1 = u.Host
p1 = defaultPort
}

h2, p2, err := net.SplitHostPort(host)
if err != nil {
h2 = host
p2 = defaultPort
for _, h := range urlutil.GetDomainsForURL(u) {
if h == host {
return true
}
}

return h1 == h2 && p1 == p2
return false
}

func getPoliciesForServerName(options *config.Options, serverName string) []config.Policy {
Expand Down
6 changes: 6 additions & 0 deletions config/envoyconfig/listeners_test.go
Expand Up @@ -984,6 +984,7 @@ func Test_getAllDomains(t *testing.T) {
{Source: &config.StringURL{URL: mustParseURL(t, "http://a.example.com")}},
{Source: &config.StringURL{URL: mustParseURL(t, "https://b.example.com")}},
{Source: &config.StringURL{URL: mustParseURL(t, "https://c.example.com")}},
{Source: &config.StringURL{URL: mustParseURL(t, "https://d.unknown.example.com")}},
},
Cert: base64.StdEncoding.EncodeToString(certPEM),
Key: base64.StdEncoding.EncodeToString(keyPEM),
Expand All @@ -1001,6 +1002,8 @@ func Test_getAllDomains(t *testing.T) {
"b.example.com:443",
"c.example.com",
"c.example.com:443",
"d.unknown.example.com",
"d.unknown.example.com:443",
}
assert.Equal(t, expect, actual)
})
Expand Down Expand Up @@ -1029,6 +1032,8 @@ func Test_getAllDomains(t *testing.T) {
"c.example.com",
"c.example.com:443",
"cache.example.com:9001",
"d.unknown.example.com",
"d.unknown.example.com:443",
}
assert.Equal(t, expect, actual)
})
Expand All @@ -1044,6 +1049,7 @@ func Test_getAllDomains(t *testing.T) {
"authenticate.example.com",
"b.example.com",
"c.example.com",
"d.unknown.example.com",
}
assert.Equal(t, expect, actual)
})
Expand Down
92 changes: 84 additions & 8 deletions config/options.go
Expand Up @@ -1026,15 +1026,15 @@ func (o *Options) GetAllRouteableGRPCHosts() ([]string, error) {
return nil, err
}
for _, u := range authorizeURLs {
hosts.Add(urlutil.GetDomainsForURL(*u)...)
hosts.Add(urlutil.GetDomainsForURL(u)...)
}
} else if IsAuthorize(o.Services) {
authorizeURLs, err := o.GetInternalAuthorizeURLs()
if err != nil {
return nil, err
}
for _, u := range authorizeURLs {
hosts.Add(urlutil.GetDomainsForURL(*u)...)
hosts.Add(urlutil.GetDomainsForURL(u)...)
}
}

Expand All @@ -1045,15 +1045,60 @@ func (o *Options) GetAllRouteableGRPCHosts() ([]string, error) {
return nil, err
}
for _, u := range dataBrokerURLs {
hosts.Add(urlutil.GetDomainsForURL(*u)...)
hosts.Add(urlutil.GetDomainsForURL(u)...)
}
} else if IsDataBroker(o.Services) {
dataBrokerURLs, err := o.GetInternalDataBrokerURLs()
if err != nil {
return nil, err
}
for _, u := range dataBrokerURLs {
hosts.Add(urlutil.GetDomainsForURL(*u)...)
hosts.Add(urlutil.GetDomainsForURL(u)...)
}
}

return hosts.ToSlice(), nil
}

// GetAllRouteableGRPCServerNames returns all the possible gRPC server names handled by the Pomerium options.
func (o *Options) GetAllRouteableGRPCServerNames() ([]string, error) {
hosts := sets.NewSorted[string]()

// authorize urls
if IsAll(o.Services) {
authorizeURLs, err := o.GetAuthorizeURLs()
if err != nil {
return nil, err
}
for _, u := range authorizeURLs {
hosts.Add(urlutil.GetServerNamesForURL(u)...)
}
} else if IsAuthorize(o.Services) {
authorizeURLs, err := o.GetInternalAuthorizeURLs()
if err != nil {
return nil, err
}
for _, u := range authorizeURLs {
hosts.Add(urlutil.GetServerNamesForURL(u)...)
}
}

// databroker urls
if IsAll(o.Services) {
dataBrokerURLs, err := o.GetDataBrokerURLs()
if err != nil {
return nil, err
}
for _, u := range dataBrokerURLs {
hosts.Add(urlutil.GetServerNamesForURL(u)...)
}
} else if IsDataBroker(o.Services) {
dataBrokerURLs, err := o.GetInternalDataBrokerURLs()
if err != nil {
return nil, err
}
for _, u := range dataBrokerURLs {
hosts.Add(urlutil.GetServerNamesForURL(u)...)
}
}

Expand All @@ -1068,29 +1113,60 @@ func (o *Options) GetAllRouteableHTTPHosts() ([]string, error) {
if err != nil {
return nil, err
}
hosts.Add(urlutil.GetDomainsForURL(*authenticateURL)...)
hosts.Add(urlutil.GetDomainsForURL(authenticateURL)...)

authenticateURL, err = o.GetAuthenticateURL()
if err != nil {
return nil, err
}
hosts.Add(urlutil.GetDomainsForURL(*authenticateURL)...)
hosts.Add(urlutil.GetDomainsForURL(authenticateURL)...)
}

// policy urls
if IsProxy(o.Services) {
for _, policy := range o.GetAllPolicies() {
hosts.Add(urlutil.GetDomainsForURL(*policy.Source.URL)...)
hosts.Add(urlutil.GetDomainsForURL(policy.Source.URL)...)
if policy.TLSDownstreamServerName != "" {
tlsURL := policy.Source.URL.ResolveReference(&url.URL{Host: policy.TLSDownstreamServerName})
hosts.Add(urlutil.GetDomainsForURL(*tlsURL)...)
hosts.Add(urlutil.GetDomainsForURL(tlsURL)...)
}
}
}

return hosts.ToSlice(), nil
}

// GetAllRouteableHTTPServerNames returns all the possible HTTP server names handled by the Pomerium options.
func (o *Options) GetAllRouteableHTTPServerNames() ([]string, error) {
serverNames := sets.NewSorted[string]()
if IsAuthenticate(o.Services) {
authenticateURL, err := o.GetInternalAuthenticateURL()
if err != nil {
return nil, err
}
serverNames.Add(urlutil.GetServerNamesForURL(authenticateURL)...)

authenticateURL, err = o.GetAuthenticateURL()
if err != nil {
return nil, err
}
serverNames.Add(urlutil.GetServerNamesForURL(authenticateURL)...)
}

// policy urls
if IsProxy(o.Services) {
for _, policy := range o.GetAllPolicies() {
serverNames.Add(urlutil.GetServerNamesForURL(policy.Source.URL)...)
if policy.TLSDownstreamServerName != "" {
tlsURL := policy.Source.URL.ResolveReference(&url.URL{Host: policy.TLSDownstreamServerName})
serverNames.Add(urlutil.GetServerNamesForURL(tlsURL)...)
}
}
}

return serverNames.ToSlice(), nil
}

// GetClientSecret gets the client secret.
func (o *Options) GetClientSecret() (string, error) {
if o == nil {
Expand Down
7 changes: 6 additions & 1 deletion config/policy.go
Expand Up @@ -599,7 +599,12 @@ func (p *Policy) Matches(requestURL url.URL) bool {
return false
}

if p.Source.Host != requestURL.Host {
// make sure one of the host domains matches the incoming url
found := false
for _, host := range urlutil.GetDomainsForURL(p.Source.URL) {
found = found || host == requestURL.Host
}
if !found {
return false
}

Expand Down
9 changes: 9 additions & 0 deletions config/policy_test.go
Expand Up @@ -269,4 +269,13 @@ func TestPolicy_Matches(t *testing.T) {
assert.True(t, p.Matches(urlutil.MustParseAndValidateURL(`https://www.example.com/admin/foo`)))
assert.True(t, p.Matches(urlutil.MustParseAndValidateURL(`https://www.example.com/admin/bar`)))
})
t.Run("tcp", func(t *testing.T) {
p := &Policy{
From: "tcp+https://proxy.example.com/redis.example.com:6379",
To: mustParseWeightedURLs(t, "tcp://localhost:6379"),
}
assert.NoError(t, p.Validate())

assert.True(t, p.Matches(urlutil.MustParseAndValidateURL(`https://redis.example.com:6379`)))
})
}
30 changes: 27 additions & 3 deletions internal/urlutil/url.go
Expand Up @@ -94,13 +94,37 @@ func GetAbsoluteURL(r *http.Request) *url.URL {
return u
}

// GetServerNamesForURL returns the TLS server names for the given URL. The server name is the
// URL hostname.
func GetServerNamesForURL(u *url.URL) []string {
if u == nil {
return nil
}

return []string{u.Hostname()}
}

// GetDomainsForURL returns the available domains for given url.
//
// For standard HTTP (80)/HTTPS (443) ports, it returns `example.com` and `example.com:<port>`.
// Otherwise, return the URL.Host value.
func GetDomainsForURL(u url.URL) []string {
if IsTCP(&u) {
return []string{u.Host}
func GetDomainsForURL(u *url.URL) []string {
if u == nil {
return nil
}

// tcp+https://ssh.example.com:22
// => ssh.example.com:22
// tcp+https://proxy.example.com/ssh.example.com:22
// => ssh.example.com:22
if strings.HasPrefix(u.Scheme, "tcp+") {
hosts := strings.Split(u.Path, "/")[1:]
// if there are no domains in the path part of the URL, use the host
if len(hosts) == 0 {
return []string{u.Host}
}
// otherwise use the path parts of the URL as the hosts
return hosts
}

var defaultPort string
Expand Down
29 changes: 28 additions & 1 deletion internal/urlutil/url_test.go
Expand Up @@ -136,6 +136,31 @@ func TestGetAbsoluteURL(t *testing.T) {
}
}

func TestGetServerNamesForURL(t *testing.T) {
t.Parallel()
for _, tc := range []struct {
name string
u *url.URL
want []string
}{
{"http", &url.URL{Scheme: "http", Host: "example.com"}, []string{"example.com"}},
{"http scheme with host contain 443", &url.URL{Scheme: "http", Host: "example.com:443"}, []string{"example.com"}},
{"https", &url.URL{Scheme: "https", Host: "example.com"}, []string{"example.com"}},
{"Host contains other port", &url.URL{Scheme: "https", Host: "example.com:1234"}, []string{"example.com"}},
{"tcp", &url.URL{Scheme: "tcp+https", Host: "example.com:1234"}, []string{"example.com"}},
{"tcp with path", &url.URL{Scheme: "tcp+https", Host: "proxy.example.com", Path: "/ssh.example.com:1234"}, []string{"proxy.example.com"}},
} {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got := GetServerNamesForURL(tc.u)
if diff := cmp.Diff(got, tc.want); diff != "" {
t.Errorf("GetServerNamesForURL() = %v", diff)
}
})
}
}

func TestGetDomainsForURL(t *testing.T) {
t.Parallel()
tests := []struct {
Expand All @@ -147,12 +172,14 @@ func TestGetDomainsForURL(t *testing.T) {
{"http scheme with host contain 443", &url.URL{Scheme: "http", Host: "example.com:443"}, []string{"example.com:443"}},
{"https", &url.URL{Scheme: "https", Host: "example.com"}, []string{"example.com", "example.com:443"}},
{"Host contains other port", &url.URL{Scheme: "https", Host: "example.com:1234"}, []string{"example.com:1234"}},
{"tcp", &url.URL{Scheme: "tcp+https", Host: "example.com:1234"}, []string{"example.com:1234"}},
{"tcp with path", &url.URL{Scheme: "tcp+https", Host: "proxy.example.com", Path: "/ssh.example.com:1234"}, []string{"ssh.example.com:1234"}},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got := GetDomainsForURL(*tc.u)
got := GetDomainsForURL(tc.u)
if diff := cmp.Diff(got, tc.want); diff != "" {
t.Errorf("GetDomainsForURL() = %v", diff)
}
Expand Down

0 comments on commit 271b078

Please sign in to comment.