Skip to content

Commit

Permalink
add fsnotify-backed cache for reading TLS PKI material (#1256)
Browse files Browse the repository at this point in the history
Signed-off-by: Bob Callaway <bcallaway@google.com>
  • Loading branch information
bobcallaway committed Jul 8, 2023
1 parent 12aa925 commit 489d73a
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 12 deletions.
98 changes: 88 additions & 10 deletions cmd/app/grpc.go
Expand Up @@ -22,7 +22,9 @@ import (
"net"
"os"
"runtime"
"sync"

"github.com/fsnotify/fsnotify"
"github.com/goadesign/goa/grpc/middleware"
ctclient "github.com/google/certificate-transparency-go/client"
grpcmw "github.com/grpc-ecosystem/go-grpc-middleware"
Expand Down Expand Up @@ -50,6 +52,7 @@ type grpcServer struct {
*grpc.Server
grpcServerEndpoint string
caService gw.CAServer
tlsCertWatcher *fsnotify.Watcher
}

func PassFulcioConfigThruContext(cfg *config.FulcioConfig) grpc.UnaryServerInterceptor {
Expand All @@ -67,16 +70,85 @@ func PassFulcioConfigThruContext(cfg *config.FulcioConfig) grpc.UnaryServerInter
}
}

func createGRPCCreds(certPath, keyPath string) (grpc.ServerOption, error) {
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
type cachedTLSCert struct {
sync.RWMutex
certPath string
keyPath string
cert *tls.Certificate
Watcher *fsnotify.Watcher
}

func newCachedTLSCert(certPath, keyPath string) (*cachedTLSCert, error) {
cachedTLSCert := &cachedTLSCert{
certPath: certPath,
keyPath: keyPath,
}
if err := cachedTLSCert.UpdateCertificate(); err != nil {
return nil, err
}
var err error
cachedTLSCert.Watcher, err = fsnotify.NewWatcher()
if err != nil {
return nil, err
}

go func() {
for {
select {
case event, ok := <-cachedTLSCert.Watcher.Events:
if !ok {
return
}
if event.Has(fsnotify.Write) {
log.Logger.Info("fsnotify grpc-tls-certificate write event detected")
if err := cachedTLSCert.UpdateCertificate(); err != nil {
log.Logger.Error(err)
}
}
case err, ok := <-cachedTLSCert.Watcher.Errors:
if !ok {
return
}
log.Logger.Error("fsnotify grpc-tls-certificate error:", err)
}
}
}()

// Add a path.
if err = cachedTLSCert.Watcher.Add(certPath); err != nil {
return nil, err
}
return cachedTLSCert, nil
}

func (c *cachedTLSCert) GetCertificate() *tls.Certificate {
// get reader lock
c.RLock()
defer c.RUnlock()
return c.cert
}

func (c *cachedTLSCert) UpdateCertificate() error {
// get writer lock
c.Lock()
defer c.Unlock()

cert, err := tls.LoadX509KeyPair(c.certPath, c.keyPath)
if err != nil {
return nil, fmt.Errorf("loading GRPC tls certificate and key file: %w", err)
return fmt.Errorf("loading GRPC tls certificate and key file: %w", err)
}

c.cert = &cert
return nil
}

func (c *cachedTLSCert) GRPCCreds() grpc.ServerOption {
return grpc.Creds(credentials.NewTLS(&tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS13,
})), nil
GetCertificate: func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
return c.GetCertificate(), nil
},
MinVersion: tls.VersionTLS13,
}))
}

func createGRPCServer(cfg *config.FulcioConfig, ctClient *ctclient.LogClient, baseca ca.CertificateAuthority, ip identity.IssuerPool) (*grpcServer, error) {
Expand All @@ -94,12 +166,15 @@ func createGRPCServer(cfg *config.FulcioConfig, ctClient *ctclient.LogClient, ba
grpc.MaxRecvMsgSize(int(maxMsgSize)),
}

var tlsCertWatcher *fsnotify.Watcher
if viper.IsSet("grpc-tls-certificate") && viper.IsSet("grpc-tls-key") {
creds, err := createGRPCCreds(viper.GetString("grpc-tls-certificate"), viper.GetString("grpc-tls-key"))
cachedTLSCert, err := newCachedTLSCert(viper.GetString("grpc-tls-certificate"), viper.GetString("grpc-tls-key"))
if err != nil {
return nil, err
}
serverOpts = append(serverOpts, creds)

tlsCertWatcher = cachedTLSCert.Watcher
serverOpts = append(serverOpts, cachedTLSCert.GRPCCreds())
}

myServer := grpc.NewServer(serverOpts...)
Expand All @@ -109,7 +184,7 @@ func createGRPCServer(cfg *config.FulcioConfig, ctClient *ctclient.LogClient, ba
gw.RegisterCAServer(myServer, grpcCAServer)

grpcServerEndpoint := fmt.Sprintf("%s:%s", viper.GetString("grpc-host"), viper.GetString("grpc-port"))
return &grpcServer{myServer, grpcServerEndpoint, grpcCAServer}, nil
return &grpcServer{myServer, grpcServerEndpoint, grpcCAServer, tlsCertWatcher}, nil
}

func (g *grpcServer) setupPrometheus(reg *prometheus.Registry) {
Expand All @@ -129,6 +204,9 @@ func (g *grpcServer) startTCPListener() {
g.grpcServerEndpoint = lis.Addr().String()
log.Logger.Infof("listening on grpc at %s", g.grpcServerEndpoint)
go func() {
if g.tlsCertWatcher != nil {
defer g.tlsCertWatcher.Close()
}
if err := g.Server.Serve(lis); err != nil {
log.Logger.Errorf("error shutting down grpcServer: %w", err)
}
Expand Down Expand Up @@ -179,7 +257,7 @@ func createLegacyGRPCServer(cfg *config.FulcioConfig, v2Server gw.CAServer) (*gr
// Register your gRPC service implementations.
gw_legacy.RegisterCAServer(myServer, legacyGRPCCAServer)

return &grpcServer{myServer, LegacyUnixDomainSocket, v2Server}, nil
return &grpcServer{myServer, LegacyUnixDomainSocket, v2Server, nil}, nil
}

func panicRecoveryHandler(ctx context.Context, p interface{}) error {
Expand Down
61 changes: 59 additions & 2 deletions cmd/app/grpc_test.go
Expand Up @@ -16,9 +16,11 @@
package app

import (
"bytes"
"os"
"path/filepath"
"testing"
"time"
)

const keyPEM = `-----BEGIN PRIVATE KEY-----
Expand Down Expand Up @@ -108,7 +110,41 @@ tSLmTWsb+j/Oxljalf+rAlItYk297HN0xMvlkHkB80O5Un6OMCHAjJmfOVZal2Y5
o4ZDR+PzKEbU8eUQbooS
-----END CERTIFICATE-----`

func TestCreateGRPCCreds(t *testing.T) {
const renewedCertPEM = `-----BEGIN CERTIFICATE-----
MIIFqzCCA5OgAwIBAgIUYKKd201v0q4S0FVSdKvRvrC9JQIwDQYJKoZIhvcNAQEL
BQAwZTELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAldBMREwDwYDVQQHDAhLaXJrbGFu
ZDERMA8GA1UECgwIU2lnc3RvcmUxDzANBgNVBAsMBkZ1bGNpbzESMBAGA1UEAwwJ
bG9jYWxob3N0MB4XDTIzMDcwMTE4NTUzM1oXDTMzMDYyODE4NTUzM1owZTELMAkG
A1UEBhMCVVMxCzAJBgNVBAgMAldBMREwDwYDVQQHDAhLaXJrbGFuZDERMA8GA1UE
CgwIU2lnc3RvcmUxDzANBgNVBAsMBkZ1bGNpbzESMBAGA1UEAwwJbG9jYWxob3N0
MIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAvc3MNQYRO5ytG+8FsPn5
0Z8koUG4sYPf6ZLTMP78+sRYfz2ggZaP46Hl3f571AB8nXSBRbIc9byDgOVpOs9t
zRnYA6tyv9cgtOCtgFfuctHFdeWpJXQr4wWhB0oUspmu66cmFKYfdwrfnrjvRkZ2
+33fGK0hC6EXtNIX7sg+Y98jT1iW0AIiBZxMBf8p5d5fTEfodVN/NZ0FN58/TOPO
jmkStqta+fZrPL02TVZ/IdDx7RSWVqD+KcqOJdSneuCt5qniQgcLwBMIk5ymyMLP
+yKl1GQXHfagxh3e84HlE93XgVCThk9XK8HSQerUmH0oklIx2PgcSS0FPjsz49Us
Kd7QweZAzcaaaoapK0QkXRvUZLSwPgOjriJtd6Pi5S7xN91DnGmyDHA7BGgCtew7
1BUgW2AzWXJq3EX0kHjMaEgHCL59SUw/pOlMiNXMC6hnUSH5lnY2isNs9+DUU+Xa
/Z9ME+B0SiRCaRGq7ZUdXiHuaN+DiRj3hX1VO96wVjvZAh0JklI6pVB7cz6HvOwx
iAtSiXxqIQZkyac3lP939tAFzLVvpqSqaHzUF8bqBSkWxy8iZVW9EJiIb8wAVE8R
Sl4WssnUrneMfXjxsyQ271H6DIDLWP4BHtorqcN0vGnOE37N6DjrOJaaTyaThn2q
Kjmt6ghqvTY0CRVpyQz3szUCAwEAAaNTMFEwHQYDVR0OBBYEFGgoph9DIwXUHUT0
8y7CtcviGmPhMB8GA1UdIwQYMBaAFGgoph9DIwXUHUT08y7CtcviGmPhMA8GA1Ud
EwEB/wQFMAMBAf8wDQYJKoZIhvcNAQELBQADggIBAJp6m3a6nMlpzZ9doLh9WOmj
dCz68oStA0HMflHGkGgtZ9BDNWQI5EL8rnsdj4l5h4nte7s4AQrtoHGdT/QkIwUu
1Zv3XBypubeof2iUJ+UhIz2Lm9vqt1hOT1HM3U7/4HCf/730y7LqwWIm580PLw1v
73C4kizrzg4T7C/4Jy5/lF5PzmJOVPWd0LDomWzbpw0pM2h8cYUY02HVEaeVOfHP
lCoTW7cL36q3I/0RAyjGmK8nNGNpJTB4xjdTT4TJRotkhcKEsTVfIrVJFGlVMDoT
Fe9T6rAnYQ2TlnPRhp9tDqiRb3Y027nJouddjulgGTRudUAzNkg7lVWJSDc4PwO2
7gm7I5/mbil6bI1r4djV0FPJZZI7EHgOM8OmKKqo5sLN0WQigZ8GSH1KHOuR1d2j
m6GJOdUJ7+ZQ+tej0pwNERMSl0+OY+FtsFMusLXoIUUTyaOs1cpFO0ifqZSp9eUP
50QDoeDGYS0T/0RicNDXMTltE24G3L7mHiPa5rr4tlYvVHYeoez7qFtG5LvDvBH8
OVZzfoJGB4MbgrFxymai/9i+hdYadt2UHL9BfnUUInDhi/2l0MB1a45DYOzf5Zf0
IYjZh058kDhlL7WOkAhPdvm2wD9KAm4FDInw49PYxpasmDKaOTf1WeWfNa5CVOYp
wpLmyTovgQl/NcXO7caU
-----END CERTIFICATE-----`

func TestCachedTLSCert(t *testing.T) {
dir := t.TempDir()

// not PKI material
Expand Down Expand Up @@ -166,10 +202,31 @@ func TestCreateGRPCCreds(t *testing.T) {

for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
_, err := createGRPCCreds(tc.certPath, tc.keyPath)
cachedCert, err := newCachedTLSCert(tc.certPath, tc.keyPath)
if tc.success != (err == nil) {
t.Errorf("unexpected result: %v", err)
}
if tc.success {
cert := cachedCert.GetCertificate()
if cert == nil {
t.Fatal("unexpected error reading tls.Certificate object")
}

// update the cert on disk to a renewed value (representing same public/private keypair)
os.WriteFile(tc.certPath, []byte(renewedCertPEM), 0644)

// sleep for a second to let goroutine fire for fsnotify event
time.Sleep(1 * time.Second)

renewedCert := cachedCert.GetCertificate()
if renewedCert == nil {
t.Fatal("unexpected error reading renewed tls.Certificate object")
}

if bytes.Equal(cert.Certificate[0], renewedCert.Certificate[0]) {
t.Fatal("got same certificate after overwriting renewed cert to same file")
}
}
})
}
}

0 comments on commit 489d73a

Please sign in to comment.