diff --git a/cmd/preflight-id/preflight-id.go b/cmd/preflight-id/preflight-id.go index 72a2bae..513e41d 100644 --- a/cmd/preflight-id/preflight-id.go +++ b/cmd/preflight-id/preflight-id.go @@ -33,6 +33,7 @@ func main() { ll = log.InfoLevel } log.SetLevel(ll) + preflightid.Logger = l.Logger var provider string if provider == "" { // infer provider from flags diff --git a/pkg/preflightid/aws.go b/pkg/preflightid/aws.go index f33c157..d7dec9c 100644 --- a/pkg/preflightid/aws.go +++ b/pkg/preflightid/aws.go @@ -2,6 +2,7 @@ package preflightid import ( "errors" + "fmt" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/sts" @@ -14,9 +15,8 @@ type IDProviderAWS struct { func (p *IDProviderAWS) Run() error { l := log.WithFields(log.Fields{ - "app": "preflight-id", - "provider": "aws", - "fn": "p.Run", + "preflight": "id", + "provider": "aws", }) l.Debug("running preflight-id") if p.ARN == "" { @@ -34,9 +34,10 @@ func (p *IDProviderAWS) Run() error { return err } if *resp.Arn != p.ARN { - l.WithError(err).Errorf("ARN mismatch: %s != %s", *resp.Arn, p.ARN) - return errors.New("ARN mismatch") + failStr := fmt.Sprintf("failed - expected %s, got %s", p.ARN, *resp.Arn) + l.Error(failStr) + return errors.New(failStr) } - l.Info("ARN match") + l.Info("passed") return nil } diff --git a/pkg/preflightid/gcp.go b/pkg/preflightid/gcp.go index c749afc..b175979 100644 --- a/pkg/preflightid/gcp.go +++ b/pkg/preflightid/gcp.go @@ -18,9 +18,8 @@ type IDProviderGCP struct { func (p *IDProviderGCP) Run() error { l := log.WithFields(log.Fields{ - "app": "preflight-id", - "provider": "gcp", - "fn": "p.Run", + "preflight": "id", + "provider": "gcp", }) l.Debug("running preflight-id") if p.Email == "" { @@ -40,9 +39,12 @@ func (p *IDProviderGCP) Run() error { l.WithError(err).Error("Failed to retrieve authorized accounts") return err } + var accountList []string for _, account := range response.Accounts { + accountList = append(accountList, account.Email) if strings.EqualFold(account.Email, p.Email) { l.Debugf("Service Account match: %s", account.Email) + l.Info("passed") return nil } } @@ -52,11 +54,14 @@ func (p *IDProviderGCP) Run() error { l.WithError(err).Error("Failed to retrieve VM Identity") return err } + accountList = append(accountList, vmIdentity) if strings.EqualFold(vmIdentity, p.Email) { l.Debugf("VM Identity match: %s", vmIdentity) + l.Info("passed") return nil } } - l.Errorf("Service Account not found: %s", p.Email) - return errors.New(fmt.Sprintf("Service Account not found")) + failStr := fmt.Sprintf("failed - expected %s, got %v", p.Email, accountList) + l.Error(failStr) + return errors.New(failStr) } diff --git a/pkg/preflightid/kube.go b/pkg/preflightid/kube.go index b838f90..e59a6dd 100644 --- a/pkg/preflightid/kube.go +++ b/pkg/preflightid/kube.go @@ -4,6 +4,7 @@ import ( "encoding/base64" "encoding/json" "errors" + "fmt" "os" "strings" @@ -16,9 +17,8 @@ type IDProviderKube struct { func (k *IDProviderKube) Run() error { l := log.WithFields(log.Fields{ - "app": "preflight-id", - "provider": "kube", - "fn": "k.Run", + "preflight": "id", + "provider": "kube", }) l.Debug("running preflight-id") if k.ServiceAccount == "" { @@ -58,12 +58,15 @@ func (k *IDProviderKube) Run() error { serviceAccountName := claims.Kubernetes.ServiceAccount.Name l.Debugf("service account name: %s", serviceAccountName) if serviceAccountName == "" { - return errors.New("invalid JWT token claims") + failStr := "failed - no service account name in JWT token" + l.Error(failStr) + return errors.New(failStr) } if k.ServiceAccount != "" && k.ServiceAccount != serviceAccountName { - l.WithError(err).Errorf("service account name mismatch: %s != %s", k.ServiceAccount, serviceAccountName) - return errors.New("service account name mismatch") + failStr := fmt.Sprintf("failed - expected %s, got %s", k.ServiceAccount, serviceAccountName) + l.Error(failStr) + return errors.New(failStr) } - l.Info("service account name match") + l.Info("passed") return nil } diff --git a/pkg/preflightid/preflightid.go b/pkg/preflightid/preflightid.go index d3ad7a3..758a618 100644 --- a/pkg/preflightid/preflightid.go +++ b/pkg/preflightid/preflightid.go @@ -9,6 +9,18 @@ import ( "gopkg.in/yaml.v3" ) +var ( + Logger *log.Logger +) + +func init() { + if Logger == nil { + Logger = log.New() + Logger.SetOutput(os.Stdout) + Logger.SetLevel(log.InfoLevel) + } +} + type IDProvider interface { Run() error } @@ -29,7 +41,7 @@ type PreflightID struct { } func LoadConfig(filepath string) (*PreflightID, error) { - l := log.WithFields(log.Fields{ + l := Logger.WithFields(log.Fields{ "fn": "LoadConfig", }) l.Debug("loading config") @@ -50,20 +62,37 @@ func LoadConfig(filepath string) (*PreflightID, error) { return pf, err } -func (p *PreflightID) Run() error { - l := log.WithFields(log.Fields{ - "app": "preflight-id", - "fn": "p.Run", +func NewPreflighter(provider Provider, config *PreflightID) (IDProvider, error) { + l := Logger.WithFields(log.Fields{ + "fn": "NewPreflighter", + "provider": provider, }) - l.Debug("running preflight-id") - switch p.Provider { + l.Debug("creating preflighter") + switch provider { case ProviderAWS: - return p.AWS.Run() + return config.AWS, nil case ProviderGCP: - return p.GCP.Run() + return config.GCP, nil case ProviderKube: - return p.Kube.Run() + return config.Kube, nil default: - return errors.New("invalid provider") + return nil, errors.New("invalid provider") + } +} + +func (p *PreflightID) Run() error { + l := Logger.WithFields(log.Fields{ + "preflight": "id", + }) + l.Debug("running preflight-id") + preflighter, err := NewPreflighter(p.Provider, p) + if err != nil { + l.WithError(err).Error("error creating preflighter") + return err + } + if err := preflighter.Run(); err != nil { + l.WithError(err).Error("error running preflighter") + return err } + return nil }