Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Agent cache manager #233

Merged
merged 20 commits into from Oct 19, 2017
18 changes: 9 additions & 9 deletions cmd/spire-agent/cli/command/run.go
Expand Up @@ -15,6 +15,7 @@ import (
"strconv"
"syscall"

"context"
"github.com/hashicorp/hcl"
"github.com/spiffe/spire/pkg/agent"
"github.com/spiffe/spire/pkg/common/log"
Expand Down Expand Up @@ -83,10 +84,11 @@ func (*RunCommand) Run(args []string) int {
if err != nil {
fmt.Println(err.Error())
}
ctx, cancel := context.WithCancel(context.Background())
signalListener(ctx, cancel)

signalListener(c.ShutdownCh)
agt := agent.New(ctx, c)

agt := agent.New(c)
err = agt.Run()
if err != nil {
c.Log.Error(err.Error())
Expand Down Expand Up @@ -270,8 +272,6 @@ func newDefaultConfig() *agent.Config {
Organization: []string{"SPIRE"},
}
errCh := make(chan error)
shutdownCh := make(chan struct{})

// log.NewLogger() cannot return error when using STDOUT
logger, _ := log.NewLogger(defaultLogLevel, "")
serverAddress := &net.TCPAddr{}
Expand All @@ -282,7 +282,6 @@ func newDefaultConfig() *agent.Config {
DataDir: defaultDataDir,
PluginDir: defaultPluginDir,
ErrorCh: errCh,
ShutdownCh: shutdownCh,
Log: logger,
ServerAddress: serverAddress,
Umask: defaultUmask,
Expand Down Expand Up @@ -310,15 +309,16 @@ func stringDefault(option string, defaultValue string) string {
return option
}

func signalListener(ch chan struct{}) {
func signalListener(ctx context.Context, cancel context.CancelFunc) {

go func() {
signalCh := make(chan os.Signal, 1)
signal.Notify(signalCh, syscall.SIGINT, syscall.SIGTERM)

var stop struct{}
defer signal.Stop(signalCh)
select {
case <-ctx.Done():
case <-signalCh:
ch <- stop
cancel()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A trick I've been doing to allow both graceful shutdown and easy "really, f--ing die" is to stop handling signals after the first one. See here.

It might not be too applicable in a non-interactive program like this server, but mentioning for food for thought. It can make dev cycles easier (although Ctrl-\ is usually fine too).

}
}()
return
Expand Down
3 changes: 2 additions & 1 deletion conf/agent/agent.conf
@@ -1,11 +1,12 @@
BindAddress = "127.0.0.1"
BindPort = "8088"
DataDir = "."
LogLevel = "INFO"
LogLevel = "DEBUG"
PluginDir = "conf/agent/plugin"
ServerAddress = "127.0.0.1"
ServerPort = "8081"
SocketPath ="/tmp/agent.sock"
TrustBundlePath = "conf/agent/dummy_root_ca.crt"
TrustDomain = "example.org"
Umask = ""
JoinToken = ""
2 changes: 1 addition & 1 deletion conf/server/server.conf
Expand Up @@ -3,6 +3,6 @@ BindPort = "8081"
BindHTTPPort = "8080"
TrustDomain = "example.org"
PluginDir = "conf/server/plugin"
LogLevel = "INFO"
LogLevel = "DEBUG"
BaseSpiffeIDTTL = 999999
Umask = ""
176 changes: 57 additions & 119 deletions pkg/agent/agent.go
Expand Up @@ -3,7 +3,6 @@ package agent
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
Expand Down Expand Up @@ -57,9 +56,6 @@ type Config struct {
// A channel for receiving errors from agent goroutines
ErrorCh chan error

// A channel to trigger agent shutdown
ShutdownCh chan struct{}

// Trust domain and associated CA bundle
TrustDomain url.URL
TrustBundle *x509.CertPool
Expand All @@ -77,17 +73,25 @@ type Agent struct {
BaseSVIDTTL int32
config *Config
grpcServer *grpc.Server
Cache cache.Cache
CacheMgr cache.Manager
Catalog catalog.Catalog
serverCerts []*x509.Certificate
ctx context.Context
cancel context.CancelFunc
}

func New(c *Config) *Agent {
func New(ctx context.Context, c *Config) *Agent {
config := &catalog.Config{
ConfigDir: c.PluginDir,
Log: c.Log.WithField("subsystem_name", "catalog"),
}
return &Agent{config: c, Catalog: catalog.New(config)}
ctx, cancel := context.WithCancel(ctx)
return &Agent{
config: c,
Catalog: catalog.New(config),
ctx: ctx,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use this context for grpc calls where appropriate.

cancel: cancel,
}
}

// Run the agent
Expand All @@ -96,8 +100,6 @@ func New(c *Config) *Agent {
func (a *Agent) Run() error {
a.prepareUmask()

a.Cache = cache.NewCache()

err := a.initPlugins()
if err != nil {
return err
Expand All @@ -118,8 +120,12 @@ func (a *Agent) Run() error {
for {
select {
case err = <-a.config.ErrorCh:
e := a.Shutdown()
if e != nil {
a.config.Log.Debug(e)
}
return err
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should Shutdown() be called here?

case <-a.config.ShutdownCh:
case <-a.ctx.Done():
return a.Shutdown()
}
}
Expand All @@ -131,6 +137,7 @@ func (a *Agent) prepareUmask() {
}

func (a *Agent) Shutdown() error {
defer a.cancel()
if a.Catalog != nil {
a.Catalog.Stop()
}
Expand Down Expand Up @@ -169,7 +176,7 @@ func (a *Agent) initEndpoints() error {
log := a.config.Log.WithField("subsystem_name", "workload")
ws := &workloadServer{
bundle: a.serverCerts[1].Raw, // TODO: Fix handling of serverCerts
cache: a.Cache,
cache: a.CacheMgr.Cache(),
catalog: a.Catalog,
l: log,
maxTTL: maxWorkloadTTL,
Expand Down Expand Up @@ -243,14 +250,36 @@ func (a *Agent) bootstrap() error {
a.baseSVIDKey = key

// If we're here, we need to attest/Re-attest
regEntryMap, err := a.attest()
regEntries, err := a.attest()
if err != nil {
return err
}
err = a.FetchSVID(regEntryMap, a.BaseSVID, a.baseSVIDKey)
if err != nil {
return err
serverId := url.URL{
Scheme: "spiffe",
Host: a.config.TrustDomain.Host,
Path: path.Join("spiffe", "cp"),
}
cmgrConfig := &cache.MgrConfig{
ServerCerts: a.serverCerts,
ServerSPIFFEID: serverId.String(),
ServerAddr: a.config.ServerAddress.String(),

BaseSVID: a.BaseSVID,
BaseSVIDKey: a.baseSVIDKey,
BaseRegEntries: regEntries,
Logger: a.config.Log,
}

a.CacheMgr, err = cache.NewManager(a.ctx, cmgrConfig)

a.CacheMgr.Init()
go func() {
<-a.CacheMgr.Done()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this unblock on an error condition too or only a clean shutdown? Should we shutdown the agent?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes CacheMgr.Done() will return when doneCh is closed which happens when the Init() method returns

a.config.Log.Info("Cache Update Stopped")
if a.CacheMgr.Err() != nil {
a.config.Log.Warning(a.CacheMgr.Err())
}
}()
}

a.config.Log.Info("Bootstrapping done")
Expand All @@ -261,14 +290,16 @@ func (a *Agent) bootstrap() error {
// which is used to generate CSRs for non-base SVIDs and update the agent cache entries
//
// TODO: Refactor me for length, testability
func (a *Agent) attest() (map[string]*common.RegistrationEntry, error) {

func (a *Agent) attest() ([]*common.RegistrationEntry, error) {
var err error
a.config.Log.Info("Preparing to attest against ", a.config.ServerAddress.String())

// Handle the join token seperately, if defined
pluginResponse := &nodeattestor.FetchAttestationDataResponse{}
if a.config.JoinToken != "" {
a.config.Log.Info("Preparing to attest this node against ", a.config.ServerAddress.String(), " using strategy 'join-token'")

a.config.Log.Info("Preparing to attest this node against ",
a.config.ServerAddress.String(), " using strategy 'join-token'")
data := &common.AttestedData{
Type: "join_token",
Data: []byte(a.config.JoinToken),
Expand All @@ -278,7 +309,6 @@ func (a *Agent) attest() (map[string]*common.RegistrationEntry, error) {
Host: a.config.TrustDomain.Host,
Path: path.Join("spire", "agent", "join_token", a.config.JoinToken),
}

pluginResponse.AttestedData = data
pluginResponse.SpiffeId = id.String()
} else {
Expand Down Expand Up @@ -309,7 +339,10 @@ func (a *Agent) attest() (map[string]*common.RegistrationEntry, error) {
}

// Since we are bootstrapping, this is explicitly _not_ mTLS
conn := a.getNodeAPIClientConn(false, a.BaseSVID, a.baseSVIDKey)
conn, err := a.getNodeAPIClientConn(false, a.BaseSVID, a.baseSVIDKey)
if err != nil {
return nil, err
}
defer conn.Close()
nodeClient := node.NewNodeClient(conn)

Expand Down Expand Up @@ -340,16 +373,11 @@ func (a *Agent) attest() (map[string]*common.RegistrationEntry, error) {
return nil, fmt.Errorf("Base SVID not found in attestation response")
}

var registrationEntryMap = make(map[string]*common.RegistrationEntry)
for _, entry := range serverResponse.SvidUpdate.RegistrationEntries {
registrationEntryMap[entry.SpiffeId] = entry
}

a.BaseSVID = svid.SvidCert
a.BaseSVIDTTL = svid.Ttl
a.storeBaseSVID()
a.config.Log.Info("Node attestation complete")
return registrationEntryMap, nil
a.config.Log.Info("Attestation complete")
return serverResponse.SvidUpdate.RegistrationEntries, nil
}

// Generate a CSR for the given SPIFFE ID
Expand Down Expand Up @@ -421,72 +449,7 @@ func (a *Agent) storeBaseSVID() {
return
}

func (a *Agent) FetchSVID(registrationEntryMap map[string]*common.RegistrationEntry, svidCert []byte,
key *ecdsa.PrivateKey) (err error) {

if len(registrationEntryMap) != 0 {
Csrs, pkeyMap, err := a.generateCSRForRegistrationEntries(registrationEntryMap)
if err != nil {
return err
}

conn := a.getNodeAPIClientConn(true, svidCert, key)
defer conn.Close()
nodeClient := node.NewNodeClient(conn)

req := &node.FetchSVIDRequest{Csrs: Csrs}

callOptPeer := new(peer.Peer)
resp, err := nodeClient.FetchSVID(context.Background(), req, grpc.Peer(callOptPeer))
if err != nil {
return err
}
if tlsInfo, ok := callOptPeer.AuthInfo.(credentials.TLSInfo); ok {
a.serverCerts = tlsInfo.State.PeerCertificates
}

svidMap := resp.GetSvidUpdate().GetSvids()

// TODO: Fetch the referenced federated bundles and
// set them here
bundles := make(map[string][]byte)
for spiffeID, entry := range registrationEntryMap {
svid, svidInMap := svidMap[spiffeID]
pkey, pkeyInMap := pkeyMap[spiffeID]
if svidInMap && pkeyInMap {
svidCert, err := x509.ParseCertificate(svid.SvidCert)
if err != nil {
return fmt.Errorf("SVID for ID %s could not be parsed: %s", spiffeID, err)
}

entry := cache.CacheEntry{
RegistrationEntry: entry,
SVID: svid,
PrivateKey: pkey,
Bundles: bundles,
Expiry: svidCert.NotAfter,
}
a.Cache.SetEntry(entry)
}
}

newRegistrationMap := make(map[string]*common.RegistrationEntry)

if len(resp.SvidUpdate.RegistrationEntries) != 0 {
for _, entry := range resp.SvidUpdate.RegistrationEntries {
if _, ok := registrationEntryMap[entry.SpiffeId]; ok != true {
newRegistrationMap[entry.SpiffeId] = entry
}
a.FetchSVID(newRegistrationMap, svidMap[entry.SpiffeId].SvidCert, pkeyMap[entry.SpiffeId])

}

}
}
return
}

func (a *Agent) getNodeAPIClientConn(mtls bool, svid []byte, key *ecdsa.PrivateKey) (conn *grpc.ClientConn) {
func (a *Agent) getNodeAPIClientConn(mtls bool, svid []byte, key *ecdsa.PrivateKey) (conn *grpc.ClientConn, err error) {

serverID := a.config.TrustDomain
serverID.Path = "spiffe/cp"
Expand Down Expand Up @@ -516,35 +479,10 @@ func (a *Agent) getNodeAPIClientConn(mtls bool, svid []byte, key *ecdsa.PrivateK

dialCreds := grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))

conn, err := grpc.Dial(a.config.ServerAddress.String(), dialCreds)
conn, err = grpc.DialContext(a.ctx, a.config.ServerAddress.String(), dialCreds)
if err != nil {
return
}

return

}

func (a *Agent) generateCSRForRegistrationEntries(
regEntryMap map[string]*common.RegistrationEntry) (CSRs [][]byte, pkeyMap map[string]*ecdsa.PrivateKey, err error) {

pkeyMap = make(map[string]*ecdsa.PrivateKey)
for id, _ := range regEntryMap {

key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
if err != nil {
return nil, nil, err
}
spiffeid, err := url.Parse(id)
if err != nil {
return nil, nil, err
}
csr, err := a.generateCSR(spiffeid, key)
if err != nil {
return nil, nil, err
}
CSRs = append(CSRs, csr)
pkeyMap[id] = key
}
return
}