Skip to content

Commit

Permalink
improve certificate matching performance (#4188)
Browse files Browse the repository at this point in the history
improve certificate matching performance (#4186)

Co-authored-by: Caleb Doxsey <cdoxsey@pomerium.com>
  • Loading branch information
backport-actions-token[bot] and calebdoxsey committed May 23, 2023
1 parent 9b78ae9 commit 45a577d
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 133 deletions.
26 changes: 5 additions & 21 deletions config/config.go
Expand Up @@ -12,7 +12,6 @@ import (
"github.com/pomerium/pomerium/internal/fileutil"
"github.com/pomerium/pomerium/internal/hashutil"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry/metrics"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/cryptutil"
Expand Down Expand Up @@ -149,24 +148,9 @@ func (cfg *Config) GetTLSClientConfig() (*tls.Config, error) {
}, nil
}

// GetCertificateForServerName gets the certificate for the server name. If no certificate is found and there
// is a derived CA one will be generated using that CA. If no derived CA is defined a self-signed certificate
// will be generated.
func (cfg *Config) GetCertificateForServerName(serverName string) (*tls.Certificate, error) {
certificates, err := cfg.AllCertificates()
if err != nil {
return nil, err
}

// first try a direct name match
for i := range certificates {
if cryptutil.MatchesServerName(&certificates[i], serverName) {
return &certificates[i], nil
}
}

log.WarnNoTLSCertificate(serverName)

// GenerateCatchAllCertificate generates a catch-all certificate. If no derived CA is defined a
// self-signed certificate will be generated.
func (cfg *Config) GenerateCatchAllCertificate() (*tls.Certificate, error) {
if cfg.Options.DeriveInternalDomainCert != nil {
sharedKey, err := cfg.Options.GetSharedKey()
if err != nil {
Expand All @@ -178,7 +162,7 @@ func (cfg *Config) GetCertificateForServerName(serverName string) (*tls.Certific
return nil, fmt.Errorf("failed to generate cert, invalid derived CA: %w", err)
}

pem, err := ca.NewServerCert([]string{serverName})
pem, err := ca.NewServerCert([]string{"*"})
if err != nil {
return nil, fmt.Errorf("failed to generate cert, error creating server certificate: %w", err)
}
Expand All @@ -196,7 +180,7 @@ func (cfg *Config) GetCertificateForServerName(serverName string) (*tls.Certific
}

// finally fall back to a generated, self-signed certificate
return cryptutil.GenerateCertificate(sharedKey, serverName)
return cryptutil.GenerateCertificate(sharedKey, "*")
}

// WillHaveCertificateForServerName returns true if there will be a certificate for the given server name.
Expand Down
94 changes: 0 additions & 94 deletions config/config_test.go

This file was deleted.

3 changes: 2 additions & 1 deletion config/envoyconfig/listeners.go
Expand Up @@ -109,7 +109,8 @@ func getAllCertificates(cfg *config.Config) ([]tls.Certificate, error) {
if err != nil {
return nil, fmt.Errorf("error collecting all certificates: %w", err)
}
wc, err := cfg.GetCertificateForServerName("*")

wc, err := cfg.GenerateCatchAllCertificate()
if err != nil {
return nil, fmt.Errorf("error getting wildcard certificate: %w", err)
}
Expand Down
4 changes: 2 additions & 2 deletions config/envoyconfig/route_configurations.go
Expand Up @@ -78,7 +78,7 @@ func (b *Builder) buildMainRouteConfiguration(

// if we're the proxy, add all the policy routes
if config.IsProxy(cfg.Options.Services) {
rs, err := b.buildRoutesForPoliciesWithHost(cfg, host)
rs, err := b.buildRoutesForPoliciesWithHost(cfg, certs, host)
if err != nil {
return nil, err
}
Expand All @@ -95,7 +95,7 @@ func (b *Builder) buildMainRouteConfiguration(
return nil, err
}
if config.IsProxy(cfg.Options.Services) {
rs, err := b.buildRoutesForPoliciesWithCatchAll(cfg)
rs, err := b.buildRoutesForPoliciesWithCatchAll(cfg, certs)
if err != nil {
return nil, err
}
Expand Down
18 changes: 9 additions & 9 deletions config/envoyconfig/routes.go
@@ -1,6 +1,7 @@
package envoyconfig

import (
"crypto/tls"
"encoding/json"
"fmt"
"net/url"
Expand Down Expand Up @@ -193,6 +194,7 @@ func getClusterStatsName(policy *config.Policy) string {

func (b *Builder) buildRoutesForPoliciesWithHost(
cfg *config.Config,
certs []tls.Certificate,
host string,
) ([]*envoy_config_route_v3.Route, error) {
var routes []*envoy_config_route_v3.Route
Expand All @@ -207,7 +209,7 @@ func (b *Builder) buildRoutesForPoliciesWithHost(
continue
}

policyRoutes, err := b.buildRoutesForPolicy(cfg, &policy, fmt.Sprintf("policy-%d", i))
policyRoutes, err := b.buildRoutesForPolicy(cfg, certs, &policy, fmt.Sprintf("policy-%d", i))
if err != nil {
return nil, err
}
Expand All @@ -219,6 +221,7 @@ func (b *Builder) buildRoutesForPoliciesWithHost(

func (b *Builder) buildRoutesForPoliciesWithCatchAll(
cfg *config.Config,
certs []tls.Certificate,
) ([]*envoy_config_route_v3.Route, error) {
var routes []*envoy_config_route_v3.Route
for i, p := range cfg.Options.GetAllPolicies() {
Expand All @@ -232,7 +235,7 @@ func (b *Builder) buildRoutesForPoliciesWithCatchAll(
continue
}

policyRoutes, err := b.buildRoutesForPolicy(cfg, &policy, fmt.Sprintf("policy-%d", i))
policyRoutes, err := b.buildRoutesForPolicy(cfg, certs, &policy, fmt.Sprintf("policy-%d", i))
if err != nil {
return nil, err
}
Expand All @@ -244,6 +247,7 @@ func (b *Builder) buildRoutesForPoliciesWithCatchAll(

func (b *Builder) buildRoutesForPolicy(
cfg *config.Config,
certs []tls.Certificate,
policy *config.Policy,
name string,
) ([]*envoy_config_route_v3.Route, error) {
Expand All @@ -256,14 +260,14 @@ func (b *Builder) buildRoutesForPolicy(
if strings.Contains(fromURL.Host, "*") {
// we have to match '*.example.com' and '*.example.com:443', so there are two routes
for _, host := range urlutil.GetDomainsForURL(fromURL) {
route, err := b.buildRouteForPolicyAndMatch(cfg, policy, name, mkRouteMatchForHost(policy, host))
route, err := b.buildRouteForPolicyAndMatch(cfg, certs, policy, name, mkRouteMatchForHost(policy, host))
if err != nil {
return nil, err
}
routes = append(routes, route)
}
} else {
route, err := b.buildRouteForPolicyAndMatch(cfg, policy, name, mkRouteMatch(policy))
route, err := b.buildRouteForPolicyAndMatch(cfg, certs, policy, name, mkRouteMatch(policy))
if err != nil {
return nil, err
}
Expand All @@ -274,6 +278,7 @@ func (b *Builder) buildRoutesForPolicy(

func (b *Builder) buildRouteForPolicyAndMatch(
cfg *config.Config,
certs []tls.Certificate,
policy *config.Policy,
name string,
match *envoy_config_route_v3.RouteMatch,
Expand All @@ -283,11 +288,6 @@ func (b *Builder) buildRouteForPolicyAndMatch(
return nil, err
}

certs, err := getAllCertificates(cfg)
if err != nil {
return nil, err
}

requireStrictTransportSecurity := cryptutil.HasCertificateForServerName(certs, fromURL.Hostname())

route := &envoy_config_route_v3.Route{
Expand Down
12 changes: 6 additions & 6 deletions config/envoyconfig/routes_test.go
Expand Up @@ -307,7 +307,7 @@ func TestTimeouts(t *testing.T) {
AllowWebsockets: tc.allowWebsockets,
},
},
}}, "example.com")
}}, nil, "example.com")
if !assert.NoError(t, err, "%v", tc) || !assert.Len(t, routes, 1, tc) || !assert.NotNil(t, routes[0].GetRoute(), "%v", tc) {
continue
}
Expand Down Expand Up @@ -412,7 +412,7 @@ func Test_buildPolicyRoutes(t *testing.T) {
UpstreamTimeout: &ten,
},
},
}}, "example.com")
}}, nil, "example.com")
require.NoError(t, err)

testutil.AssertProtoJSONEqual(t, `
Expand Down Expand Up @@ -918,7 +918,7 @@ func Test_buildPolicyRoutes(t *testing.T) {
PassIdentityHeaders: true,
},
},
}}, "authenticate.example.com")
}}, nil, "authenticate.example.com")
require.NoError(t, err)

testutil.AssertProtoJSONEqual(t, `
Expand Down Expand Up @@ -1005,7 +1005,7 @@ func Test_buildPolicyRoutes(t *testing.T) {
UpstreamTimeout: &ten,
},
},
}}, "example.com:22")
}}, nil, "example.com:22")
require.NoError(t, err)

testutil.AssertProtoJSONEqual(t, `
Expand Down Expand Up @@ -1151,7 +1151,7 @@ func Test_buildPolicyRoutes(t *testing.T) {
From: "https://from.example.com",
},
},
}}, "from.example.com")
}}, nil, "from.example.com")
require.NoError(t, err)

testutil.AssertProtoJSONEqual(t, `
Expand Down Expand Up @@ -1272,7 +1272,7 @@ func Test_buildPolicyRoutesRewrite(t *testing.T) {
HostPathRegexRewriteSubstitution: "\\1",
},
},
}}, "example.com")
}}, nil, "example.com")
require.NoError(t, err)

testutil.AssertProtoJSONEqual(t, `
Expand Down

0 comments on commit 45a577d

Please sign in to comment.