Skip to content

Commit

Permalink
Dynamically reload cert/keys if changed on disk
Browse files Browse the repository at this point in the history
  • Loading branch information
jtackaberry committed Jan 13, 2024
1 parent 2d6c89d commit f5a95a0
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 8 deletions.
18 changes: 10 additions & 8 deletions rtls/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,16 @@ const (
// should appear in the returned certificate. If noverify is true, the client will not verify
// the server's certificate.
func CreateClientConfig(certFile, keyFile, caCertFile, serverName string, noverify bool) (*tls.Config, error) {
var err error

config := createBaseTLSConfig(serverName, noverify)
if certFile != "" && keyFile != "" {
config.Certificates = make([]tls.Certificate, 1)
config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
mc, err := newMonitoredCertificate(certFile, keyFile)
if err != nil {
return nil, err
}
config.GetClientCertificate = func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
// Get() will never return nil
return mc.Get(), nil
}
}
if caCertFile != "" {
asn1Data, err := os.ReadFile(caCertFile)
Expand All @@ -61,14 +62,15 @@ func CreateClientConfig(certFile, keyFile, caCertFile, serverName string, noveri
// client. If mtls is MTLSStateEnabled, the server will require the client to present a
// valid certificate.
func CreateServerConfig(certFile, keyFile, caCertFile string, mtls MTLSState) (*tls.Config, error) {
var err error

config := createBaseTLSConfig(NoServerName, false)
config.Certificates = make([]tls.Certificate, 1)
config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
mc, err := newMonitoredCertificate(certFile, keyFile)
if err != nil {
return nil, err
}
config.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
// Get() will never return nil
return mc.Get(), nil
}
if caCertFile != "" {
asn1Data, err := os.ReadFile(caCertFile)
if err != nil {
Expand Down
83 changes: 83 additions & 0 deletions rtls/monitor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package rtls

import (
"crypto/tls"
"log"
"os"
"sync/atomic"
"time"
)

// monitorFiles watches the given files for changes, calling the supplied callback when a
// change is detected. The callback is passed a slice of the files that changed. The loop
// never terminates, so it's expected the caller will invoke this in a goroutine.
func monitorFiles(files []string, cb func([]string)) {
mtimes := make(map[string]time.Time, len(files))
for {
var changed []string
for _, file := range files {
info, err := os.Stat(file)
if err != nil {
continue
}
if !mtimes[file].IsZero() && info.ModTime() != mtimes[file] {
changed = append(changed, file)
}
mtimes[file] = info.ModTime()
}
if changed != nil {
cb(changed)
}
time.Sleep(2 * time.Second)
}
}

// monitoredCertificate watches the given certificate and key files for changes, reloading
// the certificate if a change is detected.
type monitoredCertificate struct {
certFile string
keyFile string
// This holds a *tls.Certificate which can be updated from another goroutine, so use
// an atomic value to synchronize access
cert atomic.Value
}

// newMonitoredCertificate loads the certificate and key and spawns a goroutine to watch
// the files for changes.
func newMonitoredCertificate(certFile, keyFile string) (*monitoredCertificate, error) {
mc := &monitoredCertificate{
certFile: certFile,
keyFile: keyFile,
}
// Prime the cached certificate and induce any errors immediately
if err := mc.load(); err != nil {
return nil, err
}
// Now start watching for changes
go monitorFiles([]string{certFile, keyFile}, func(changed []string) {
if err := mc.load(); err != nil {
log.Printf("error reloading certificate %s, not replacing: %s", certFile, err)
} else {
log.Printf("reloaded certificate %s", certFile)
}
})
return mc, nil
}

// load reads the certificate and key files from disk and caches the result for Get()
func (mc *monitoredCertificate) load() error {
cert, err := tls.LoadX509KeyPair(mc.certFile, mc.keyFile)
if err != nil {
return err
}
mc.cert.Store(&cert)
return nil
}

// Get returns the certificate.
//
// This will never return nil provided the monitoredCertificate object was created via
// newMonitoredCertificate(),
func (mc *monitoredCertificate) Get() *tls.Certificate {
return mc.cert.Load().(*tls.Certificate)
}

0 comments on commit f5a95a0

Please sign in to comment.