Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 23 additions & 176 deletions cmd/manualchange.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package cmd

import (
"context"
"errors"
"fmt"
"os"
"os/signal"
Expand All @@ -13,15 +12,13 @@ import (
"github.com/google/uuid"
"github.com/overmindtech/ovm-cli/tracing"
"github.com/overmindtech/sdp-go"
"github.com/overmindtech/sdp-go/sdpws"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"google.golang.org/protobuf/encoding/protojson"
"nhooyr.io/websocket"
"nhooyr.io/websocket/wspb"
)

// manualChangeCmd is the equivalent to submit-plan for manual changes
Expand Down Expand Up @@ -126,8 +123,6 @@ func ManualChange(ctx context.Context, ready chan bool) int {
log.WithContext(ctx).WithFields(lf).Info("re-using change")
}

receivedItems := []*sdp.Reference{}

method, err := methodFromString(viper.GetString("query-method"))
if err != nil {
log.WithContext(ctx).WithError(err).WithFields(lf).Error("can't parse --query-method")
Expand All @@ -145,192 +140,44 @@ func ManualChange(ctx context.Context, ready chan bool) int {
log.WithContext(ctx).WithFields(lf).WithError(err).Error("Failed to wake up sources")
return 1
}
u := uuid.New()

queries := []*sdp.Query{
{
UUID: u[:],
Method: method,
Scope: viper.GetString("query-scope"),
Type: viper.GetString("query-type"),
Query: viper.GetString("query"),
IgnoreCache: true,
},
}
options := &websocket.DialOptions{
HTTPClient: NewAuthenticatedClient(ctx, otelhttp.DefaultClient),
}

log.WithContext(ctx).WithFields(lf).WithField("item_count", len(queries)).Info("identifying items")
// nolint: bodyclose // nhooyr.io/websocket reads the body internally
c, _, err := websocket.Dial(ctx, viper.GetString("gateway-url"), options)
ws, err := sdpws.Dial(ctx, gatewayUrl, otelhttp.DefaultClient, nil)
if err != nil {
log.WithContext(ctx).WithFields(lf).WithError(err).Error("Failed to connect to overmind API")
log.WithContext(ctx).WithFields(lf).WithError(err).Error("Failed to connect to gateway")
return 1
}
defer c.Close(websocket.StatusGoingAway, "")

// the default, 32kB is too small for cert bundles and rds-db-cluster-parameter-groups
c.SetReadLimit(2 * 1024 * 1024)

queriesSentChan := make(chan struct{})
go func() {
for _, q := range queries {
req := sdp.GatewayRequest{
MinStatusInterval: minStatusInterval,
RequestType: &sdp.GatewayRequest_Query{
Query: q,
},
}
err = wspb.Write(ctx, c, &req)

if err == nil {
log.WithContext(ctx).WithFields(log.Fields{
"scope": q.Scope,
"type": q.Type,
"query": q.Query,
"method": q.Method.String(),
"uuid": q.ParseUuid().String(),
}).Trace("Started query")
}
if err != nil {
log.WithContext(ctx).WithFields(lf).WithError(err).WithField("req", &req).Error("Failed to send request")
continue
}
}
queriesSentChan <- struct{}{}
}()

responses := make(chan *sdp.GatewayResponse)

// Start a goroutine that reads responses
go func() {
for {
res := new(sdp.GatewayResponse)

err = wspb.Read(ctx, c, res)

if err != nil {
var e websocket.CloseError
if errors.As(err, &e) {
log.WithContext(ctx).WithFields(lf).WithFields(log.Fields{
"code": e.Code.String(),
"reason": e.Reason,
}).Info("Websocket closing")
return
}
log.WithContext(ctx).WithFields(lf).WithError(err).Error("Failed to read response")
return
}

responses <- res
}
}()

activeQueries := make(map[uuid.UUID]bool)
queriesSent := false

// Read the responses
responses:
for {
select {
case <-queriesSentChan:
queriesSent = true

case <-ctx.Done():
log.WithContext(ctx).WithFields(lf).Info("Context cancelled, exiting")
return 1
u := uuid.New()
q := &sdp.Query{
UUID: u[:],
Method: method,
Scope: viper.GetString("query-scope"),
Type: viper.GetString("query-type"),
Query: viper.GetString("query"),
IgnoreCache: true,
}

case resp := <-responses:
switch resp.ResponseType.(type) {

case *sdp.GatewayResponse_Status:
status := resp.GetStatus()
statusFields := log.Fields{
"summary": status.Summary,
"responders": status.Summary.Responders,
"queriesSent": queriesSent,
"post_processing_complete": status.PostProcessingComplete,
}

if status.Done() {
// fall through from all "final" query states, check if there's still queries in progress;
// only break from the loop if all queries have already been sent
// TODO: see above, still needs DefaultStartTimeout implemented to account for slow sources
allDone := allDone(ctx, activeQueries, lf)
statusFields["allDone"] = allDone
if allDone && queriesSent {
log.WithContext(ctx).WithFields(lf).WithFields(statusFields).Info("all responders and queries done")
break responses
} else {
log.WithContext(ctx).WithFields(lf).WithFields(statusFields).Info("all responders done, with unfinished queries")
}
} else {
log.WithContext(ctx).WithFields(lf).WithFields(statusFields).Info("still waiting for responders")
}

case *sdp.GatewayResponse_QueryStatus:
status := resp.GetQueryStatus()
statusFields := log.Fields{
"status": status.Status.String(),
}
queryUuid := status.GetUUIDParsed()
if queryUuid == nil {
log.WithContext(ctx).WithFields(lf).WithFields(statusFields).Debugf("Received QueryStatus with nil UUID")
continue responses
}
statusFields["query"] = queryUuid

switch status.Status {
case sdp.QueryStatus_UNSPECIFIED:
statusFields["unexpected_status"] = true
case sdp.QueryStatus_STARTED:
activeQueries[*queryUuid] = true
case sdp.QueryStatus_FINISHED:
activeQueries[*queryUuid] = false
case sdp.QueryStatus_ERRORED:
activeQueries[*queryUuid] = false
case sdp.QueryStatus_CANCELLED:
activeQueries[*queryUuid] = false
default:
statusFields["unexpected_status"] = true
}

log.WithContext(ctx).WithFields(lf).WithFields(statusFields).Debugf("query status update")

case *sdp.GatewayResponse_NewItem:
item := resp.GetNewItem()
log.WithContext(ctx).WithFields(lf).WithField("item", item.GloballyUniqueName()).Infof("new item")

receivedItems = append(receivedItems, item.Reference())

case *sdp.GatewayResponse_NewEdge:
log.WithContext(ctx).WithFields(lf).Debug("ignored edge")

case *sdp.GatewayResponse_QueryError:
err := resp.GetQueryError()
log.WithContext(ctx).WithFields(lf).WithError(err).Errorf("Error from %v(%v)", err.ResponderName, err.SourceName)

case *sdp.GatewayResponse_Error:
err := resp.GetError()
log.WithContext(ctx).WithFields(lf).WithField(log.ErrorKey, err).Errorf("generic error")

default:
j := protojson.Format(resp)
log.WithContext(ctx).WithFields(lf).Infof("Unknown %T Response:\n%v", resp.ResponseType, j)
}
}
log.WithContext(ctx).WithFields(lf).WithField("item_count", 1).Info("identifying items")
receivedItems, err := ws.Query(ctx, q)
if err != nil {
log.WithContext(ctx).WithFields(lf).WithError(err).Error("failed to send query")
return 1
}

if len(receivedItems) > 0 {
log.WithContext(ctx).WithFields(lf).WithField("received_items", len(receivedItems)).Info("updating changing items on the change record")
} else {
log.WithContext(ctx).WithFields(lf).WithField("received_items", len(receivedItems)).Info("updating change record with no items")
}

references := make([]*sdp.Reference, len(receivedItems))
for i, item := range receivedItems {
references[i] = item.Reference()
}
resultStream, err := client.UpdateChangingItems(ctx, &connect.Request[sdp.UpdateChangingItemsRequest]{
Msg: &sdp.UpdateChangingItemsRequest{
ChangeUUID: changeUuid[:],
ChangingItems: receivedItems,
ChangingItems: references,
},
})
if err != nil {
Expand Down
Loading