diff --git a/cmd/app/grpc.go b/cmd/app/grpc.go index 51587f206..8ed1db727 100644 --- a/cmd/app/grpc.go +++ b/cmd/app/grpc.go @@ -22,7 +22,9 @@ import ( "net" "os" "runtime" + "sync" + "github.com/fsnotify/fsnotify" "github.com/goadesign/goa/grpc/middleware" ctclient "github.com/google/certificate-transparency-go/client" grpcmw "github.com/grpc-ecosystem/go-grpc-middleware" @@ -50,6 +52,7 @@ type grpcServer struct { *grpc.Server grpcServerEndpoint string caService gw.CAServer + tlsCertWatcher *fsnotify.Watcher } func PassFulcioConfigThruContext(cfg *config.FulcioConfig) grpc.UnaryServerInterceptor { @@ -67,16 +70,85 @@ func PassFulcioConfigThruContext(cfg *config.FulcioConfig) grpc.UnaryServerInter } } -func createGRPCCreds(certPath, keyPath string) (grpc.ServerOption, error) { - cert, err := tls.LoadX509KeyPair(certPath, keyPath) +type cachedTLSCert struct { + sync.RWMutex + certPath string + keyPath string + cert *tls.Certificate + Watcher *fsnotify.Watcher +} + +func newCachedTLSCert(certPath, keyPath string) (*cachedTLSCert, error) { + cachedTLSCert := &cachedTLSCert{ + certPath: certPath, + keyPath: keyPath, + } + if err := cachedTLSCert.UpdateCertificate(); err != nil { + return nil, err + } + var err error + cachedTLSCert.Watcher, err = fsnotify.NewWatcher() + if err != nil { + return nil, err + } + + go func() { + for { + select { + case event, ok := <-cachedTLSCert.Watcher.Events: + if !ok { + return + } + if event.Has(fsnotify.Write) { + log.Logger.Info("fsnotify grpc-tls-certificate write event detected") + if err := cachedTLSCert.UpdateCertificate(); err != nil { + log.Logger.Error(err) + } + } + case err, ok := <-cachedTLSCert.Watcher.Errors: + if !ok { + return + } + log.Logger.Error("fsnotify grpc-tls-certificate error:", err) + } + } + }() + + // Add a path. + if err = cachedTLSCert.Watcher.Add(certPath); err != nil { + return nil, err + } + return cachedTLSCert, nil +} + +func (c *cachedTLSCert) GetCertificate() *tls.Certificate { + // get reader lock + c.RLock() + defer c.RUnlock() + return c.cert +} + +func (c *cachedTLSCert) UpdateCertificate() error { + // get writer lock + c.Lock() + defer c.Unlock() + + cert, err := tls.LoadX509KeyPair(c.certPath, c.keyPath) if err != nil { - return nil, fmt.Errorf("loading GRPC tls certificate and key file: %w", err) + return fmt.Errorf("loading GRPC tls certificate and key file: %w", err) } + c.cert = &cert + return nil +} + +func (c *cachedTLSCert) GRPCCreds() grpc.ServerOption { return grpc.Creds(credentials.NewTLS(&tls.Config{ - Certificates: []tls.Certificate{cert}, - MinVersion: tls.VersionTLS13, - })), nil + GetCertificate: func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { + return c.GetCertificate(), nil + }, + MinVersion: tls.VersionTLS13, + })) } func createGRPCServer(cfg *config.FulcioConfig, ctClient *ctclient.LogClient, baseca ca.CertificateAuthority, ip identity.IssuerPool) (*grpcServer, error) { @@ -94,12 +166,15 @@ func createGRPCServer(cfg *config.FulcioConfig, ctClient *ctclient.LogClient, ba grpc.MaxRecvMsgSize(int(maxMsgSize)), } + var tlsCertWatcher *fsnotify.Watcher if viper.IsSet("grpc-tls-certificate") && viper.IsSet("grpc-tls-key") { - creds, err := createGRPCCreds(viper.GetString("grpc-tls-certificate"), viper.GetString("grpc-tls-key")) + cachedTLSCert, err := newCachedTLSCert(viper.GetString("grpc-tls-certificate"), viper.GetString("grpc-tls-key")) if err != nil { return nil, err } - serverOpts = append(serverOpts, creds) + + tlsCertWatcher = cachedTLSCert.Watcher + serverOpts = append(serverOpts, cachedTLSCert.GRPCCreds()) } myServer := grpc.NewServer(serverOpts...) @@ -109,7 +184,7 @@ func createGRPCServer(cfg *config.FulcioConfig, ctClient *ctclient.LogClient, ba gw.RegisterCAServer(myServer, grpcCAServer) grpcServerEndpoint := fmt.Sprintf("%s:%s", viper.GetString("grpc-host"), viper.GetString("grpc-port")) - return &grpcServer{myServer, grpcServerEndpoint, grpcCAServer}, nil + return &grpcServer{myServer, grpcServerEndpoint, grpcCAServer, tlsCertWatcher}, nil } func (g *grpcServer) setupPrometheus(reg *prometheus.Registry) { @@ -129,6 +204,9 @@ func (g *grpcServer) startTCPListener() { g.grpcServerEndpoint = lis.Addr().String() log.Logger.Infof("listening on grpc at %s", g.grpcServerEndpoint) go func() { + if g.tlsCertWatcher != nil { + defer g.tlsCertWatcher.Close() + } if err := g.Server.Serve(lis); err != nil { log.Logger.Errorf("error shutting down grpcServer: %w", err) } @@ -179,7 +257,7 @@ func createLegacyGRPCServer(cfg *config.FulcioConfig, v2Server gw.CAServer) (*gr // Register your gRPC service implementations. gw_legacy.RegisterCAServer(myServer, legacyGRPCCAServer) - return &grpcServer{myServer, LegacyUnixDomainSocket, v2Server}, nil + return &grpcServer{myServer, LegacyUnixDomainSocket, v2Server, nil}, nil } func panicRecoveryHandler(ctx context.Context, p interface{}) error { diff --git a/cmd/app/grpc_test.go b/cmd/app/grpc_test.go index 26dadd686..ea41980b3 100644 --- a/cmd/app/grpc_test.go +++ b/cmd/app/grpc_test.go @@ -16,9 +16,11 @@ package app import ( + "bytes" "os" "path/filepath" "testing" + "time" ) const keyPEM = `-----BEGIN PRIVATE KEY----- @@ -108,7 +110,41 @@ tSLmTWsb+j/Oxljalf+rAlItYk297HN0xMvlkHkB80O5Un6OMCHAjJmfOVZal2Y5 o4ZDR+PzKEbU8eUQbooS -----END CERTIFICATE-----` -func TestCreateGRPCCreds(t *testing.T) { +const renewedCertPEM = `-----BEGIN CERTIFICATE----- +MIIFqzCCA5OgAwIBAgIUYKKd201v0q4S0FVSdKvRvrC9JQIwDQYJKoZIhvcNAQEL +BQAwZTELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAldBMREwDwYDVQQHDAhLaXJrbGFu +ZDERMA8GA1UECgwIU2lnc3RvcmUxDzANBgNVBAsMBkZ1bGNpbzESMBAGA1UEAwwJ +bG9jYWxob3N0MB4XDTIzMDcwMTE4NTUzM1oXDTMzMDYyODE4NTUzM1owZTELMAkG +A1UEBhMCVVMxCzAJBgNVBAgMAldBMREwDwYDVQQHDAhLaXJrbGFuZDERMA8GA1UE +CgwIU2lnc3RvcmUxDzANBgNVBAsMBkZ1bGNpbzESMBAGA1UEAwwJbG9jYWxob3N0 +MIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAvc3MNQYRO5ytG+8FsPn5 +0Z8koUG4sYPf6ZLTMP78+sRYfz2ggZaP46Hl3f571AB8nXSBRbIc9byDgOVpOs9t +zRnYA6tyv9cgtOCtgFfuctHFdeWpJXQr4wWhB0oUspmu66cmFKYfdwrfnrjvRkZ2 ++33fGK0hC6EXtNIX7sg+Y98jT1iW0AIiBZxMBf8p5d5fTEfodVN/NZ0FN58/TOPO +jmkStqta+fZrPL02TVZ/IdDx7RSWVqD+KcqOJdSneuCt5qniQgcLwBMIk5ymyMLP ++yKl1GQXHfagxh3e84HlE93XgVCThk9XK8HSQerUmH0oklIx2PgcSS0FPjsz49Us +Kd7QweZAzcaaaoapK0QkXRvUZLSwPgOjriJtd6Pi5S7xN91DnGmyDHA7BGgCtew7 +1BUgW2AzWXJq3EX0kHjMaEgHCL59SUw/pOlMiNXMC6hnUSH5lnY2isNs9+DUU+Xa +/Z9ME+B0SiRCaRGq7ZUdXiHuaN+DiRj3hX1VO96wVjvZAh0JklI6pVB7cz6HvOwx +iAtSiXxqIQZkyac3lP939tAFzLVvpqSqaHzUF8bqBSkWxy8iZVW9EJiIb8wAVE8R +Sl4WssnUrneMfXjxsyQ271H6DIDLWP4BHtorqcN0vGnOE37N6DjrOJaaTyaThn2q +Kjmt6ghqvTY0CRVpyQz3szUCAwEAAaNTMFEwHQYDVR0OBBYEFGgoph9DIwXUHUT0 +8y7CtcviGmPhMB8GA1UdIwQYMBaAFGgoph9DIwXUHUT08y7CtcviGmPhMA8GA1Ud +EwEB/wQFMAMBAf8wDQYJKoZIhvcNAQELBQADggIBAJp6m3a6nMlpzZ9doLh9WOmj +dCz68oStA0HMflHGkGgtZ9BDNWQI5EL8rnsdj4l5h4nte7s4AQrtoHGdT/QkIwUu +1Zv3XBypubeof2iUJ+UhIz2Lm9vqt1hOT1HM3U7/4HCf/730y7LqwWIm580PLw1v +73C4kizrzg4T7C/4Jy5/lF5PzmJOVPWd0LDomWzbpw0pM2h8cYUY02HVEaeVOfHP +lCoTW7cL36q3I/0RAyjGmK8nNGNpJTB4xjdTT4TJRotkhcKEsTVfIrVJFGlVMDoT +Fe9T6rAnYQ2TlnPRhp9tDqiRb3Y027nJouddjulgGTRudUAzNkg7lVWJSDc4PwO2 +7gm7I5/mbil6bI1r4djV0FPJZZI7EHgOM8OmKKqo5sLN0WQigZ8GSH1KHOuR1d2j +m6GJOdUJ7+ZQ+tej0pwNERMSl0+OY+FtsFMusLXoIUUTyaOs1cpFO0ifqZSp9eUP +50QDoeDGYS0T/0RicNDXMTltE24G3L7mHiPa5rr4tlYvVHYeoez7qFtG5LvDvBH8 +OVZzfoJGB4MbgrFxymai/9i+hdYadt2UHL9BfnUUInDhi/2l0MB1a45DYOzf5Zf0 +IYjZh058kDhlL7WOkAhPdvm2wD9KAm4FDInw49PYxpasmDKaOTf1WeWfNa5CVOYp +wpLmyTovgQl/NcXO7caU +-----END CERTIFICATE-----` + +func TestCachedTLSCert(t *testing.T) { dir := t.TempDir() // not PKI material @@ -166,10 +202,31 @@ func TestCreateGRPCCreds(t *testing.T) { for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { - _, err := createGRPCCreds(tc.certPath, tc.keyPath) + cachedCert, err := newCachedTLSCert(tc.certPath, tc.keyPath) if tc.success != (err == nil) { t.Errorf("unexpected result: %v", err) } + if tc.success { + cert := cachedCert.GetCertificate() + if cert == nil { + t.Fatal("unexpected error reading tls.Certificate object") + } + + // update the cert on disk to a renewed value (representing same public/private keypair) + os.WriteFile(tc.certPath, []byte(renewedCertPEM), 0644) + + // sleep for a second to let goroutine fire for fsnotify event + time.Sleep(1 * time.Second) + + renewedCert := cachedCert.GetCertificate() + if renewedCert == nil { + t.Fatal("unexpected error reading renewed tls.Certificate object") + } + + if bytes.Equal(cert.Certificate[0], renewedCert.Certificate[0]) { + t.Fatal("got same certificate after overwriting renewed cert to same file") + } + } }) } }