/
tls.go
88 lines (75 loc) · 2.28 KB
/
tls.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
package cryptutil
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"fmt"
"io/ioutil"
"github.com/caddyserver/certmagic"
"github.com/pomerium/pomerium/internal/log"
)
// GetCertPool gets a cert pool for the given CA or CAFile.
func GetCertPool(ca, caFile string) (*x509.CertPool, error) {
ctx := context.TODO()
rootCAs, err := x509.SystemCertPool()
if err != nil {
log.Error(ctx).Err(err).Msg("pkg/cryptutil: failed getting system cert pool making new one")
rootCAs = x509.NewCertPool()
}
if ca == "" && caFile == "" {
return rootCAs, nil
}
var data []byte
if ca != "" {
data, err = base64.StdEncoding.DecodeString(ca)
if err != nil {
return nil, fmt.Errorf("failed to decode base64-encoded certificate authority: %w", err)
}
} else {
data, err = ioutil.ReadFile(caFile)
if err != nil {
return nil, fmt.Errorf("failed to read certificate authority file (%s): %w", caFile, err)
}
}
if ok := rootCAs.AppendCertsFromPEM(data); !ok {
return nil, fmt.Errorf("failed to append any PEM-encoded certificates")
}
log.Debug(ctx).Msg("pkg/cryptutil: added custom certificate authority")
return rootCAs, nil
}
// GetCertificateForDomain returns the tls Certificate which matches the given domain name.
// It should handle both exact matches and wildcard matches. If none of those match, the first certificate will be used.
// Finally if there are no matching certificates one will be generated.
func GetCertificateForDomain(certificates []tls.Certificate, domain string) (*tls.Certificate, error) {
// first try a direct name match
for i := range certificates {
if matchesDomain(&certificates[i], domain) {
return &certificates[i], nil
}
}
// next use the first cert
if len(certificates) > 0 {
return &certificates[0], nil
}
// finally fall back to a generated, self-signed certificate
return GenerateSelfSignedCertificate(domain)
}
func matchesDomain(cert *tls.Certificate, domain string) bool {
if cert == nil || len(cert.Certificate) == 0 {
return false
}
xcert, err := x509.ParseCertificate(cert.Certificate[0])
if err != nil {
return false
}
if certmagic.MatchWildcard(domain, xcert.Subject.CommonName) {
return true
}
for _, san := range xcert.DNSNames {
if certmagic.MatchWildcard(domain, san) {
return true
}
}
return false
}