Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

envoyconfig: add virtual host domains for certificates in addition to routes #3593

Merged
merged 4 commits into from Aug 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
37 changes: 25 additions & 12 deletions config/envoyconfig/listeners.go
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"net"
"net/url"
"strings"
"time"

envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
Expand Down Expand Up @@ -118,7 +119,7 @@ func (b *Builder) buildMainListener(ctx context.Context, cfg *config.Config) (*e
}
listenerFilters = append(listenerFilters, TLSInspectorFilter())

chains, err := b.buildFilterChains(cfg.Options, cfg.Options.Addr,
chains, err := b.buildFilterChains(cfg, cfg.Options.Addr,
func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error) {
filter, err := b.buildMainHTTPConnectionManagerFilter(cfg.Options, httpDomains, tlsDomain)
if err != nil {
Expand Down Expand Up @@ -235,23 +236,23 @@ func (b *Builder) buildMetricsListener(cfg *config.Config) (*envoy_config_listen
}

func (b *Builder) buildFilterChains(
options *config.Options, addr string,
cfg *config.Config, addr string,
callback func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error),
) ([]*envoy_config_listener_v3.FilterChain, error) {
allDomains, err := getAllRouteableDomains(options, addr)
allDomains, err := getAllRouteableDomains(cfg.Options, addr)
if err != nil {
return nil, err
}

tlsDomains, err := getAllTLSDomains(options, addr)
tlsDomains, err := getAllTLSDomains(cfg, addr)
if err != nil {
return nil, err
}

var chains []*envoy_config_listener_v3.FilterChain
chains = append(chains, b.buildACMETLSALPNFilterChain())
for _, domain := range tlsDomains {
routeableDomains, err := getRouteableDomainsForTLSServerName(options, addr, domain)
routeableDomains, err := getRouteableDomainsForTLSServerName(cfg.Options, addr, domain)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -341,7 +342,9 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter(
LuaFilter(luascripts.CleanUpstream),
LuaFilter(luascripts.RewriteHeaders),
}
if tlsDomain != "" && tlsDomain != "*" {
// only return 421s for non-wildcard domains because the lua script doesn't understand how to
// parse wildcards properly
if tlsDomain != "" && !strings.Contains(tlsDomain, "*") {
desimone marked this conversation as resolved.
Show resolved Hide resolved
filters = append(filters, LuaFilter(fmt.Sprintf(luascripts.FixMisdirected, tlsDomain)))
}
filters = append(filters, HTTPRouterFilter())
Expand Down Expand Up @@ -438,7 +441,7 @@ func (b *Builder) buildGRPCListener(ctx context.Context, cfg *config.Config) (*e
return li, nil
}

chains, err := b.buildFilterChains(cfg.Options, cfg.Options.GRPCAddr,
chains, err := b.buildFilterChains(cfg, cfg.Options.GRPCAddr,
func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error) {
filterChain := &envoy_config_listener_v3.FilterChain{
Filters: []*envoy_config_listener_v3.Filter{filter},
Expand Down Expand Up @@ -658,21 +661,31 @@ func getAllRouteableDomains(options *config.Options, addr string) ([]string, err
return allDomains.ToSlice(), nil
}

func getAllTLSDomains(options *config.Options, addr string) ([]string, error) {
allDomains, err := getAllRouteableDomains(options, addr)
func getAllTLSDomains(cfg *config.Config, addr string) ([]string, error) {
domains := sets.NewSorted[string]()
calebdoxsey marked this conversation as resolved.
Show resolved Hide resolved

routeableDomains, err := getAllRouteableDomains(cfg.Options, addr)
if err != nil {
return nil, err
}

domains := sets.NewSorted[string]()
for _, hp := range allDomains {
for _, hp := range routeableDomains {
if d, _, err := net.SplitHostPort(hp); err == nil {
domains.Add(d)
} else {
domains.Add(hp)
}
}

certs, err := cfg.AllCertificates()
if err != nil {
return nil, err
}
for i := range certs {
for _, domain := range cryptutil.GetCertificateDomains(&certs[i]) {
domains.Add(domain)
}
}

return domains.ToSlice(), nil
}

Expand Down
15 changes: 13 additions & 2 deletions config/envoyconfig/listeners_test.go
Expand Up @@ -2,6 +2,7 @@ package envoyconfig

import (
"context"
"encoding/base64"
"os"
"path/filepath"
"testing"
Expand All @@ -13,6 +14,7 @@ import (
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/config/envoyconfig/filemgr"
"github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/pkg/cryptutil"
)

const (
Expand Down Expand Up @@ -726,6 +728,11 @@ func Test_buildDownstreamTLSContext(t *testing.T) {
}

func Test_getAllDomains(t *testing.T) {
cert, err := cryptutil.GenerateSelfSignedCertificate("*.unknown.example.com")
require.NoError(t, err)
certPEM, keyPEM, err := cryptutil.EncodeCertificate(cert)
require.NoError(t, err)

options := &config.Options{
Addr: "127.0.0.1:9000",
GRPCAddr: "127.0.0.1:9001",
Expand All @@ -738,6 +745,8 @@ func Test_getAllDomains(t *testing.T) {
{Source: &config.StringURL{URL: mustParseURL(t, "https://b.example.com")}},
{Source: &config.StringURL{URL: mustParseURL(t, "https://c.example.com")}},
},
Cert: base64.StdEncoding.EncodeToString(certPEM),
Key: base64.StdEncoding.EncodeToString(keyPEM),
}
t.Run("routable", func(t *testing.T) {
t.Run("http", func(t *testing.T) {
Expand Down Expand Up @@ -786,9 +795,10 @@ func Test_getAllDomains(t *testing.T) {
})
t.Run("tls", func(t *testing.T) {
t.Run("http", func(t *testing.T) {
actual, err := getAllTLSDomains(options, "127.0.0.1:9000")
actual, err := getAllTLSDomains(&config.Config{Options: options}, "127.0.0.1:9000")
require.NoError(t, err)
expect := []string{
"*.unknown.example.com",
"a.example.com",
"authenticate.example.com",
"b.example.com",
Expand All @@ -797,9 +807,10 @@ func Test_getAllDomains(t *testing.T) {
assert.Equal(t, expect, actual)
})
t.Run("grpc", func(t *testing.T) {
actual, err := getAllTLSDomains(options, "127.0.0.1:9001")
actual, err := getAllTLSDomains(&config.Config{Options: options}, "127.0.0.1:9001")
require.NoError(t, err)
expect := []string{
"*.unknown.example.com",
"authorize.example.com",
"cache.example.com",
}
Expand Down
15 changes: 15 additions & 0 deletions pkg/cryptutil/certificates.go
Expand Up @@ -219,6 +219,21 @@ func GenerateSelfSignedCertificate(domain string, configure ...func(*x509.Certif
return &cert, nil
}

// EncodeCertificate encodes a TLS certificate into PEM compatible byte slices.
calebdoxsey marked this conversation as resolved.
Show resolved Hide resolved
// Returns `nil`, `nil` if there is an error marshaling the PKCS8 private key.
func EncodeCertificate(cert *tls.Certificate) (pemCertificateBytes, pemKeyBytes []byte, err error) {
if cert == nil || len(cert.Certificate) == 0 {
return nil, nil, nil
}
publicKeyBytes := cert.Certificate[0]
privateKeyBytes, err := x509.MarshalPKCS8PrivateKey(cert.PrivateKey)
if err != nil {
return nil, nil, err
}
return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: publicKeyBytes}),
pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privateKeyBytes}), nil
}

// ParsePEMCertificate parses a PEM encoded certificate block.
func ParsePEMCertificate(raw []byte) (*x509.Certificate, error) {
data := raw
Expand Down
15 changes: 15 additions & 0 deletions pkg/cryptutil/certificates_test.go
Expand Up @@ -165,3 +165,18 @@ func TestPrivateKeyMarshaling(t *testing.T) {
t.Fatal("private key encoding did not match")
}
}

func TestEncodeCertificate(t *testing.T) {
t.Run("nil", func(t *testing.T) {
cert, key, err := EncodeCertificate(nil)
assert.NoError(t, err)
assert.Nil(t, cert)
assert.Nil(t, key)
})
t.Run("empty certificate", func(t *testing.T) {
cert, key, err := EncodeCertificate(&tls.Certificate{})
assert.NoError(t, err)
assert.Nil(t, cert)
assert.Nil(t, key)
})
}
24 changes: 24 additions & 0 deletions pkg/cryptutil/tls.go
Expand Up @@ -63,6 +63,30 @@ func GetCertificateForDomain(certificates []tls.Certificate, domain string) (*tl
return GenerateSelfSignedCertificate(domain)
}

// GetCertificateDomains gets all the certificate's matching domain names.
calebdoxsey marked this conversation as resolved.
Show resolved Hide resolved
// Will return an empty slice if certificate is nil, empty, or x509 parsing fails.
func GetCertificateDomains(cert *tls.Certificate) []string {
calebdoxsey marked this conversation as resolved.
Show resolved Hide resolved
if cert == nil || len(cert.Certificate) == 0 {
return nil
}

xcert, err := x509.ParseCertificate(cert.Certificate[0])
if err != nil {
return nil
}

var domains []string
if xcert.Subject.CommonName != "" {
domains = append(domains, xcert.Subject.CommonName)
}
for _, dnsName := range xcert.DNSNames {
if dnsName != "" {
domains = append(domains, dnsName)
}
}
return domains
}

func matchesDomain(cert *tls.Certificate, domain string) bool {
if cert == nil || len(cert.Certificate) == 0 {
return false
Expand Down
7 changes: 7 additions & 0 deletions pkg/cryptutil/tls_test.go
Expand Up @@ -5,6 +5,7 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestGetCertificateForDomain(t *testing.T) {
Expand Down Expand Up @@ -62,3 +63,9 @@ func TestGetCertificateForDomain(t *testing.T) {
assert.NotNil(t, found)
})
}

func TestGetCertificateDomains(t *testing.T) {
cert, err := GenerateSelfSignedCertificate("www.example.com")
require.NoError(t, err)
assert.Equal(t, []string{"www.example.com"}, GetCertificateDomains(cert))
}