diff --git a/rtls/config.go b/rtls/config.go index 108a16774..87d012602 100644 --- a/rtls/config.go +++ b/rtls/config.go @@ -30,15 +30,16 @@ const ( // should appear in the returned certificate. If noverify is true, the client will not verify // the server's certificate. func CreateClientConfig(certFile, keyFile, caCertFile, serverName string, noverify bool) (*tls.Config, error) { - var err error - config := createBaseTLSConfig(serverName, noverify) if certFile != "" && keyFile != "" { - config.Certificates = make([]tls.Certificate, 1) - config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) + mc, err := newMonitoredCertificate(certFile, keyFile) if err != nil { return nil, err } + config.GetClientCertificate = func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) { + // Get() will never return nil + return mc.Get(), nil + } } if caCertFile != "" { asn1Data, err := os.ReadFile(caCertFile) @@ -61,14 +62,15 @@ func CreateClientConfig(certFile, keyFile, caCertFile, serverName string, noveri // client. If mtls is MTLSStateEnabled, the server will require the client to present a // valid certificate. func CreateServerConfig(certFile, keyFile, caCertFile string, mtls MTLSState) (*tls.Config, error) { - var err error - config := createBaseTLSConfig(NoServerName, false) - config.Certificates = make([]tls.Certificate, 1) - config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) + mc, err := newMonitoredCertificate(certFile, keyFile) if err != nil { return nil, err } + config.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + // Get() will never return nil + return mc.Get(), nil + } if caCertFile != "" { asn1Data, err := os.ReadFile(caCertFile) if err != nil { diff --git a/rtls/monitor.go b/rtls/monitor.go new file mode 100644 index 000000000..3a74b8edf --- /dev/null +++ b/rtls/monitor.go @@ -0,0 +1,83 @@ +package rtls + +import ( + "crypto/tls" + "log" + "os" + "sync/atomic" + "time" +) + +// monitorFiles watches the given files for changes, calling the supplied callback when a +// change is detected. The callback is passed a slice of the files that changed. The loop +// never terminates, so it's expected the caller will invoke this in a goroutine. +func monitorFiles(files []string, cb func([]string)) { + mtimes := make([]time.Time, len(files)) + for { + var changed []string + for n, file := range files { + info, err := os.Stat(file) + if err != nil { + continue + } + if !mtimes[n].IsZero() && info.ModTime() != mtimes[n] { + changed = append(changed, file) + } + mtimes[n] = info.ModTime() + } + if changed != nil { + cb(changed) + } + time.Sleep(2 * time.Second) + } +} + +// monitoredCertificate watches the given certificate and key files for changes, reloading +// the certificate if a change is detected. +type monitoredCertificate struct { + certFile string + keyFile string + // This holds a *tls.Certificate which can be updated from another goroutine, so use + // an atomic value to synchronize access + cert atomic.Value +} + +// newMonitoredCertificate loads the certificate and key and spawns a goroutine to watch +// the files for changes. +func newMonitoredCertificate(certFile, keyFile string) (*monitoredCertificate, error) { + mc := &monitoredCertificate{ + certFile: certFile, + keyFile: keyFile, + } + // Prime the cached certificate and induce any errors immediately + if err := mc.load(); err != nil { + return nil, err + } + // Now start watching for changes + go monitorFiles([]string{certFile, keyFile}, func(changed []string) { + if err := mc.load(); err != nil { + log.Printf("error reloading certificate %s, not replacing: %s", certFile, err) + } else { + log.Printf("reloaded certificate %s", certFile) + } + }) + return mc, nil +} + +// load reads the certificate and key files from disk and caches the result for Get() +func (mc *monitoredCertificate) load() error { + cert, err := tls.LoadX509KeyPair(mc.certFile, mc.keyFile) + if err != nil { + return err + } + mc.cert.Store(&cert) + return nil +} + +// Get returns the certificate. +// +// This will never return nil provided the monitoredCertificate object was created via +// newMonitoredCertificate(), +func (mc *monitoredCertificate) Get() *tls.Certificate { + return mc.cert.Load().(*tls.Certificate) +}