-
Notifications
You must be signed in to change notification settings - Fork 917
/
Copy pathcertreloader.go
207 lines (176 loc) · 6.1 KB
/
certreloader.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
package tlsconfig
import (
"crypto/tls"
"crypto/x509"
"fmt"
"os"
"runtime"
"sync"
"github.com/getsentry/sentry-go"
"github.com/pkg/errors"
"github.com/rs/zerolog"
"github.com/urfave/cli/v2"
)
const (
OriginCAPoolFlag = "origin-ca-pool"
CaCertFlag = "cacert"
)
// CertReloader can load and reload a TLS certificate from a particular filepath.
// Hooks into tls.Config's GetCertificate to allow a TLS server to update its certificate without restarting.
type CertReloader struct {
sync.Mutex
certificate *tls.Certificate
certPath string
keyPath string
}
// NewCertReloader makes a CertReloader. It loads the cert during initialization to make sure certPath and keyPath are valid
func NewCertReloader(certPath, keyPath string) (*CertReloader, error) {
cr := new(CertReloader)
cr.certPath = certPath
cr.keyPath = keyPath
if err := cr.LoadCert(); err != nil {
return nil, err
}
return cr, nil
}
// Cert returns the TLS certificate most recently read by the CertReloader.
// This method works as a direct utility method for tls.Config#Cert.
func (cr *CertReloader) Cert(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
cr.Lock()
defer cr.Unlock()
return cr.certificate, nil
}
// ClientCert returns the TLS certificate most recently read by the CertReloader.
// This method works as a direct utility method for tls.Config#ClientCert.
func (cr *CertReloader) ClientCert(certRequestInfo *tls.CertificateRequestInfo) (*tls.Certificate, error) {
cr.Lock()
defer cr.Unlock()
return cr.certificate, nil
}
// LoadCert loads a TLS certificate from the CertReloader's specified filepath.
// Call this after writing a new certificate to the disk (e.g. after renewing a certificate)
func (cr *CertReloader) LoadCert() error {
cr.Lock()
defer cr.Unlock()
cert, err := tls.LoadX509KeyPair(cr.certPath, cr.keyPath)
// Keep the old certificate if there's a problem reading the new one.
if err != nil {
sentry.CaptureException(fmt.Errorf("Error parsing X509 key pair: %v", err))
return err
}
cr.certificate = &cert
return nil
}
func LoadOriginCA(originCAPoolFilename string, log *zerolog.Logger) (*x509.CertPool, error) {
var originCustomCAPool []byte
if originCAPoolFilename != "" {
var err error
originCustomCAPool, err = os.ReadFile(originCAPoolFilename)
if err != nil {
return nil, errors.Wrap(err, fmt.Sprintf("unable to read the file %s for --%s", originCAPoolFilename, OriginCAPoolFlag))
}
}
originCertPool, err := loadOriginCertPool(originCustomCAPool, log)
if err != nil {
return nil, errors.Wrap(err, "error loading the certificate pool")
}
// Windows users should be notified that they can use the flag
if runtime.GOOS == "windows" && originCAPoolFilename == "" {
log.Info().Msgf("cloudflared does not support loading the system root certificate pool on Windows. Please use --%s <PATH> to specify the path to the certificate pool", OriginCAPoolFlag)
}
return originCertPool, nil
}
func LoadCustomOriginCA(originCAFilename string) (*x509.CertPool, error) {
// First, obtain the system certificate pool
certPool, err := x509.SystemCertPool()
if err != nil {
certPool = x509.NewCertPool()
}
// Next, append the Cloudflare CAs into the system pool
cfRootCA, err := GetCloudflareRootCA()
if err != nil {
return nil, errors.Wrap(err, "could not append Cloudflare Root CAs to cloudflared certificate pool")
}
for _, cert := range cfRootCA {
certPool.AddCert(cert)
}
if originCAFilename == "" {
return certPool, nil
}
customOriginCA, err := os.ReadFile(originCAFilename)
if err != nil {
return nil, errors.Wrap(err, fmt.Sprintf("unable to read the file %s", originCAFilename))
}
if !certPool.AppendCertsFromPEM(customOriginCA) {
return nil, fmt.Errorf("error appending custom CA to cert pool")
}
return certPool, nil
}
func CreateTunnelConfig(c *cli.Context, serverName string) (*tls.Config, error) {
var rootCAs []string
if c.String(CaCertFlag) != "" {
rootCAs = append(rootCAs, c.String(CaCertFlag))
}
userConfig := &TLSParameters{RootCAs: rootCAs, ServerName: serverName}
tlsConfig, err := GetConfig(userConfig)
if err != nil {
return nil, err
}
if tlsConfig.RootCAs == nil {
rootCAPool, err := x509.SystemCertPool()
if err != nil {
return nil, errors.Wrap(err, "unable to get x509 system cert pool")
}
cfRootCA, err := GetCloudflareRootCA()
if err != nil {
return nil, errors.Wrap(err, "could not append Cloudflare Root CAs to cloudflared certificate pool")
}
for _, cert := range cfRootCA {
rootCAPool.AddCert(cert)
}
tlsConfig.RootCAs = rootCAPool
}
if tlsConfig.ServerName == "" && !tlsConfig.InsecureSkipVerify {
return nil, fmt.Errorf("either ServerName or InsecureSkipVerify must be specified in the tls.Config")
}
return tlsConfig, nil
}
func loadOriginCertPool(originCAPoolPEM []byte, log *zerolog.Logger) (*x509.CertPool, error) {
// Get the global pool
certPool, err := loadGlobalCertPool(log)
if err != nil {
return nil, err
}
// Then, add any custom origin CA pool the user may have passed
if originCAPoolPEM != nil {
if !certPool.AppendCertsFromPEM(originCAPoolPEM) {
log.Info().Msg("could not append the provided origin CA to the cloudflared certificate pool")
}
}
return certPool, nil
}
func loadGlobalCertPool(log *zerolog.Logger) (*x509.CertPool, error) {
// First, obtain the system certificate pool
certPool, err := x509.SystemCertPool()
if err != nil {
if runtime.GOOS != "windows" { // See https://github.com/golang/go/issues/16736
log.Err(err).Msg("error obtaining the system certificates")
}
certPool = x509.NewCertPool()
}
// Next, append the Cloudflare CAs into the system pool
cfRootCA, err := GetCloudflareRootCA()
if err != nil {
return nil, errors.Wrap(err, "could not append Cloudflare Root CAs to cloudflared certificate pool")
}
for _, cert := range cfRootCA {
certPool.AddCert(cert)
}
// Finally, add the Hello certificate into the pool (since it's self-signed)
helloCert, err := GetHelloCertificateX509()
if err != nil {
return nil, errors.Wrap(err, "could not append Hello server certificate to cloudflared certificate pool")
}
certPool.AddCert(helloCert)
return certPool, nil
}