diff --git a/cmd/changefromtfplan.go b/cmd/changefromtfplan.go index dd21f3d3..fd796e54 100644 --- a/cmd/changefromtfplan.go +++ b/cmd/changefromtfplan.go @@ -2,6 +2,7 @@ package cmd import ( "context" + "errors" "fmt" "os" "os/signal" @@ -15,14 +16,18 @@ import ( log "github.com/sirupsen/logrus" "github.com/spf13/cobra" "github.com/spf13/viper" + "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" + "google.golang.org/protobuf/encoding/protojson" + "nhooyr.io/websocket" + "nhooyr.io/websocket/wspb" ) // changeFromTfplanCmd represents the change-from-tfplan command var changeFromTfplanCmd = &cobra.Command{ Use: "change-from-tfplan [--title TITLE] [--description DESCRIPTION] [--ticket-link URL] [--tfplan FILE]", - Short: "Creates a new Change from a given terraform plan (in JSON format)", + Short: "Creates a new Change from a given terraform plan file", PreRun: func(cmd *cobra.Command, args []string) { // Bind these to viper err := viper.BindPFlags(cmd.PersistentFlags()) @@ -41,19 +46,43 @@ var changeFromTfplanCmd = &cobra.Command{ }, } +// test data var ( - affecting_resource *sdp.Reference = &sdp.Reference{ - Type: "elbv2-load-balancer", - UniqueAttributeValue: "ingress", - Scope: "944651592624.eu-west-2", + affecting_uuid uuid.UUID = uuid.New() + affecting_resource *sdp.Query = &sdp.Query{ + Type: "elbv2-load-balancer", + Method: sdp.QueryMethod_GET, + Query: "ingress", + RecursionBehaviour: &sdp.Query_RecursionBehaviour{ + LinkDepth: 0, + }, + Scope: "944651592624.eu-west-2", + UUID: affecting_uuid[:], } - safe_resource *sdp.Reference = &sdp.Reference{ - Type: "ec2-security-group", - UniqueAttributeValue: "sg-09533c300cd1a41c1", - Scope: "944651592624.eu-west-2", + + safe_uuid uuid.UUID = uuid.New() + safe_resource *sdp.Query = &sdp.Query{ + Type: "ec2-security-group", + Method: sdp.QueryMethod_GET, + Query: "sg-09533c300cd1a41c1", + RecursionBehaviour: &sdp.Query_RecursionBehaviour{ + LinkDepth: 0, + }, + Scope: "944651592624.eu-west-2", + UUID: safe_uuid[:], } ) +func changingItemQueriesFromTfplan() []*sdp.Query { + var changing_items []*sdp.Query + if viper.GetBool("test-affecting") { + changing_items = []*sdp.Query{affecting_resource} + } else { + changing_items = []*sdp.Query{safe_resource} + } + return changing_items +} + func ChangeFromTfplan(signals chan os.Signal, ready chan bool) int { timeout, err := time.ParseDuration(viper.GetString("timeout")) if err != nil { @@ -69,11 +98,13 @@ func ChangeFromTfplan(signals chan os.Signal, ready chan bool) int { // Connect to the websocket log.WithContext(ctx).Debugf("Connecting to overmind API: %v", viper.GetString("url")) + lf := log.Fields{ + "url": viper.GetString("url"), + } + ctx, err = ensureToken(ctx, signals) if err != nil { - log.WithContext(ctx).WithError(err).WithFields(log.Fields{ - "url": viper.GetString("url"), - }).Error("failed to authenticate") + log.WithContext(ctx).WithError(err).WithFields(lf).Error("failed to authenticate") return 1 } @@ -94,34 +125,181 @@ func ChangeFromTfplan(signals chan os.Signal, ready chan bool) int { }, }) if err != nil { - log.WithContext(ctx).WithError(err).WithFields(log.Fields{ - "url": viper.GetString("url"), - }).Error("failed to create change") + log.WithContext(ctx).WithError(err).WithFields(lf).Error("failed to create change") return 1 } - log.WithContext(ctx).WithFields(log.Fields{ - "url": viper.GetString("url"), - "change": createResponse.Msg.Change.Metadata.GetUUIDParsed(), - }).Info("created a new change") - var changing_items []*sdp.Reference - if viper.GetBool("test-affecting") { - changing_items = []*sdp.Reference{affecting_resource} - } else { - changing_items = []*sdp.Reference{safe_resource} + lf["change"] = createResponse.Msg.Change.Metadata.GetUUIDParsed() + log.WithContext(ctx).WithFields(lf).Info("created a new change") + + log.WithContext(ctx).WithFields(lf).Info("resolving items from terraform plan") + queries := changingItemQueriesFromTfplan() + + options := &websocket.DialOptions{ + HTTPClient: NewAuthenticatedClient(ctx, otelhttp.DefaultClient), + } + + c, _, err := websocket.Dial(ctx, viper.GetString("url"), options) + if err != nil { + log.WithContext(ctx).WithFields(lf).WithError(err).Error("Failed to connect to overmind API") + return 1 + } + defer c.Close(websocket.StatusGoingAway, "") + + // the default, 32kB is too small for cert bundles and rds-db-cluster-parameter-groups + c.SetReadLimit(2 * 1024 * 1024) + + queriesSentChan := make(chan struct{}) + go func() { + for _, q := range queries { + req := sdp.GatewayRequest{ + MinStatusInterval: minStatusInterval, + RequestType: &sdp.GatewayRequest_Query{ + Query: q, + }, + } + err = wspb.Write(ctx, c, &req) + if err != nil { + log.WithContext(ctx).WithFields(lf).WithError(err).WithField("req", &req).Error("Failed to send request") + continue + } + } + queriesSentChan <- struct{}{} + }() + + responses := make(chan *sdp.GatewayResponse) + + // Start a goroutine that reads responses + go func() { + for { + res := new(sdp.GatewayResponse) + + err = wspb.Read(ctx, c, res) + + if err != nil { + var e websocket.CloseError + if errors.As(err, &e) { + log.WithContext(ctx).WithFields(lf).WithFields(log.Fields{ + "code": e.Code.String(), + "reason": e.Reason, + }).Info("Websocket closing") + return + } + log.WithContext(ctx).WithFields(lf).WithError(err).Error("Failed to read response") + return + } + + responses <- res + } + }() + + activeQueries := make(map[uuid.UUID]bool) + queriesSent := false + + receivedItems := []*sdp.Reference{} + + // Read the responses +responses: + for { + select { + case <-queriesSentChan: + queriesSent = true + + case <-signals: + log.WithContext(ctx).WithFields(lf).Info("Received interrupt, exiting") + return 1 + + case <-ctx.Done(): + log.WithContext(ctx).WithFields(lf).Info("Context cancelled, exiting") + return 1 + + case resp := <-responses: + switch resp.ResponseType.(type) { + + case *sdp.GatewayResponse_Status: + status := resp.GetStatus() + statusFields := log.Fields{ + "summary": status.Summary, + "responders": status.Summary.Responders, + "queriesSent": queriesSent, + "post_processing_complete": status.PostProcessingComplete, + } + + if status.Done() { + // fall through from all "final" query states, check if there's still queries in progress; + // only break from the loop if all queries have already been sent + // TODO: see above, still needs DefaultStartTimeout implemented to account for slow sources + allDone := allDone(ctx, activeQueries, lf) + statusFields["allDone"] = allDone + if allDone && queriesSent { + log.WithContext(ctx).WithFields(lf).WithFields(statusFields).Info("all responders and queries done") + break responses + } else { + log.WithContext(ctx).WithFields(lf).WithFields(statusFields).Info("all responders done, with unfinished queries") + } + } else { + log.WithContext(ctx).WithFields(lf).WithFields(statusFields).Info("still waiting for responders") + } + + case *sdp.GatewayResponse_QueryStatus: + status := resp.GetQueryStatus() + statusFields := log.Fields{ + "status": status.Status.String(), + } + queryUuid := status.GetUUIDParsed() + if queryUuid == nil { + log.WithContext(ctx).WithFields(lf).WithFields(statusFields).Debugf("Received QueryStatus with nil UUID") + continue responses + } + statusFields["query"] = queryUuid + + switch status.Status { + case sdp.QueryStatus_STARTED: + activeQueries[*queryUuid] = true + case sdp.QueryStatus_FINISHED: + activeQueries[*queryUuid] = false + case sdp.QueryStatus_ERRORED: + activeQueries[*queryUuid] = false + case sdp.QueryStatus_CANCELLED: + activeQueries[*queryUuid] = false + default: + statusFields["unexpected_status"] = true + } + + log.WithContext(ctx).WithFields(lf).WithFields(statusFields).Debugf("query status update") + + case *sdp.GatewayResponse_NewItem: + item := resp.GetNewItem() + log.WithContext(ctx).WithFields(lf).WithField("item", item.GloballyUniqueName()).Infof("new item") + + receivedItems = append(receivedItems, item.Reference()) + + case *sdp.GatewayResponse_NewEdge: + log.WithContext(ctx).WithFields(lf).Debug("ignored edge") + + case *sdp.GatewayResponse_QueryError: + err := resp.GetQueryError() + log.WithContext(ctx).WithFields(lf).WithError(err).Errorf("Error from %v(%v)", err.ResponderName, err.SourceName) + + case *sdp.GatewayResponse_Error: + err := resp.GetError() + log.WithContext(ctx).WithFields(lf).WithField(log.ErrorKey, err).Errorf("generic error") + + default: + j := protojson.Format(resp) + log.WithContext(ctx).WithFields(lf).Infof("Unknown %T Response:\n%v", resp.ResponseType, j) + } + } } resultStream, err := client.UpdateChangingItems(ctx, &connect.Request[sdp.UpdateChangingItemsRequest]{ Msg: &sdp.UpdateChangingItemsRequest{ ChangeUUID: createResponse.Msg.Change.Metadata.UUID, - ChangingItems: changing_items, + ChangingItems: receivedItems, }, }) if err != nil { - log.WithContext(ctx).WithError(err).WithFields(log.Fields{ - "url": viper.GetString("url"), - "change": createResponse.Msg.Change.Metadata.GetUUIDParsed(), - }).Error("failed to update changing items") + log.WithContext(ctx).WithFields(lf).WithError(err).Error("failed to update changing items") return 1 } @@ -129,33 +307,24 @@ func ChangeFromTfplan(signals chan os.Signal, ready chan bool) int { first_log := true for resultStream.Receive() { if resultStream.Err() != nil { - log.WithContext(ctx).WithError(err).WithFields(log.Fields{ - "url": viper.GetString("url"), - "change": createResponse.Msg.Change.Metadata.GetUUIDParsed(), - }).Error("error streaming results") + log.WithContext(ctx).WithFields(lf).WithError(err).Error("error streaming results") return 1 } msg := resultStream.Msg() - // log the first message and at most every 500ms during discovery + // log the first message and at most every 250ms during discovery // to avoid spanning the cli output time_since_last_log := time.Since(last_log) if first_log || msg.State != sdp.CalculateBlastRadiusResponse_STATE_DISCOVERING || time_since_last_log > 250*time.Millisecond { - log.WithContext(ctx).WithFields(log.Fields{ - "url": viper.GetString("url"), - "change": createResponse.Msg.Change.Metadata.GetUUIDParsed(), - "msg": msg, - }).Info("status update") + log.WithContext(ctx).WithFields(lf).WithField("msg", msg).Info("status update") last_log = time.Now() first_log = false } } - log.WithContext(ctx).WithFields(log.Fields{ - "change": createResponse.Msg.Change.Metadata.GetUUIDParsed(), - "change-url": fmt.Sprintf("%v/changes/%v", viper.GetString("frontend"), createResponse.Msg.Change.Metadata.GetUUIDParsed()), - }).Info("change ready") + changeUrl := fmt.Sprintf("%v/changes/%v", viper.GetString("frontend"), createResponse.Msg.Change.Metadata.GetUUIDParsed()) + log.WithContext(ctx).WithFields(lf).WithField("change-url", changeUrl).Info("change ready") fetchResponse, err := client.GetChange(ctx, &connect.Request[sdp.GetChangeRequest]{ Msg: &sdp.GetChangeRequest{ @@ -163,23 +332,18 @@ func ChangeFromTfplan(signals chan os.Signal, ready chan bool) int { }, }) if err != nil { - log.WithContext(ctx).WithError(err).WithFields(log.Fields{ - "url": viper.GetString("url"), - }).Error("failed to get updated change") + log.WithContext(ctx).WithFields(lf).WithError(err).Error("failed to get updated change") return 1 } + for _, a := range fetchResponse.Msg.Change.Properties.AffectedAppsUUID { appUuid, err := uuid.FromBytes(a) if err != nil { - log.WithContext(ctx).WithError(err).WithFields(log.Fields{ - "url": viper.GetString("url"), - "value": a, - }).Error("received invalid app uuid") + log.WithContext(ctx).WithFields(lf).WithError(err).WithField("app", a).Error("received invalid app uuid") continue } - log.WithContext(ctx).WithFields(log.Fields{ - "change": createResponse.Msg.Change.Metadata.GetUUIDParsed(), - "change-url": fmt.Sprintf("%v/changes/%v", viper.GetString("frontend"), createResponse.Msg.Change.Metadata.GetUUIDParsed()), + log.WithContext(ctx).WithFields(lf).WithFields(log.Fields{ + "change-url": changeUrl, "app": appUuid, "app-url": fmt.Sprintf("%v/apps/%v", viper.GetString("frontend"), appUuid), }).Info("affected app") @@ -188,11 +352,23 @@ func ChangeFromTfplan(signals chan os.Signal, ready chan bool) int { return 0 } +func allDone(ctx context.Context, activeQueries map[uuid.UUID]bool, lf log.Fields) bool { + allDone := true + for q := range activeQueries { + if activeQueries[q] { + log.WithContext(ctx).WithFields(lf).WithField("query", q).Debugf("query still active") + allDone = false + break + } + } + return allDone +} + func init() { rootCmd.AddCommand(changeFromTfplanCmd) - changeFromTfplanCmd.PersistentFlags().String("changes-url", "https://api.prod.overmind.tech/", "The changes service API endpoint") - changeFromTfplanCmd.PersistentFlags().String("frontend", "https://app.overmind.tech/", "The frontend base URL") + changeFromTfplanCmd.PersistentFlags().String("changes-url", "https://api.prod.overmind.tech", "The changes service API endpoint") + changeFromTfplanCmd.PersistentFlags().String("frontend", "https://app.overmind.tech", "The frontend base URL") changeFromTfplanCmd.PersistentFlags().String("terraform", "terraform", "The binary to use for calling terraform. Will be looked up in the system PATH.") changeFromTfplanCmd.PersistentFlags().String("tfplan", "./tfplan", "Parse changing items from this terraform plan file.") diff --git a/cmd/request.go b/cmd/request.go index 0a9a50cc..3f8698c0 100644 --- a/cmd/request.go +++ b/cmd/request.go @@ -5,7 +5,6 @@ import ( "encoding/json" "errors" "fmt" - "net/http" "os" "os/signal" "syscall" @@ -67,14 +66,16 @@ func Request(signals chan os.Signal, ready chan bool) int { return 1 } + lf := log.Fields{ + "url": viper.GetString("url"), + } + // Connect to the websocket log.WithContext(ctx).Debugf("Connecting to overmind API: %v", viper.GetString("url")) ctx, err = ensureToken(ctx, signals) if err != nil { - log.WithContext(ctx).WithError(err).WithFields(log.Fields{ - "url": viper.GetString("url"), - }).Error("failed to authenticate") + log.WithContext(ctx).WithFields(lf).WithError(err).Error("failed to authenticate") return 1 } @@ -83,41 +84,36 @@ func Request(signals chan os.Signal, ready chan bool) int { defer cancel() options := &websocket.DialOptions{ - HTTPClient: otelhttp.DefaultClient, - } - - if viper.GetString("token") != "" { - log.Info("Setting authorization token") - options.HTTPHeader = make(http.Header) - options.HTTPHeader.Set("Authorization", fmt.Sprintf("Bearer %v", viper.GetString("token"))) + HTTPClient: NewAuthenticatedClient(ctx, otelhttp.DefaultClient), } c, _, err := websocket.Dial(ctx, viper.GetString("url"), options) if err != nil { - log.WithContext(ctx).WithError(err).WithFields(log.Fields{ - "url": viper.GetString("url"), - }).Error("Failed to connect to overmind API") + log.WithContext(ctx).WithFields(lf).WithError(err).Error("Failed to connect to overmind API") return 1 } defer c.Close(websocket.StatusGoingAway, "") + // the default, 32kB is too small for cert bundles and rds-db-cluster-parameter-groups + c.SetReadLimit(2 * 1024 * 1024) + // Log the request in JSON b, err := json.MarshalIndent(req, "", " ") if err != nil { - log.WithContext(ctx).WithError(err).Error("Failed to marshal request") + log.WithContext(ctx).WithFields(lf).WithError(err).Error("Failed to marshal request") return 1 } - log.WithContext(ctx).Infof("Request:\n%v", string(b)) + log.WithContext(ctx).WithFields(lf).Infof("Request:\n%v", string(b)) err = wspb.Write(ctx, c, req) if err != nil { - log.WithContext(ctx).WithFields(log.Fields{ - "error": err, - }).Error("Failed to send request") + log.WithContext(ctx).WithFields(lf).WithError(err).Error("Failed to send request") return 1 } + queriesSent := true + responses := make(chan *sdp.GatewayResponse) // Start a goroutine that reads responses @@ -153,27 +149,55 @@ responses: for { select { case <-signals: - log.WithContext(ctx).Info("Received interrupt, exiting") + log.WithContext(ctx).WithFields(lf).Info("Received interrupt, exiting") return 1 + case <-ctx.Done(): - log.WithContext(ctx).Info("Context cancelled, exiting") + log.WithContext(ctx).WithFields(lf).Info("Context cancelled, exiting") return 1 + case resp := <-responses: switch resp.ResponseType.(type) { + case *sdp.GatewayResponse_Status: + status := resp.GetStatus() + statusFields := log.Fields{ + "summary": status.Summary, + "responders": status.Summary.Responders, + "queriesSent": queriesSent, + "post_processing_complete": status.PostProcessingComplete, + } + + if status.Done() { + // fall through from all "final" query states, check if there's still queries in progress; + // only break from the loop if all queries have already been sent + // TODO: see above, still needs DefaultStartTimeout implemented to account for slow sources + allDone := allDone(ctx, activeQueries, lf) + statusFields["allDone"] = allDone + if allDone && queriesSent { + log.WithContext(ctx).WithFields(lf).WithFields(statusFields).Info("all responders and queries done") + break responses + } else { + log.WithContext(ctx).WithFields(lf).WithFields(statusFields).Info("all responders done, with unfinished queries") + } + } else { + log.WithContext(ctx).WithFields(lf).WithFields(statusFields).Info("still waiting for responders") + } + case *sdp.GatewayResponse_QueryStatus: status := resp.GetQueryStatus() + statusFields := log.Fields{ + "status": status.Status.String(), + } queryUuid := status.GetUUIDParsed() if queryUuid == nil { - log.WithContext(ctx).Debugf("Received QueryStatus with nil UUID: %v", status.Status.String()) + log.WithContext(ctx).WithFields(lf).WithFields(statusFields).Debugf("Received QueryStatus with nil UUID") continue responses } - - log.WithContext(ctx).Debugf("Status for %v: %v", queryUuid, status.Status.String()) + statusFields["query"] = queryUuid switch status.Status { case sdp.QueryStatus_STARTED: activeQueries[*queryUuid] = true - continue responses case sdp.QueryStatus_FINISHED: activeQueries[*queryUuid] = false case sdp.QueryStatus_ERRORED: @@ -181,44 +205,33 @@ responses: case sdp.QueryStatus_CANCELLED: activeQueries[*queryUuid] = false default: - log.WithContext(ctx).Debugf("unexpected status %v: %v", queryUuid, status.Status.String()) - continue responses + statusFields["unexpected_status"] = true } - // fall through from all "final" query states, check if there's still queries in progress - // TODO: needs DefaultStartTimeout implemented to account for slow sources - allDone := true - active: - for q := range activeQueries { - if activeQueries[q] { - log.WithContext(ctx).Debugf("%v still active", q) - allDone = false - break active - } - } + log.WithContext(ctx).WithFields(lf).WithFields(statusFields).Debugf("query status update") - if allDone { - break responses - } case *sdp.GatewayResponse_NewItem: item := resp.GetNewItem() + log.WithContext(ctx).WithFields(lf).WithField("item", item.GloballyUniqueName()).Infof("new item") - log.WithContext(ctx).Infof("New item: %v", item.GloballyUniqueName()) case *sdp.GatewayResponse_NewEdge: edge := resp.GetNewEdge() + log.WithContext(ctx).WithFields(lf).WithFields(log.Fields{ + "from": edge.From.GloballyUniqueName(), + "to": edge.To.GloballyUniqueName(), + }).Info("new edge") - log.WithContext(ctx).Infof("New edge: %v->%v", edge.From.GloballyUniqueName(), edge.To.GloballyUniqueName()) case *sdp.GatewayResponse_QueryError: err := resp.GetQueryError() + log.WithContext(ctx).WithFields(lf).Errorf("Error from %v(%v): %v", err.ResponderName, err.SourceName, err) - log.WithContext(ctx).Errorf("Error from %v(%v): %v", err.ResponderName, err.SourceName, err) case *sdp.GatewayResponse_Error: err := resp.GetError() - log.WithContext(ctx).Errorf("generic error: %v", err) + log.WithContext(ctx).WithFields(lf).Errorf("generic error: %v", err) + default: j := protojson.Format(resp) - - log.WithContext(ctx).Infof("Unknown %T Response:\n%v", resp.ResponseType, j) + log.WithContext(ctx).WithFields(lf).Infof("Unknown %T Response:\n%v", resp.ResponseType, j) } } } @@ -227,6 +240,7 @@ responses: log.WithContext(ctx).Info("Starting snapshot") msgId := uuid.New() snapReq := &sdp.GatewayRequest{ + MinStatusInterval: minStatusInterval, RequestType: &sdp.GatewayRequest_StoreSnapshot{ StoreSnapshot: &sdp.StoreSnapshot{ Name: viper.GetString("snapshot-name"), @@ -274,7 +288,9 @@ responses: } func createInitialRequest() (*sdp.GatewayRequest, error) { - req := new(sdp.GatewayRequest) + req := &sdp.GatewayRequest{ + MinStatusInterval: minStatusInterval, + } u := uuid.New() switch viper.GetString("request-type") { diff --git a/cmd/root.go b/cmd/root.go index 7cee773b..6e89dda2 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -8,6 +8,7 @@ import ( "net/url" "os" "strings" + "time" "github.com/google/uuid" "github.com/overmindtech/ovm-cli/tracing" @@ -17,11 +18,14 @@ import ( "github.com/spf13/viper" "github.com/uptrace/opentelemetry-go-extra/otellogrus" "golang.org/x/oauth2" + "google.golang.org/protobuf/types/known/durationpb" ) var cfgFile string var logLevel string +var minStatusInterval = durationpb.New(250 * time.Millisecond) + // rootCmd represents the base command when called without any subcommands var rootCmd = &cobra.Command{ Use: "ovm-cli", @@ -145,7 +149,6 @@ func ensureToken(ctx context.Context, signals chan os.Signal) (context.Context, } // Set the token - viper.Set("token", token.AccessToken) return context.WithValue(ctx, sdp.UserTokenContextKey{}, token.AccessToken), nil } return ctx, fmt.Errorf("no token configured and target URL (%v) is insecure", gatewayURL)