Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add fsnotify-backed cache for reading TLS PKI material #1256

Merged
merged 1 commit into from
Jul 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
98 changes: 88 additions & 10 deletions cmd/app/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ import (
"net"
"os"
"runtime"
"sync"

"github.com/fsnotify/fsnotify"
"github.com/goadesign/goa/grpc/middleware"
ctclient "github.com/google/certificate-transparency-go/client"
grpcmw "github.com/grpc-ecosystem/go-grpc-middleware"
Expand Down Expand Up @@ -50,6 +52,7 @@ type grpcServer struct {
*grpc.Server
grpcServerEndpoint string
caService gw.CAServer
tlsCertWatcher *fsnotify.Watcher
}

func PassFulcioConfigThruContext(cfg *config.FulcioConfig) grpc.UnaryServerInterceptor {
Expand All @@ -67,16 +70,85 @@ func PassFulcioConfigThruContext(cfg *config.FulcioConfig) grpc.UnaryServerInter
}
}

func createGRPCCreds(certPath, keyPath string) (grpc.ServerOption, error) {
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
type cachedTLSCert struct {
sync.RWMutex
certPath string
keyPath string
cert *tls.Certificate
Watcher *fsnotify.Watcher
}

func newCachedTLSCert(certPath, keyPath string) (*cachedTLSCert, error) {
cachedTLSCert := &cachedTLSCert{
certPath: certPath,
keyPath: keyPath,
}
if err := cachedTLSCert.UpdateCertificate(); err != nil {
return nil, err
}
var err error
cachedTLSCert.Watcher, err = fsnotify.NewWatcher()
if err != nil {
return nil, err
}

go func() {
for {
select {
case event, ok := <-cachedTLSCert.Watcher.Events:
if !ok {
return
}
if event.Has(fsnotify.Write) {
log.Logger.Info("fsnotify grpc-tls-certificate write event detected")
if err := cachedTLSCert.UpdateCertificate(); err != nil {
log.Logger.Error(err)
}
}
case err, ok := <-cachedTLSCert.Watcher.Errors:
if !ok {
return
}
log.Logger.Error("fsnotify grpc-tls-certificate error:", err)
}
}
}()

// Add a path.
if err = cachedTLSCert.Watcher.Add(certPath); err != nil {
return nil, err
}
return cachedTLSCert, nil
}

func (c *cachedTLSCert) GetCertificate() *tls.Certificate {
// get reader lock
c.RLock()
bobcallaway marked this conversation as resolved.
Show resolved Hide resolved
defer c.RUnlock()
return c.cert
}

func (c *cachedTLSCert) UpdateCertificate() error {
// get writer lock
c.Lock()
defer c.Unlock()

cert, err := tls.LoadX509KeyPair(c.certPath, c.keyPath)
if err != nil {
return nil, fmt.Errorf("loading GRPC tls certificate and key file: %w", err)
return fmt.Errorf("loading GRPC tls certificate and key file: %w", err)
}

c.cert = &cert
return nil
}

func (c *cachedTLSCert) GRPCCreds() grpc.ServerOption {
return grpc.Creds(credentials.NewTLS(&tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS13,
})), nil
GetCertificate: func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
return c.GetCertificate(), nil
},
MinVersion: tls.VersionTLS13,
}))
}

func createGRPCServer(cfg *config.FulcioConfig, ctClient *ctclient.LogClient, baseca ca.CertificateAuthority, ip identity.IssuerPool) (*grpcServer, error) {
Expand All @@ -94,12 +166,15 @@ func createGRPCServer(cfg *config.FulcioConfig, ctClient *ctclient.LogClient, ba
grpc.MaxRecvMsgSize(int(maxMsgSize)),
}

var tlsCertWatcher *fsnotify.Watcher
if viper.IsSet("grpc-tls-certificate") && viper.IsSet("grpc-tls-key") {
creds, err := createGRPCCreds(viper.GetString("grpc-tls-certificate"), viper.GetString("grpc-tls-key"))
cachedTLSCert, err := newCachedTLSCert(viper.GetString("grpc-tls-certificate"), viper.GetString("grpc-tls-key"))
if err != nil {
return nil, err
}
serverOpts = append(serverOpts, creds)

tlsCertWatcher = cachedTLSCert.Watcher
serverOpts = append(serverOpts, cachedTLSCert.GRPCCreds())
}

myServer := grpc.NewServer(serverOpts...)
Expand All @@ -109,7 +184,7 @@ func createGRPCServer(cfg *config.FulcioConfig, ctClient *ctclient.LogClient, ba
gw.RegisterCAServer(myServer, grpcCAServer)

grpcServerEndpoint := fmt.Sprintf("%s:%s", viper.GetString("grpc-host"), viper.GetString("grpc-port"))
return &grpcServer{myServer, grpcServerEndpoint, grpcCAServer}, nil
return &grpcServer{myServer, grpcServerEndpoint, grpcCAServer, tlsCertWatcher}, nil
}

func (g *grpcServer) setupPrometheus(reg *prometheus.Registry) {
Expand All @@ -129,6 +204,9 @@ func (g *grpcServer) startTCPListener() {
g.grpcServerEndpoint = lis.Addr().String()
log.Logger.Infof("listening on grpc at %s", g.grpcServerEndpoint)
go func() {
if g.tlsCertWatcher != nil {
defer g.tlsCertWatcher.Close()
}
if err := g.Server.Serve(lis); err != nil {
log.Logger.Errorf("error shutting down grpcServer: %w", err)
}
Expand Down Expand Up @@ -179,7 +257,7 @@ func createLegacyGRPCServer(cfg *config.FulcioConfig, v2Server gw.CAServer) (*gr
// Register your gRPC service implementations.
gw_legacy.RegisterCAServer(myServer, legacyGRPCCAServer)

return &grpcServer{myServer, LegacyUnixDomainSocket, v2Server}, nil
return &grpcServer{myServer, LegacyUnixDomainSocket, v2Server, nil}, nil
}

func panicRecoveryHandler(ctx context.Context, p interface{}) error {
Expand Down
61 changes: 59 additions & 2 deletions cmd/app/grpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
package app

import (
"bytes"
"os"
"path/filepath"
"testing"
"time"
)

const keyPEM = `-----BEGIN PRIVATE KEY-----
Expand Down Expand Up @@ -108,7 +110,41 @@ tSLmTWsb+j/Oxljalf+rAlItYk297HN0xMvlkHkB80O5Un6OMCHAjJmfOVZal2Y5
o4ZDR+PzKEbU8eUQbooS
-----END CERTIFICATE-----`

func TestCreateGRPCCreds(t *testing.T) {
const renewedCertPEM = `-----BEGIN CERTIFICATE-----
MIIFqzCCA5OgAwIBAgIUYKKd201v0q4S0FVSdKvRvrC9JQIwDQYJKoZIhvcNAQEL
BQAwZTELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAldBMREwDwYDVQQHDAhLaXJrbGFu
ZDERMA8GA1UECgwIU2lnc3RvcmUxDzANBgNVBAsMBkZ1bGNpbzESMBAGA1UEAwwJ
bG9jYWxob3N0MB4XDTIzMDcwMTE4NTUzM1oXDTMzMDYyODE4NTUzM1owZTELMAkG
A1UEBhMCVVMxCzAJBgNVBAgMAldBMREwDwYDVQQHDAhLaXJrbGFuZDERMA8GA1UE
CgwIU2lnc3RvcmUxDzANBgNVBAsMBkZ1bGNpbzESMBAGA1UEAwwJbG9jYWxob3N0
MIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAvc3MNQYRO5ytG+8FsPn5
0Z8koUG4sYPf6ZLTMP78+sRYfz2ggZaP46Hl3f571AB8nXSBRbIc9byDgOVpOs9t
zRnYA6tyv9cgtOCtgFfuctHFdeWpJXQr4wWhB0oUspmu66cmFKYfdwrfnrjvRkZ2
+33fGK0hC6EXtNIX7sg+Y98jT1iW0AIiBZxMBf8p5d5fTEfodVN/NZ0FN58/TOPO
jmkStqta+fZrPL02TVZ/IdDx7RSWVqD+KcqOJdSneuCt5qniQgcLwBMIk5ymyMLP
+yKl1GQXHfagxh3e84HlE93XgVCThk9XK8HSQerUmH0oklIx2PgcSS0FPjsz49Us
Kd7QweZAzcaaaoapK0QkXRvUZLSwPgOjriJtd6Pi5S7xN91DnGmyDHA7BGgCtew7
1BUgW2AzWXJq3EX0kHjMaEgHCL59SUw/pOlMiNXMC6hnUSH5lnY2isNs9+DUU+Xa
/Z9ME+B0SiRCaRGq7ZUdXiHuaN+DiRj3hX1VO96wVjvZAh0JklI6pVB7cz6HvOwx
iAtSiXxqIQZkyac3lP939tAFzLVvpqSqaHzUF8bqBSkWxy8iZVW9EJiIb8wAVE8R
Sl4WssnUrneMfXjxsyQ271H6DIDLWP4BHtorqcN0vGnOE37N6DjrOJaaTyaThn2q
Kjmt6ghqvTY0CRVpyQz3szUCAwEAAaNTMFEwHQYDVR0OBBYEFGgoph9DIwXUHUT0
8y7CtcviGmPhMB8GA1UdIwQYMBaAFGgoph9DIwXUHUT08y7CtcviGmPhMA8GA1Ud
EwEB/wQFMAMBAf8wDQYJKoZIhvcNAQELBQADggIBAJp6m3a6nMlpzZ9doLh9WOmj
dCz68oStA0HMflHGkGgtZ9BDNWQI5EL8rnsdj4l5h4nte7s4AQrtoHGdT/QkIwUu
1Zv3XBypubeof2iUJ+UhIz2Lm9vqt1hOT1HM3U7/4HCf/730y7LqwWIm580PLw1v
73C4kizrzg4T7C/4Jy5/lF5PzmJOVPWd0LDomWzbpw0pM2h8cYUY02HVEaeVOfHP
lCoTW7cL36q3I/0RAyjGmK8nNGNpJTB4xjdTT4TJRotkhcKEsTVfIrVJFGlVMDoT
Fe9T6rAnYQ2TlnPRhp9tDqiRb3Y027nJouddjulgGTRudUAzNkg7lVWJSDc4PwO2
7gm7I5/mbil6bI1r4djV0FPJZZI7EHgOM8OmKKqo5sLN0WQigZ8GSH1KHOuR1d2j
m6GJOdUJ7+ZQ+tej0pwNERMSl0+OY+FtsFMusLXoIUUTyaOs1cpFO0ifqZSp9eUP
50QDoeDGYS0T/0RicNDXMTltE24G3L7mHiPa5rr4tlYvVHYeoez7qFtG5LvDvBH8
OVZzfoJGB4MbgrFxymai/9i+hdYadt2UHL9BfnUUInDhi/2l0MB1a45DYOzf5Zf0
IYjZh058kDhlL7WOkAhPdvm2wD9KAm4FDInw49PYxpasmDKaOTf1WeWfNa5CVOYp
wpLmyTovgQl/NcXO7caU
-----END CERTIFICATE-----`

func TestCachedTLSCert(t *testing.T) {
dir := t.TempDir()

// not PKI material
Expand Down Expand Up @@ -166,10 +202,31 @@ func TestCreateGRPCCreds(t *testing.T) {

for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
_, err := createGRPCCreds(tc.certPath, tc.keyPath)
cachedCert, err := newCachedTLSCert(tc.certPath, tc.keyPath)
if tc.success != (err == nil) {
t.Errorf("unexpected result: %v", err)
}
if tc.success {
cert := cachedCert.GetCertificate()
if cert == nil {
t.Fatal("unexpected error reading tls.Certificate object")
}

// update the cert on disk to a renewed value (representing same public/private keypair)
os.WriteFile(tc.certPath, []byte(renewedCertPEM), 0644)

// sleep for a second to let goroutine fire for fsnotify event
time.Sleep(1 * time.Second)

renewedCert := cachedCert.GetCertificate()
if renewedCert == nil {
t.Fatal("unexpected error reading renewed tls.Certificate object")
}

if bytes.Equal(cert.Certificate[0], renewedCert.Certificate[0]) {
t.Fatal("got same certificate after overwriting renewed cert to same file")
}
}
})
}
}