New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Agent cache manager #233
Agent cache manager #233
Changes from 16 commits
2db10cd
d2e4922
42894a0
64e26dc
342beba
11fbabf
9b52977
46eece9
4cf4253
58fd922
3159f72
901c2ac
a80e56e
bd84271
95f1d22
cf5436e
b0f8e68
8c84d40
0ad18b2
e03f4de
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,12 @@ | ||
BindAddress = "127.0.0.1" | ||
BindPort = "8088" | ||
DataDir = "." | ||
LogLevel = "INFO" | ||
LogLevel = "DEBUG" | ||
PluginDir = "conf/agent/plugin" | ||
ServerAddress = "127.0.0.1" | ||
ServerPort = "8081" | ||
SocketPath ="/tmp/agent.sock" | ||
TrustBundlePath = "conf/agent/dummy_root_ca.crt" | ||
TrustDomain = "example.org" | ||
Umask = "" | ||
JoinToken = "" |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,6 @@ package agent | |
import ( | ||
"context" | ||
"crypto/ecdsa" | ||
"crypto/elliptic" | ||
"crypto/rand" | ||
"crypto/tls" | ||
"crypto/x509" | ||
|
@@ -57,9 +56,6 @@ type Config struct { | |
// A channel for receiving errors from agent goroutines | ||
ErrorCh chan error | ||
|
||
// A channel to trigger agent shutdown | ||
ShutdownCh chan struct{} | ||
|
||
// Trust domain and associated CA bundle | ||
TrustDomain url.URL | ||
TrustBundle *x509.CertPool | ||
|
@@ -77,17 +73,25 @@ type Agent struct { | |
BaseSVIDTTL int32 | ||
config *Config | ||
grpcServer *grpc.Server | ||
Cache cache.Cache | ||
CacheMgr cache.Manager | ||
Catalog catalog.Catalog | ||
serverCerts []*x509.Certificate | ||
ctx context.Context | ||
cancel context.CancelFunc | ||
} | ||
|
||
func New(c *Config) *Agent { | ||
func New(ctx context.Context, c *Config) *Agent { | ||
config := &catalog.Config{ | ||
ConfigDir: c.PluginDir, | ||
Log: c.Log.WithField("subsystem_name", "catalog"), | ||
} | ||
return &Agent{config: c, Catalog: catalog.New(config)} | ||
ctx, cancel := context.WithCancel(ctx) | ||
return &Agent{ | ||
config: c, | ||
Catalog: catalog.New(config), | ||
ctx: ctx, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use this context for grpc calls where appropriate. |
||
cancel: cancel, | ||
} | ||
} | ||
|
||
// Run the agent | ||
|
@@ -96,8 +100,6 @@ func New(c *Config) *Agent { | |
func (a *Agent) Run() error { | ||
a.prepareUmask() | ||
|
||
a.Cache = cache.NewCache() | ||
|
||
err := a.initPlugins() | ||
if err != nil { | ||
return err | ||
|
@@ -118,8 +120,12 @@ func (a *Agent) Run() error { | |
for { | ||
select { | ||
case err = <-a.config.ErrorCh: | ||
e := a.Shutdown() | ||
if e != nil { | ||
a.config.Log.Debug(e) | ||
} | ||
return err | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should |
||
case <-a.config.ShutdownCh: | ||
case <-a.ctx.Done(): | ||
return a.Shutdown() | ||
} | ||
} | ||
|
@@ -131,6 +137,7 @@ func (a *Agent) prepareUmask() { | |
} | ||
|
||
func (a *Agent) Shutdown() error { | ||
defer a.cancel() | ||
if a.Catalog != nil { | ||
a.Catalog.Stop() | ||
} | ||
|
@@ -169,7 +176,7 @@ func (a *Agent) initEndpoints() error { | |
log := a.config.Log.WithField("subsystem_name", "workload") | ||
ws := &workloadServer{ | ||
bundle: a.serverCerts[1].Raw, // TODO: Fix handling of serverCerts | ||
cache: a.Cache, | ||
cache: a.CacheMgr.Cache(), | ||
catalog: a.Catalog, | ||
l: log, | ||
maxTTL: maxWorkloadTTL, | ||
|
@@ -243,14 +250,36 @@ func (a *Agent) bootstrap() error { | |
a.baseSVIDKey = key | ||
|
||
// If we're here, we need to attest/Re-attest | ||
regEntryMap, err := a.attest() | ||
regEntries, err := a.attest() | ||
if err != nil { | ||
return err | ||
} | ||
err = a.FetchSVID(regEntryMap, a.BaseSVID, a.baseSVIDKey) | ||
if err != nil { | ||
return err | ||
serverId := url.URL{ | ||
Scheme: "spiffe", | ||
Host: a.config.TrustDomain.Host, | ||
Path: path.Join("spiffe", "cp"), | ||
} | ||
cmgrConfig := &cache.MgrConfig{ | ||
ServerCerts: a.serverCerts, | ||
ServerSPIFFEID: serverId.String(), | ||
ServerAddr: a.config.ServerAddress.String(), | ||
|
||
BaseSVID: a.BaseSVID, | ||
BaseSVIDKey: a.baseSVIDKey, | ||
BaseRegEntries: regEntries, | ||
Logger: a.config.Log, | ||
} | ||
|
||
a.CacheMgr, err = cache.NewManager(a.ctx, cmgrConfig) | ||
|
||
a.CacheMgr.Init() | ||
go func() { | ||
<-a.CacheMgr.Done() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will this unblock on an error condition too or only a clean shutdown? Should we shutdown the agent? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes |
||
a.config.Log.Info("Cache Update Stopped") | ||
if a.CacheMgr.Err() != nil { | ||
a.config.Log.Warning(a.CacheMgr.Err()) | ||
} | ||
}() | ||
} | ||
|
||
a.config.Log.Info("Bootstrapping done") | ||
|
@@ -261,14 +290,16 @@ func (a *Agent) bootstrap() error { | |
// which is used to generate CSRs for non-base SVIDs and update the agent cache entries | ||
// | ||
// TODO: Refactor me for length, testability | ||
func (a *Agent) attest() (map[string]*common.RegistrationEntry, error) { | ||
|
||
func (a *Agent) attest() ([]*common.RegistrationEntry, error) { | ||
var err error | ||
a.config.Log.Info("Preparing to attest against ", a.config.ServerAddress.String()) | ||
|
||
// Handle the join token seperately, if defined | ||
pluginResponse := &nodeattestor.FetchAttestationDataResponse{} | ||
if a.config.JoinToken != "" { | ||
a.config.Log.Info("Preparing to attest this node against ", a.config.ServerAddress.String(), " using strategy 'join-token'") | ||
|
||
a.config.Log.Info("Preparing to attest this node against ", | ||
a.config.ServerAddress.String(), " using strategy 'join-token'") | ||
data := &common.AttestedData{ | ||
Type: "join_token", | ||
Data: []byte(a.config.JoinToken), | ||
|
@@ -278,7 +309,6 @@ func (a *Agent) attest() (map[string]*common.RegistrationEntry, error) { | |
Host: a.config.TrustDomain.Host, | ||
Path: path.Join("spire", "agent", "join_token", a.config.JoinToken), | ||
} | ||
|
||
pluginResponse.AttestedData = data | ||
pluginResponse.SpiffeId = id.String() | ||
} else { | ||
|
@@ -309,7 +339,10 @@ func (a *Agent) attest() (map[string]*common.RegistrationEntry, error) { | |
} | ||
|
||
// Since we are bootstrapping, this is explicitly _not_ mTLS | ||
conn := a.getNodeAPIClientConn(false, a.BaseSVID, a.baseSVIDKey) | ||
conn, err := a.getNodeAPIClientConn(false, a.BaseSVID, a.baseSVIDKey) | ||
if err != nil { | ||
return nil, err | ||
} | ||
defer conn.Close() | ||
nodeClient := node.NewNodeClient(conn) | ||
|
||
|
@@ -340,16 +373,11 @@ func (a *Agent) attest() (map[string]*common.RegistrationEntry, error) { | |
return nil, fmt.Errorf("Base SVID not found in attestation response") | ||
} | ||
|
||
var registrationEntryMap = make(map[string]*common.RegistrationEntry) | ||
for _, entry := range serverResponse.SvidUpdate.RegistrationEntries { | ||
registrationEntryMap[entry.SpiffeId] = entry | ||
} | ||
|
||
a.BaseSVID = svid.SvidCert | ||
a.BaseSVIDTTL = svid.Ttl | ||
a.storeBaseSVID() | ||
a.config.Log.Info("Node attestation complete") | ||
return registrationEntryMap, nil | ||
a.config.Log.Info("Attestation complete") | ||
return serverResponse.SvidUpdate.RegistrationEntries, nil | ||
} | ||
|
||
// Generate a CSR for the given SPIFFE ID | ||
|
@@ -421,72 +449,7 @@ func (a *Agent) storeBaseSVID() { | |
return | ||
} | ||
|
||
func (a *Agent) FetchSVID(registrationEntryMap map[string]*common.RegistrationEntry, svidCert []byte, | ||
key *ecdsa.PrivateKey) (err error) { | ||
|
||
if len(registrationEntryMap) != 0 { | ||
Csrs, pkeyMap, err := a.generateCSRForRegistrationEntries(registrationEntryMap) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
conn := a.getNodeAPIClientConn(true, svidCert, key) | ||
defer conn.Close() | ||
nodeClient := node.NewNodeClient(conn) | ||
|
||
req := &node.FetchSVIDRequest{Csrs: Csrs} | ||
|
||
callOptPeer := new(peer.Peer) | ||
resp, err := nodeClient.FetchSVID(context.Background(), req, grpc.Peer(callOptPeer)) | ||
if err != nil { | ||
return err | ||
} | ||
if tlsInfo, ok := callOptPeer.AuthInfo.(credentials.TLSInfo); ok { | ||
a.serverCerts = tlsInfo.State.PeerCertificates | ||
} | ||
|
||
svidMap := resp.GetSvidUpdate().GetSvids() | ||
|
||
// TODO: Fetch the referenced federated bundles and | ||
// set them here | ||
bundles := make(map[string][]byte) | ||
for spiffeID, entry := range registrationEntryMap { | ||
svid, svidInMap := svidMap[spiffeID] | ||
pkey, pkeyInMap := pkeyMap[spiffeID] | ||
if svidInMap && pkeyInMap { | ||
svidCert, err := x509.ParseCertificate(svid.SvidCert) | ||
if err != nil { | ||
return fmt.Errorf("SVID for ID %s could not be parsed: %s", spiffeID, err) | ||
} | ||
|
||
entry := cache.CacheEntry{ | ||
RegistrationEntry: entry, | ||
SVID: svid, | ||
PrivateKey: pkey, | ||
Bundles: bundles, | ||
Expiry: svidCert.NotAfter, | ||
} | ||
a.Cache.SetEntry(entry) | ||
} | ||
} | ||
|
||
newRegistrationMap := make(map[string]*common.RegistrationEntry) | ||
|
||
if len(resp.SvidUpdate.RegistrationEntries) != 0 { | ||
for _, entry := range resp.SvidUpdate.RegistrationEntries { | ||
if _, ok := registrationEntryMap[entry.SpiffeId]; ok != true { | ||
newRegistrationMap[entry.SpiffeId] = entry | ||
} | ||
a.FetchSVID(newRegistrationMap, svidMap[entry.SpiffeId].SvidCert, pkeyMap[entry.SpiffeId]) | ||
|
||
} | ||
|
||
} | ||
} | ||
return | ||
} | ||
|
||
func (a *Agent) getNodeAPIClientConn(mtls bool, svid []byte, key *ecdsa.PrivateKey) (conn *grpc.ClientConn) { | ||
func (a *Agent) getNodeAPIClientConn(mtls bool, svid []byte, key *ecdsa.PrivateKey) (conn *grpc.ClientConn, err error) { | ||
|
||
serverID := a.config.TrustDomain | ||
serverID.Path = "spiffe/cp" | ||
|
@@ -516,35 +479,10 @@ func (a *Agent) getNodeAPIClientConn(mtls bool, svid []byte, key *ecdsa.PrivateK | |
|
||
dialCreds := grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)) | ||
|
||
conn, err := grpc.Dial(a.config.ServerAddress.String(), dialCreds) | ||
conn, err = grpc.DialContext(a.ctx, a.config.ServerAddress.String(), dialCreds) | ||
if err != nil { | ||
return | ||
} | ||
|
||
return | ||
|
||
} | ||
|
||
func (a *Agent) generateCSRForRegistrationEntries( | ||
regEntryMap map[string]*common.RegistrationEntry) (CSRs [][]byte, pkeyMap map[string]*ecdsa.PrivateKey, err error) { | ||
|
||
pkeyMap = make(map[string]*ecdsa.PrivateKey) | ||
for id, _ := range regEntryMap { | ||
|
||
key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) | ||
if err != nil { | ||
return nil, nil, err | ||
} | ||
spiffeid, err := url.Parse(id) | ||
if err != nil { | ||
return nil, nil, err | ||
} | ||
csr, err := a.generateCSR(spiffeid, key) | ||
if err != nil { | ||
return nil, nil, err | ||
} | ||
CSRs = append(CSRs, csr) | ||
pkeyMap[id] = key | ||
} | ||
return | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A trick I've been doing to allow both graceful shutdown and easy "really, f--ing die" is to stop handling signals after the first one. See here.
It might not be too applicable in a non-interactive program like this server, but mentioning for food for thought. It can make dev cycles easier (although Ctrl-\ is usually fine too).