Skip to content

Commit

Permalink
Agent cache manager (#233)
Browse files Browse the repository at this point in the history
* cache manager:
- Updates cache with new and expired entries
  • Loading branch information
walmav committed Oct 19, 2017
1 parent ce8c15f commit 277769d
Show file tree
Hide file tree
Showing 23 changed files with 1,315 additions and 309 deletions.
18 changes: 9 additions & 9 deletions cmd/spire-agent/cli/command/run.go
Expand Up @@ -15,6 +15,7 @@ import (
"strconv"
"syscall"

"context"
"github.com/hashicorp/hcl"
"github.com/spiffe/spire/pkg/agent"
"github.com/spiffe/spire/pkg/common/log"
Expand Down Expand Up @@ -83,10 +84,11 @@ func (*RunCommand) Run(args []string) int {
if err != nil {
fmt.Println(err.Error())
}
ctx, cancel := context.WithCancel(context.Background())
signalListener(ctx, cancel)

signalListener(c.ShutdownCh)
agt := agent.New(ctx, c)

agt := agent.New(c)
err = agt.Run()
if err != nil {
c.Log.Error(err.Error())
Expand Down Expand Up @@ -270,8 +272,6 @@ func newDefaultConfig() *agent.Config {
Organization: []string{"SPIRE"},
}
errCh := make(chan error)
shutdownCh := make(chan struct{})

// log.NewLogger() cannot return error when using STDOUT
logger, _ := log.NewLogger(defaultLogLevel, "")
serverAddress := &net.TCPAddr{}
Expand All @@ -282,7 +282,6 @@ func newDefaultConfig() *agent.Config {
DataDir: defaultDataDir,
PluginDir: defaultPluginDir,
ErrorCh: errCh,
ShutdownCh: shutdownCh,
Log: logger,
ServerAddress: serverAddress,
Umask: defaultUmask,
Expand Down Expand Up @@ -310,15 +309,16 @@ func stringDefault(option string, defaultValue string) string {
return option
}

func signalListener(ch chan struct{}) {
func signalListener(ctx context.Context, cancel context.CancelFunc) {

go func() {
signalCh := make(chan os.Signal, 1)
signal.Notify(signalCh, syscall.SIGINT, syscall.SIGTERM)

var stop struct{}
defer signal.Stop(signalCh)
select {
case <-ctx.Done():
case <-signalCh:
ch <- stop
cancel()
}
}()
return
Expand Down
3 changes: 2 additions & 1 deletion conf/agent/agent.conf
@@ -1,11 +1,12 @@
BindAddress = "127.0.0.1"
BindPort = "8088"
DataDir = "."
LogLevel = "INFO"
LogLevel = "DEBUG"
PluginDir = "conf/agent/plugin"
ServerAddress = "127.0.0.1"
ServerPort = "8081"
SocketPath ="/tmp/agent.sock"
TrustBundlePath = "conf/agent/dummy_root_ca.crt"
TrustDomain = "example.org"
Umask = ""
JoinToken = ""
2 changes: 1 addition & 1 deletion conf/server/server.conf
Expand Up @@ -3,6 +3,6 @@ BindPort = "8081"
BindHTTPPort = "8080"
TrustDomain = "example.org"
PluginDir = "conf/server/plugin"
LogLevel = "INFO"
LogLevel = "DEBUG"
BaseSpiffeIDTTL = 999999
Umask = ""
Binary file added doc/images/cacheMgr.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
176 changes: 57 additions & 119 deletions pkg/agent/agent.go
Expand Up @@ -3,7 +3,6 @@ package agent
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
Expand Down Expand Up @@ -57,9 +56,6 @@ type Config struct {
// A channel for receiving errors from agent goroutines
ErrorCh chan error

// A channel to trigger agent shutdown
ShutdownCh chan struct{}

// Trust domain and associated CA bundle
TrustDomain url.URL
TrustBundle *x509.CertPool
Expand All @@ -77,17 +73,25 @@ type Agent struct {
BaseSVIDTTL int32
config *Config
grpcServer *grpc.Server
Cache cache.Cache
CacheMgr cache.Manager
Catalog catalog.Catalog
serverCerts []*x509.Certificate
ctx context.Context
cancel context.CancelFunc
}

func New(c *Config) *Agent {
func New(ctx context.Context, c *Config) *Agent {
config := &catalog.Config{
ConfigDir: c.PluginDir,
Log: c.Log.WithField("subsystem_name", "catalog"),
}
return &Agent{config: c, Catalog: catalog.New(config)}
ctx, cancel := context.WithCancel(ctx)
return &Agent{
config: c,
Catalog: catalog.New(config),
ctx: ctx,
cancel: cancel,
}
}

// Run the agent
Expand All @@ -96,8 +100,6 @@ func New(c *Config) *Agent {
func (a *Agent) Run() error {
a.prepareUmask()

a.Cache = cache.NewCache()

err := a.initPlugins()
if err != nil {
return err
Expand All @@ -118,8 +120,12 @@ func (a *Agent) Run() error {
for {
select {
case err = <-a.config.ErrorCh:
e := a.Shutdown()
if e != nil {
a.config.Log.Debug(e)
}
return err
case <-a.config.ShutdownCh:
case <-a.ctx.Done():
return a.Shutdown()
}
}
Expand All @@ -131,6 +137,7 @@ func (a *Agent) prepareUmask() {
}

func (a *Agent) Shutdown() error {
defer a.cancel()
if a.Catalog != nil {
a.Catalog.Stop()
}
Expand Down Expand Up @@ -169,7 +176,7 @@ func (a *Agent) initEndpoints() error {
log := a.config.Log.WithField("subsystem_name", "workload")
ws := &workloadServer{
bundle: a.serverCerts[1].Raw, // TODO: Fix handling of serverCerts
cache: a.Cache,
cache: a.CacheMgr.Cache(),
catalog: a.Catalog,
l: log,
maxTTL: maxWorkloadTTL,
Expand Down Expand Up @@ -243,14 +250,36 @@ func (a *Agent) bootstrap() error {
a.baseSVIDKey = key

// If we're here, we need to attest/Re-attest
regEntryMap, err := a.attest()
regEntries, err := a.attest()
if err != nil {
return err
}
err = a.FetchSVID(regEntryMap, a.BaseSVID, a.baseSVIDKey)
if err != nil {
return err
serverId := url.URL{
Scheme: "spiffe",
Host: a.config.TrustDomain.Host,
Path: path.Join("spiffe", "cp"),
}
cmgrConfig := &cache.MgrConfig{
ServerCerts: a.serverCerts,
ServerSPIFFEID: serverId.String(),
ServerAddr: a.config.ServerAddress.String(),

BaseSVID: a.BaseSVID,
BaseSVIDKey: a.baseSVIDKey,
BaseRegEntries: regEntries,
Logger: a.config.Log,
}

a.CacheMgr, err = cache.NewManager(a.ctx, cmgrConfig)

a.CacheMgr.Init()
go func() {
<-a.CacheMgr.Done()
a.config.Log.Info("Cache Update Stopped")
if a.CacheMgr.Err() != nil {
a.config.Log.Warning(a.CacheMgr.Err())
}
}()
}

a.config.Log.Info("Bootstrapping done")
Expand All @@ -261,14 +290,16 @@ func (a *Agent) bootstrap() error {
// which is used to generate CSRs for non-base SVIDs and update the agent cache entries
//
// TODO: Refactor me for length, testability
func (a *Agent) attest() (map[string]*common.RegistrationEntry, error) {

func (a *Agent) attest() ([]*common.RegistrationEntry, error) {
var err error
a.config.Log.Info("Preparing to attest against ", a.config.ServerAddress.String())

// Handle the join token seperately, if defined
pluginResponse := &nodeattestor.FetchAttestationDataResponse{}
if a.config.JoinToken != "" {
a.config.Log.Info("Preparing to attest this node against ", a.config.ServerAddress.String(), " using strategy 'join-token'")

a.config.Log.Info("Preparing to attest this node against ",
a.config.ServerAddress.String(), " using strategy 'join-token'")
data := &common.AttestedData{
Type: "join_token",
Data: []byte(a.config.JoinToken),
Expand All @@ -278,7 +309,6 @@ func (a *Agent) attest() (map[string]*common.RegistrationEntry, error) {
Host: a.config.TrustDomain.Host,
Path: path.Join("spire", "agent", "join_token", a.config.JoinToken),
}

pluginResponse.AttestedData = data
pluginResponse.SpiffeId = id.String()
} else {
Expand Down Expand Up @@ -309,7 +339,10 @@ func (a *Agent) attest() (map[string]*common.RegistrationEntry, error) {
}

// Since we are bootstrapping, this is explicitly _not_ mTLS
conn := a.getNodeAPIClientConn(false, a.BaseSVID, a.baseSVIDKey)
conn, err := a.getNodeAPIClientConn(false, a.BaseSVID, a.baseSVIDKey)
if err != nil {
return nil, err
}
defer conn.Close()
nodeClient := node.NewNodeClient(conn)

Expand Down Expand Up @@ -340,16 +373,11 @@ func (a *Agent) attest() (map[string]*common.RegistrationEntry, error) {
return nil, fmt.Errorf("Base SVID not found in attestation response")
}

var registrationEntryMap = make(map[string]*common.RegistrationEntry)
for _, entry := range serverResponse.SvidUpdate.RegistrationEntries {
registrationEntryMap[entry.SpiffeId] = entry
}

a.BaseSVID = svid.SvidCert
a.BaseSVIDTTL = svid.Ttl
a.storeBaseSVID()
a.config.Log.Info("Node attestation complete")
return registrationEntryMap, nil
a.config.Log.Info("Attestation complete")
return serverResponse.SvidUpdate.RegistrationEntries, nil
}

// Generate a CSR for the given SPIFFE ID
Expand Down Expand Up @@ -421,72 +449,7 @@ func (a *Agent) storeBaseSVID() {
return
}

func (a *Agent) FetchSVID(registrationEntryMap map[string]*common.RegistrationEntry, svidCert []byte,
key *ecdsa.PrivateKey) (err error) {

if len(registrationEntryMap) != 0 {
Csrs, pkeyMap, err := a.generateCSRForRegistrationEntries(registrationEntryMap)
if err != nil {
return err
}

conn := a.getNodeAPIClientConn(true, svidCert, key)
defer conn.Close()
nodeClient := node.NewNodeClient(conn)

req := &node.FetchSVIDRequest{Csrs: Csrs}

callOptPeer := new(peer.Peer)
resp, err := nodeClient.FetchSVID(context.Background(), req, grpc.Peer(callOptPeer))
if err != nil {
return err
}
if tlsInfo, ok := callOptPeer.AuthInfo.(credentials.TLSInfo); ok {
a.serverCerts = tlsInfo.State.PeerCertificates
}

svidMap := resp.GetSvidUpdate().GetSvids()

// TODO: Fetch the referenced federated bundles and
// set them here
bundles := make(map[string][]byte)
for spiffeID, entry := range registrationEntryMap {
svid, svidInMap := svidMap[spiffeID]
pkey, pkeyInMap := pkeyMap[spiffeID]
if svidInMap && pkeyInMap {
svidCert, err := x509.ParseCertificate(svid.SvidCert)
if err != nil {
return fmt.Errorf("SVID for ID %s could not be parsed: %s", spiffeID, err)
}

entry := cache.CacheEntry{
RegistrationEntry: entry,
SVID: svid,
PrivateKey: pkey,
Bundles: bundles,
Expiry: svidCert.NotAfter,
}
a.Cache.SetEntry(entry)
}
}

newRegistrationMap := make(map[string]*common.RegistrationEntry)

if len(resp.SvidUpdate.RegistrationEntries) != 0 {
for _, entry := range resp.SvidUpdate.RegistrationEntries {
if _, ok := registrationEntryMap[entry.SpiffeId]; ok != true {
newRegistrationMap[entry.SpiffeId] = entry
}
a.FetchSVID(newRegistrationMap, svidMap[entry.SpiffeId].SvidCert, pkeyMap[entry.SpiffeId])

}

}
}
return
}

func (a *Agent) getNodeAPIClientConn(mtls bool, svid []byte, key *ecdsa.PrivateKey) (conn *grpc.ClientConn) {
func (a *Agent) getNodeAPIClientConn(mtls bool, svid []byte, key *ecdsa.PrivateKey) (conn *grpc.ClientConn, err error) {

serverID := a.config.TrustDomain
serverID.Path = "spiffe/cp"
Expand Down Expand Up @@ -516,35 +479,10 @@ func (a *Agent) getNodeAPIClientConn(mtls bool, svid []byte, key *ecdsa.PrivateK

dialCreds := grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))

conn, err := grpc.Dial(a.config.ServerAddress.String(), dialCreds)
conn, err = grpc.DialContext(a.ctx, a.config.ServerAddress.String(), dialCreds)
if err != nil {
return
}

return

}

func (a *Agent) generateCSRForRegistrationEntries(
regEntryMap map[string]*common.RegistrationEntry) (CSRs [][]byte, pkeyMap map[string]*ecdsa.PrivateKey, err error) {

pkeyMap = make(map[string]*ecdsa.PrivateKey)
for id, _ := range regEntryMap {

key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
if err != nil {
return nil, nil, err
}
spiffeid, err := url.Parse(id)
if err != nil {
return nil, nil, err
}
csr, err := a.generateCSR(spiffeid, key)
if err != nil {
return nil, nil, err
}
CSRs = append(CSRs, csr)
pkeyMap[id] = key
}
return
}

0 comments on commit 277769d

Please sign in to comment.