From 2fa1d95cf7fa9baf41cdc6118ad2eaf0213d8b91 Mon Sep 17 00:00:00 2001 From: Reuben Miller Date: Wed, 15 May 2024 08:02:26 +0200 Subject: [PATCH] add generic worker for non-request based commands --- pkg/cmd/cmd.go | 3 + pkg/cmdutil/factory.go | 458 +++++++++++++++++++++++--- pkg/config/cliConfiguration.go | 8 + pkg/config/commonoptions.go | 9 + pkg/iterator/bound_iterator.go | 42 +++ pkg/mapbuilder/mapbuilder.go | 17 + pkg/mapbuilder/mapbuilder_iterator.go | 24 ++ pkg/request/request.go | 14 +- pkg/worker/generic_worker.go | 431 ++++++++++++++++++++++++ pkg/worker/worker.go | 10 + 10 files changed, 970 insertions(+), 46 deletions(-) create mode 100644 pkg/iterator/bound_iterator.go create mode 100644 pkg/mapbuilder/mapbuilder_iterator.go create mode 100644 pkg/worker/generic_worker.go diff --git a/pkg/cmd/cmd.go b/pkg/cmd/cmd.go index da5068b59..a5948d74e 100644 --- a/pkg/cmd/cmd.go +++ b/pkg/cmd/cmd.go @@ -399,6 +399,9 @@ func Initialize() (*root.CmdRoot, error) { rootCmd := root.NewCmdRoot(cmdFactory, buildVersion, "") + // Add reference to root command + cmdFactory.SetCommand(rootCmd.Command) + tableOptions := &console.TableOptions{ MinColumnWidth: configHandler.ViewColumnMinWidth(), MaxColumnWidth: configHandler.ViewColumnMaxWidth(), diff --git a/pkg/cmdutil/factory.go b/pkg/cmdutil/factory.go index 95b3781fc..bee660463 100644 --- a/pkg/cmdutil/factory.go +++ b/pkg/cmdutil/factory.go @@ -6,8 +6,10 @@ import ( "fmt" "io" "log" + "net/http" "path/filepath" "strings" + "time" "github.com/reubenmiller/go-c8y-cli/v2/pkg/activitylogger" "github.com/reubenmiller/go-c8y-cli/v2/pkg/cmderrors" @@ -18,8 +20,11 @@ import ( "github.com/reubenmiller/go-c8y-cli/v2/pkg/extensions" "github.com/reubenmiller/go-c8y-cli/v2/pkg/flags" "github.com/reubenmiller/go-c8y-cli/v2/pkg/iostreams" + "github.com/reubenmiller/go-c8y-cli/v2/pkg/iterator" + "github.com/reubenmiller/go-c8y-cli/v2/pkg/jsonUtilities" "github.com/reubenmiller/go-c8y-cli/v2/pkg/jsonformatter" "github.com/reubenmiller/go-c8y-cli/v2/pkg/logger" + "github.com/reubenmiller/go-c8y-cli/v2/pkg/mapbuilder" "github.com/reubenmiller/go-c8y-cli/v2/pkg/mode" "github.com/reubenmiller/go-c8y-cli/v2/pkg/pathresolver" "github.com/reubenmiller/go-c8y-cli/v2/pkg/request" @@ -27,6 +32,7 @@ import ( "github.com/reubenmiller/go-c8y/pkg/c8y" "github.com/spf13/cobra" "github.com/tidwall/gjson" + "github.com/tidwall/pretty" ) type Browser interface { @@ -51,6 +57,15 @@ type Factory struct { // Executable is the path to the currently invoked binary Executable string + + // Command + Command *cobra.Command +} + +// Set reference to the cobra command +func (f *Factory) SetCommand(cmd *cobra.Command) *Factory { + f.Command = cmd + return f } // CreateModeEnabled create mode is enabled @@ -193,6 +208,68 @@ func (f *Factory) RunWithWorkers(client *c8y.Client, cmd *cobra.Command, req *c8 return w.ProcessRequestAndResponse(cmd, req, inputIterators) } +func (f *Factory) RunWithGenericWorkers(cmd *cobra.Command, inputIterators *flags.RequestInputIterators, iter iterator.Iterator, runFunc worker.Runner) error { + client, err := f.Client() + if err != nil { + return err + } + cfg, err := f.Config() + if err != nil { + return err + } + log, err := f.Logger() + if err != nil { + return err + } + + activityLogger, err := f.ActivityLogger() + if err != nil { + return err + } + // consol, err := f.Console() + // if err != nil { + // return err + // } + // dataview, err := f.DataView() + // if err != nil { + // return err + // } + + w, err := worker.NewGenericWorker(log, cfg, f.IOStreams, client, activityLogger, runFunc, f.CheckPostCommandError) + + if err != nil { + return err + } + return w.Run(cmd, iter, inputIterators) +} + +func (f *Factory) RunSequentiallyWithGenericWorkers(cmd *cobra.Command, iter iterator.Iterator, runFunc worker.Runner, inputIterators *flags.RequestInputIterators) error { + client, err := f.Client() + if err != nil { + return err + } + cfg, err := f.Config() + if err != nil { + return err + } + log, err := f.Logger() + if err != nil { + return err + } + + activityLogger, err := f.ActivityLogger() + if err != nil { + return err + } + + w, err := worker.NewGenericWorker(log, cfg, f.IOStreams, client, activityLogger, runFunc, f.CheckPostCommandError) + + if err != nil { + return err + } + return w.RunSequentially(cmd, iter, inputIterators) +} + // GetViewProperties Look up the view properties to display func (f *Factory) GetViewProperties(cfg *config.Config, cmd *cobra.Command, output []byte) ([]string, error) { dataView, err := f.DataView() @@ -250,45 +327,6 @@ func (f *Factory) GetViewProperties(cfg *config.Config, cmd *cobra.Command, outp return viewProperties, nil } -// WriteJSONToConsole writes given json output to the console supporting the common options of select, output etc. -func (f *Factory) WriteJSONToConsole(cfg *config.Config, cmd *cobra.Command, property string, output []byte) error { - consol, err := f.Console() - if err != nil { - return err - } - commonOptions, err := cfg.GetOutputCommonOptions(cmd) - if err != nil { - return err - } - - if len(commonOptions.Filters.Pluck) == 0 { - // don't fail if view properties fail - props, _ := f.GetViewProperties(cfg, cmd, output) - if len(props) > 0 { - commonOptions.Filters.Pluck = props - } - } - output, filterErr := commonOptions.Filters.Apply(string(output), property, false, consol.SetHeaderFromInput) - if filterErr != nil { - return filterErr - } - - output = bytes.ReplaceAll(output, []byte("\\u003c"), []byte("<")) - output = bytes.ReplaceAll(output, []byte("\\u003e"), []byte(">")) - output = bytes.ReplaceAll(output, []byte("\\u0026"), []byte("&")) - - jsonformatter.WithOutputFormatters( - consol, - output, - false, - jsonformatter.WithFileOutput(commonOptions.OutputFile != "", commonOptions.OutputFile, false), - jsonformatter.WithTrimSpace(true), - jsonformatter.WithJSONStreamOutput(true, consol.IsJSONStream(), consol.IsTextOutput()), - jsonformatter.WithSuffix(len(output) > 0, "\n"), - ) - return nil -} - func (f *Factory) CheckPostCommandError(err error) error { cfg, configErr := f.Config() if configErr != nil { @@ -445,3 +483,345 @@ func NewRequestInputIterators(cmd *cobra.Command, cfg *config.Config) (*flags.Re } return inputIter, err } + +type OutputContext struct { + Input any + Response *http.Response + Duration time.Duration +} + +func (f *Factory) ExecuteOutputTemplate(output []byte, params OutputContext, commonOptions *config.CommonCommandOptions) ([]byte, error) { + if commonOptions.OutputTemplate == "" { + return output, nil + } + + outputBuilder := mapbuilder.NewInitializedMapBuilder(true) + + if err := outputBuilder.AddLocalTemplateVariable("flags", commonOptions.CommandFlags); err != nil { + return nil, err + } + + requestData := make(map[string]interface{}) + responseData := make(map[string]interface{}) + + // Add request/response variables + if params.Response != nil { + resp := params.Response + requestData["path"] = resp.Request.URL.Path + requestData["pathEncoded"] = strings.Replace(resp.Request.URL.String(), resp.Request.URL.Scheme+"://"+resp.Request.URL.Host, "", 1) + requestData["host"] = resp.Request.URL.Host + requestData["url"] = resp.Request.URL.String() + requestData["query"] = request.TryUnescapeURL(resp.Request.URL.RawQuery) + requestData["queryParams"] = request.FlattenArrayMap(resp.Request.URL.Query()) + requestData["method"] = resp.Request.Method + // requestData["header"] = resp.Response.Request.Header + + // TODO: Add a response variable to included the status code, content type, + responseData["statusCode"] = resp.StatusCode + responseData["status"] = resp.Status + responseData["duration"] = params.Duration.Milliseconds() + responseData["contentLength"] = resp.ContentLength + responseData["contentType"] = resp.Header.Get("Content-Type") + responseData["header"] = request.FlattenArrayMap(resp.Header) + responseData["proto"] = resp.Proto + responseData["body"] = string(output) + } + + if err := outputBuilder.AddLocalTemplateVariable("request", requestData); err != nil { + return nil, err + } + + if err := outputBuilder.AddLocalTemplateVariable("response", responseData); err != nil { + return nil, err + } + + outputJSON := make(map[string]any) + if parseErr := jsonUtilities.ParseJSON(string(output), outputJSON); parseErr == nil { + if err := outputBuilder.AddLocalTemplateVariable("output", outputJSON); err != nil { + return nil, err + } + } else { + if err := outputBuilder.AddLocalTemplateVariable("output", string(output)); err != nil { + return nil, err + } + } + + outputBuilder.AppendTemplate(commonOptions.OutputTemplate) + out, outErr := outputBuilder.MarshalJSONWithInput(params.Input) + + if outErr != nil { + return out, outErr + } + return out, nil +} + +func (f *Factory) WriteOutputWithoutPropertyGuess(output []byte, params OutputContext) error { + cfg, err := f.Config() + if err != nil { + return err + } + commonOptions, err := cfg.GetOutputCommonOptions(f.Command) + if err != nil { + return err + } + + _, err = f.WriteOutputWithRows(output, params, commonOptions.DisableResultPropertyDetection()) + return err +} + +func (f *Factory) WriteOutput(output []byte, params OutputContext, commonOptions *config.CommonCommandOptions) error { + _, err := f.WriteOutputWithRows(output, params, commonOptions) + return err +} + +func (f *Factory) WriteOutputWithRows(output []byte, params OutputContext, commonOptions *config.CommonCommandOptions) (int, error) { + consol, err := f.Console() + if err != nil { + return 0, err + } + + cfg, err := f.Config() + if err != nil { + return 0, err + } + + dataView, err := f.DataView() + if err != nil { + return 0, err + } + + logg, err := f.Logger() + if err != nil { + return 0, err + } + + if commonOptions == nil { + if f.Command == nil { + return 0, fmt.Errorf("command output options are mandatory") + } + commonOptions = cfg.MustGetOutputCommonOptions(f.Command) + } + + unfilteredSize := 0 + outputJSON := gjson.ParseBytes(output) + + if len(output) > 0 || commonOptions.HasOutputTemplate() { + // estimate size based on utf8 encoding. 1 char is 1 byte + if params.Response != nil { + PrintResponseSize(logg, params.Response, output) + } + + var responseText []byte + isJSONResponse := jsonUtilities.IsValidJSON(output) + + dataProperty := "" + showRaw := cfg.RawOutput() || cfg.WithTotalPages() || cfg.WithTotalElements() + + dataProperty = commonOptions.ResultProperty + if dataProperty == "" { + dataProperty = f.GuessDataProperty(outputJSON) + } else if dataProperty == "-" { + dataProperty = "" + } + + if v := outputJSON.Get(dataProperty); v.Exists() && v.IsArray() { + unfilteredSize = len(v.Array()) + logg.Infof("Unfiltered array size. len=%d", unfilteredSize) + } + + // Apply output template (before the data is processed as the template can transform text to json or other way around) + if commonOptions.HasOutputTemplate() { + var tempBody []byte + if showRaw || dataProperty == "" { + tempBody = output + } else { + tempBody = []byte(outputJSON.Get(dataProperty).Raw) + } + dataProperty = "" + + tmplOutput, tmplErr := f.ExecuteOutputTemplate(tempBody, params, commonOptions) + if tmplErr != nil { + return unfilteredSize, tmplErr + } + + if jsonUtilities.IsValidJSON(tmplOutput) { + isJSONResponse = true + output = pretty.Ugly(tmplOutput) + outputJSON = gjson.ParseBytes(output) + } else { + isJSONResponse = false + // TODO: Is removing the quotes doing too much, what happens if someone is building csv, and it using quotes around some fields? + // e.g. `"my value",100`, that would get transformed to `my value",100` + // Trim any quotes wrapping the values + tmplOutput = bytes.TrimSpace(tmplOutput) + + output = pretty.Ugly(bytes.Trim(tmplOutput, "\"")) + outputJSON = gjson.ParseBytes([]byte("")) + } + } + + if isJSONResponse && commonOptions.Filters != nil { + if showRaw { + dataProperty = "" + } + + if cfg.RawOutput() { + logg.Infof("Raw mode active. In raw mode the following settings are forced, view=off, output=json") + } + view := cfg.ViewOption() + logg.Infof("View mode: %s", view) + + // Detect view (if no filters are given) + if len(commonOptions.Filters.Pluck) == 0 { + if len(output) > 0 && dataView != nil { + inputData := outputJSON + if dataProperty != "" { + inputData = outputJSON.Get(dataProperty) + } + + switch strings.ToLower(view) { + case config.ViewsOff: + // dont apply a view + if !showRaw { + commonOptions.Filters.Pluck = []string{"**"} + } + case config.ViewsAuto: + viewData := &dataview.ViewData{ + ResponseBody: &inputData, + } + + if params.Response != nil { + viewData.ContentType = params.Response.Header.Get("Content-Type") + viewData.Request = params.Response.Request + } + + props, err := dataView.GetView(viewData) + + if err != nil || len(props) == 0 { + if err != nil { + logg.Infof("No matching view detected. defaulting to '**'. %s", err) + } else { + logg.Info("No matching view detected. defaulting to '**'") + } + commonOptions.Filters.Pluck = []string{"**"} + } else { + logg.Infof("Detected view: %s", strings.Join(props, ", ")) + commonOptions.Filters.Pluck = props + } + default: + props, err := dataView.GetViewByName(view) + if err != nil || len(props) == 0 { + if err != nil { + logg.Warnf("no matching view found. %s, name=%s", err, view) + } else { + logg.Warnf("no matching view found. name=%s", view) + } + commonOptions.Filters.Pluck = []string{"**"} + } else { + logg.Infof("Detected view: %s", strings.Join(props, ", ")) + commonOptions.Filters.Pluck = props + } + } + } + } else { + logg.Debugf("using existing pluck values. %v", commonOptions.Filters.Pluck) + } + + if filterOutput, filterErr := commonOptions.Filters.Apply(string(output), dataProperty, false, consol.SetHeaderFromInput); filterErr != nil { + logg.Warnf("filter error. %s", filterErr) + responseText = filterOutput + } else { + responseText = filterOutput + } + + emptyArray := []byte("[]\n") + + if !showRaw { + if len(responseText) == len(emptyArray) && bytes.Equal(responseText, emptyArray) { + logg.Info("No matching results found. Empty response will be omitted") + responseText = []byte{} + } + } + + } else { + responseText = output + } + + // replace special escaped unicode sequences + responseText = bytes.ReplaceAll(responseText, []byte("\\u003c"), []byte("<")) + responseText = bytes.ReplaceAll(responseText, []byte("\\u003e"), []byte(">")) + responseText = bytes.ReplaceAll(responseText, []byte("\\u0026"), []byte("&")) + + // Wait for progress bar to finish before printing to console + // to prevent overriding the output + f.IOStreams.WaitForProgressIndicator() + + jsonformatter.WithOutputFormatters( + consol, + responseText, + !isJSONResponse, + jsonformatter.WithFileOutput(commonOptions.OutputFile != "", commonOptions.OutputFile, false), + jsonformatter.WithTrimSpace(true), + jsonformatter.WithJSONStreamOutput(isJSONResponse, consol.IsJSONStream(), consol.IsTextOutput()), + jsonformatter.WithSuffix(len(responseText) > 0, "\n"), + ) + } + return unfilteredSize, nil +} + +func (f *Factory) GuessDataProperty(output gjson.Result) string { + property := "" + arrayPropertes := []string{} + totalKeys := 0 + + logg, err := f.Logger() + if err != nil { + panic(err) + } + + if v := output.Get("id"); !v.Exists() { + // Find the property which is an array + output.ForEach(func(key, value gjson.Result) bool { + totalKeys++ + if value.IsArray() { + arrayPropertes = append(arrayPropertes, key.String()) + } + return true + }) + } + + if len(arrayPropertes) > 1 { + logg.Debugf("Could not detect property as more than 1 array like property detected: %v", arrayPropertes) + return "" + } + logg.Debugf("Array properties: %v", arrayPropertes) + + if len(arrayPropertes) == 0 { + return "" + } + + property = arrayPropertes[0] + + // if total keys is a high number, than it is most likely not an array of data + // i.e. for the /tenant/statistics + if property != "" && totalKeys > 10 { + return "" + } + + if property != "" && totalKeys < 10 { + logg.Debugf("Data property: %s", property) + } + return property +} + +func PrintResponseSize(l *logger.Logger, resp *http.Response, output []byte) { + if resp.ContentLength > -1 { + l.Infof("Response Length: %0.1fKB", float64(resp.ContentLength)/1024) + } else { + if resp.Uncompressed { + l.Infof("Response Length: %0.1fKB (uncompressed)", float64(len(output))/1024) + } else { + l.Infof("Response Length: %0.1fKB", float64(len(output))/1024) + } + } +} diff --git a/pkg/config/cliConfiguration.go b/pkg/config/cliConfiguration.go index 7305d2f00..0aac83738 100644 --- a/pkg/config/cliConfiguration.go +++ b/pkg/config/cliConfiguration.go @@ -1595,6 +1595,14 @@ func (c *Config) GetJSONSelect() []string { return allitems } +func (c *Config) MustGetOutputCommonOptions(cmd *cobra.Command) *CommonCommandOptions { + opts, err := c.GetOutputCommonOptions(cmd) + if err != nil { + panic(err) + } + return &opts +} + // GetOutputCommonOptions get common output options which controls how the output should be handled i.e. json filter, selects, csv etc. func (c *Config) GetOutputCommonOptions(cmd *cobra.Command) (CommonCommandOptions, error) { if c.commonOptions != nil { diff --git a/pkg/config/commonoptions.go b/pkg/config/commonoptions.go index db02fdde4..a63b42a62 100644 --- a/pkg/config/commonoptions.go +++ b/pkg/config/commonoptions.go @@ -108,3 +108,12 @@ func (options CommonCommandOptions) AddQueryParametersWithMapping(query *flags.Q } } } + +func (options CommonCommandOptions) HasOutputTemplate() bool { + return options.OutputTemplate != "" +} + +func (options *CommonCommandOptions) DisableResultPropertyDetection() *CommonCommandOptions { + options.ResultProperty = "-" + return options +} diff --git a/pkg/iterator/bound_iterator.go b/pkg/iterator/bound_iterator.go new file mode 100644 index 000000000..ae2b13a60 --- /dev/null +++ b/pkg/iterator/bound_iterator.go @@ -0,0 +1,42 @@ +package iterator + +import ( + "io" + "sync/atomic" +) + +// BoundIterator is generic iterator which executes a function on every iteration +type BoundIterator struct { + currentIndex int64 // access atomically (must be defined at the top) + endIndex int64 + iter Iterator +} + +// GetNext will count through the values and return them one by one +func (i *BoundIterator) GetNext() (line []byte, input interface{}, err error) { + nextIndex := atomic.AddInt64(&i.currentIndex, 1) + if i.endIndex != 0 && nextIndex > i.endIndex { + err = io.EOF + } else { + line, input, err = i.iter.GetNext() + } + return line, input, err +} + +// IsBound return true if the iterator is bound +func (i *BoundIterator) IsBound() bool { + return i.endIndex > 0 +} + +// MarshalJSON return the value in a json compatible value +func (i *BoundIterator) MarshalJSON() (line []byte, err error) { + return MarshalJSON(i) +} + +// NewBoundIterator return an iterator which makes an existing iterator bound +func NewBoundIterator(iter Iterator, max int64) *BoundIterator { + return &BoundIterator{ + iter: iter, + endIndex: max, + } +} diff --git a/pkg/mapbuilder/mapbuilder.go b/pkg/mapbuilder/mapbuilder.go index 8caa7385e..3b238bdf2 100644 --- a/pkg/mapbuilder/mapbuilder.go +++ b/pkg/mapbuilder/mapbuilder.go @@ -1159,6 +1159,23 @@ func (b *MapBuilder) Set(path string, value interface{}) error { return nil } +// Set multiple values +func (b *MapBuilder) SetTuple(path string, values ...interface{}) error { + for _, value := range values { + // store iterators separately so we can intercept the raw value which is otherwise lost during json marshalling + if it, ok := value.(iterator.Iterator); ok { + b.bodyIterators = append(b.bodyIterators, IteratorReference{path, it}) + Logger.Debugf("DEBUG: Found iterator. path=%s", path) + return nil + } + + if err := b.SetPath(path, value); err != nil { + return err + } + } + return nil +} + // MergeMaps merges a list of maps into the body. If the body does not already exists, // then it will be ignored. Only shallow merging is done. // Duplicate keys will be overwritten by maps later in the list diff --git a/pkg/mapbuilder/mapbuilder_iterator.go b/pkg/mapbuilder/mapbuilder_iterator.go new file mode 100644 index 000000000..43e695c0b --- /dev/null +++ b/pkg/mapbuilder/mapbuilder_iterator.go @@ -0,0 +1,24 @@ +package mapbuilder + +func NewMapBuilderIterator(b *MapBuilder) *MapBuilderIterator { + return &MapBuilderIterator{ + MapBuilder: b, + } +} + +type MapBuilderIterator struct { + MapBuilder *MapBuilder + Index int64 +} + +func (i *MapBuilderIterator) GetNext() (line []byte, input interface{}, err error) { + out, err := i.MapBuilder.MarshalJSON() + return out, out, err +} + +func (i *MapBuilderIterator) IsBound() bool { + if i.MapBuilder.TemplateIterator != nil { + return false + } + return i.MapBuilder.TemplateIterator.IsBound() +} diff --git a/pkg/request/request.go b/pkg/request/request.go index 8cadef005..2dc756adb 100644 --- a/pkg/request/request.go +++ b/pkg/request/request.go @@ -280,7 +280,7 @@ func (r *RequestHandler) PrintRequestDetails(w io.Writer, requestOptions *c8y.Re PathEncoded: strings.Replace(fullURL, req.URL.Scheme+"://"+req.URL.Host, "", 1), Method: req.Method, Headers: headers, - Query: tryUnescapeURL(req.URL.RawQuery), + Query: TryUnescapeURL(req.URL.RawQuery), Body: requestBody, Shell: shell, PowerShell: pwsh, @@ -315,7 +315,7 @@ func (r *RequestHandler) PrintRequestDetails(w io.Writer, requestOptions *c8y.Re // markdown sectionLabel.Fprintf(w, "What If: Sending [%s] request to [%s]\n", req.Method, req.URL) - label.Fprintf(w, "\n### %s %s", details.Method, tryUnescapeURL(details.PathEncoded)) + label.Fprintf(w, "\n### %s %s", details.Method, TryUnescapeURL(details.PathEncoded)) if len(req.Header) > 0 { // sort header names @@ -362,7 +362,7 @@ func (r *RequestHandler) PrintRequestDetails(w io.Writer, requestOptions *c8y.Re } } -func tryUnescapeURL(v string) string { +func TryUnescapeURL(v string) string { unescapedQuery, err := url.QueryUnescape(v) if err != nil { return v @@ -665,7 +665,7 @@ func optimizeManagedObjectsURL(u *url.URL, lastID string) *url.URL { return u } -func flattenArrayMap[K string, V []string](m map[K]V) map[K]any { +func FlattenArrayMap[K string, V []string](m map[K]V) map[K]any { out := make(map[K]any) for key, value := range m { if len(value) == 1 { @@ -690,8 +690,8 @@ func ExecuteTemplate(responseText []byte, resp *http.Response, input any, common requestData["pathEncoded"] = strings.Replace(resp.Request.URL.String(), resp.Request.URL.Scheme+"://"+resp.Request.URL.Host, "", 1) requestData["host"] = resp.Request.URL.Host requestData["url"] = resp.Request.URL.String() - requestData["query"] = tryUnescapeURL(resp.Request.URL.RawQuery) - requestData["queryParams"] = flattenArrayMap(resp.Request.URL.Query()) + requestData["query"] = TryUnescapeURL(resp.Request.URL.RawQuery) + requestData["queryParams"] = FlattenArrayMap(resp.Request.URL.Query()) requestData["method"] = resp.Request.Method // requestData["header"] = resp.Response.Request.Header if err := outputBuilder.AddLocalTemplateVariable("request", requestData); err != nil { @@ -705,7 +705,7 @@ func ExecuteTemplate(responseText []byte, resp *http.Response, input any, common responseData["duration"] = duration.Milliseconds() responseData["contentLength"] = resp.ContentLength responseData["contentType"] = resp.Header.Get("Content-Type") - responseData["header"] = flattenArrayMap(resp.Header) + responseData["header"] = FlattenArrayMap(resp.Header) responseData["proto"] = resp.Proto responseData["body"] = string(responseText) if err := outputBuilder.AddLocalTemplateVariable("response", responseData); err != nil { diff --git a/pkg/worker/generic_worker.go b/pkg/worker/generic_worker.go new file mode 100644 index 000000000..2a77c5820 --- /dev/null +++ b/pkg/worker/generic_worker.go @@ -0,0 +1,431 @@ +package worker + +import ( + "errors" + "fmt" + "io" + "os" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/reubenmiller/go-c8y-cli/v2/pkg/activitylogger" + "github.com/reubenmiller/go-c8y-cli/v2/pkg/cmderrors" + "github.com/reubenmiller/go-c8y-cli/v2/pkg/config" + "github.com/reubenmiller/go-c8y-cli/v2/pkg/flags" + "github.com/reubenmiller/go-c8y-cli/v2/pkg/iostreams" + "github.com/reubenmiller/go-c8y-cli/v2/pkg/iterator" + "github.com/reubenmiller/go-c8y-cli/v2/pkg/logger" + "github.com/reubenmiller/go-c8y-cli/v2/pkg/progressbar" + "github.com/reubenmiller/go-c8y-cli/v2/pkg/prompt" + "github.com/reubenmiller/go-c8y/pkg/c8y" + "github.com/spf13/cobra" +) + +type Runner func(Job) (any, error) + +func NewGenericWorker(log *logger.Logger, cfg *config.Config, iostream *iostreams.IOStreams, client *c8y.Client, activityLog *activitylogger.ActivityLogger, runFunc Runner, checkError func(error) error) (*GenericWorker, error) { + return &GenericWorker{ + Config: cfg, + Logger: log, + IO: iostream, + ActivityLogger: activityLog, + Client: client, + Execute: runFunc, + CheckError: checkError, + }, nil +} + +type GenericWorker struct { + Config *config.Config + IO *iostreams.IOStreams + Logger *logger.Logger + Client *c8y.Client + ActivityLogger *activitylogger.ActivityLogger + CheckError func(error) error + Execute Runner +} + +// GetMaxWorkers maximum number of workers +func (w *GenericWorker) GetMaxWorkers() int { + if w.Config == nil { + return 5 + } + return w.Config.GetMaxWorkers() +} + +// GetMaxJob maximum number of jobs allowed +func (w *GenericWorker) GetMaxJobs() int64 { + if w.Config == nil { + return 100 + } + return w.Config.GetMaxJobs() +} + +func (w *GenericWorker) GetBatchOptions(cmd *cobra.Command) (*BatchOptions, error) { + options := &BatchOptions{ + AbortOnErrorCount: w.Config.AbortOnErrorCount(), + TotalWorkers: w.Config.GetWorkers(), + Delay: w.Config.WorkerDelay(), + DelayBefore: w.Config.WorkerDelayBefore(), + SemanticMethod: flags.GetSemanticMethodFromAnnotation(cmd), + } + + if v, err := cmd.Flags().GetInt("count"); err == nil { + options.NumJobs = v + } + + if v, err := cmd.Flags().GetInt("startIndex"); err == nil { + options.StartIndex = v + } + + return options, nil +} + +type Job struct { + ID int64 + Value any + CommonOptions config.CommonCommandOptions + Input any + Options BatchOptions +} + +func (w *GenericWorker) RunSequentially(cmd *cobra.Command, iter iterator.Iterator, inputIterators *flags.RequestInputIterators) error { + // TODO: How does an unbound iterator get caught here? + if inputIterators == nil { + return fmt.Errorf("missing input iterators") + } + + // get common options and batch settings + commonOptions, err := w.Config.GetOutputCommonOptions(cmd) + if err != nil { + return cmderrors.NewUserError(fmt.Sprintf("Failed to get common options. err=%s", err)) + } + + batchOptions, err := w.GetBatchOptions(cmd) + if err != nil { + return err + } + + // Configure post actions + batchOptions.PostActions = inputIterators.PipeOptions.PostActions + + out, input, err := iter.GetNext() + if err != nil { + return err + } + job := Job{ + ID: 1, + Value: out, + CommonOptions: commonOptions, + Options: *batchOptions, + Input: input, + } + _, executeErr := w.Execute(job) + + return executeErr +} + +func (w *GenericWorker) Run(cmd *cobra.Command, iter iterator.Iterator, inputIterators *flags.RequestInputIterators) error { + // TODO: How does an unbound iterator get caught here? + if inputIterators == nil { + return fmt.Errorf("missing input iterators") + } + + // get common options and batch settings + commonOptions, err := w.Config.GetOutputCommonOptions(cmd) + if err != nil { + return cmderrors.NewUserError(fmt.Sprintf("Failed to get common options. err=%s", err)) + } + + batchOptions, err := w.GetBatchOptions(cmd) + if err != nil { + return err + } + + // Configure post actions + batchOptions.PostActions = inputIterators.PipeOptions.PostActions + + return w.run(iter, commonOptions, *batchOptions) +} + +func (w *GenericWorker) run(iter iterator.Iterator, commonOptions config.CommonCommandOptions, batchOptions BatchOptions) error { + // Two channels - to send them work and to collect their results. + // buffer size does not really matter, it just needs to be high + // enough not to block the workers + + // TODO: how to detect when request iterator is finished when using the body iterator (total number of requests?) + if batchOptions.TotalWorkers < 1 { + batchOptions.TotalWorkers = 1 + } + jobs := make(chan Job, batchOptions.TotalWorkers-1) + results := make(chan error, batchOptions.TotalWorkers-1) + workers := sync.WaitGroup{} + + // don't start the progress bar until all confirmations are done + progbar := progressbar.NewMultiProgressBar(w.IO.ErrOut, 1, batchOptions.TotalWorkers, "requests", w.Config.ShowProgress()) + + for iWork := 1; iWork <= batchOptions.TotalWorkers; iWork++ { + w.Logger.Debugf("starting worker: %d", iWork) + workers.Add(1) + go w.StartWorker(iWork, jobs, results, progbar, &workers) + } + + jobID := int64(0) + skipConfirm := false + shouldConfirm := false + promptCount := int32(0) + promptWG := sync.WaitGroup{} + + maxJobs := w.GetMaxJobs() + tenantName := "" + if w.Client != nil { + tenantName = w.Client.TenantName + } + w.Logger.Infof("Max jobs: %d", maxJobs) + + // add jobs async + go func() { + defer close(jobs) + jobInputErrors := int64(0) + for { + jobID++ + w.Logger.Debugf("checking job iterator: %d", jobID) + + // check if iterator is exhausted + value, input, err := iter.GetNext() + + if errors.Is(err, io.EOF) { + // no more requests, decrement job id as the job was not started + jobID-- + break + } + + if maxJobs != 0 && jobID > maxJobs { + w.Logger.Infof("maximum jobs reached: limit=%d", maxJobs) + break + } + + if err != nil { + if errors.Is(err, io.EOF) { + // no more requests + break + } + jobInputErrors++ + + rootCauseErr := err + if errors.Is(err, cmderrors.ErrNoMatchesFound) { + rootCauseErr = err + } else if parentErr := errors.Unwrap(err); parentErr != nil { + rootCauseErr = parentErr + } + + w.Config.LogErrorF(rootCauseErr, "skipping job: %d. %s", jobID, rootCauseErr) + results <- err + + // Note: stop adding jobs if total errors are exceeded + // This is necessary as the worker still needs time to process + // the current job, so there can be a delay before the results are read. + if jobInputErrors >= int64(batchOptions.AbortOnErrorCount) { + break + } + + // move to next job + continue + } + w.Logger.Debugf("adding job: %d", jobID) + + if value != nil { + if batchOptions.SemanticMethod != "" { + // Use a custom method which controls how the request should be handled but is not the actual request + shouldConfirm = w.Config.ShouldConfirm(batchOptions.SemanticMethod) + } else { + // TODO: Allow a job to control if something should be confirmed or not + // shouldConfirm = w.Config.ShouldConfirm(request.Method) + } + } + + // confirm action + if !skipConfirm && shouldConfirm { + // wait for any other previous prompted jobs to finish + promptWG.Wait() + + operation := "Execute command" + if commonOptions.ConfirmText != "" { + operation = commonOptions.ConfirmText + } else if len(os.Args[1:]) > 1 { + // build confirm text from cmd structure + operation = fmt.Sprintf("%s %s", os.Args[2], strings.TrimRight(os.Args[1], "s")) + } + promptMessage, _ := batchOptions.GetConfirmationMessage(operation, value, input) + confirmResult, err := prompt.Confirm(fmt.Sprintf("(job: %d)", jobID), promptMessage, "tenant "+tenantName, prompt.ConfirmYes.String(), false) + + switch confirmResult { + case prompt.ConfirmYesToAll: + skipConfirm = true + case prompt.ConfirmYes: + // confirmed + case prompt.ConfirmNo: + w.Logger.Warningf("skipping job: %d. %s", jobID, err) + if w.ActivityLogger != nil { + // TODO: Let batching control custom log message + // w.ActivityLogger.LogCustom(err.Error() + ". " + request.Path) + } + results <- err + continue + case prompt.ConfirmNoToAll: + w.Logger.Infof("skipping job: %d. %s", jobID, err) + if w.ActivityLogger != nil { + // TODO: Let batching control custom log message + // w.ActivityLogger.LogCustom(err.Error() + ". " + request.Path) + } + w.Logger.Infof("cancelling all remaining jobs") + results <- err + } + if confirmResult == prompt.ConfirmNoToAll { + break + } + + promptWG.Add(1) + atomic.AddInt32(&promptCount, 1) + } + + if skipConfirm || !shouldConfirm { + progbar.Start(float64(batchOptions.Delay * 2 / time.Millisecond)) + } + + jobs <- Job{ + ID: jobID, + Options: batchOptions, + Input: input, + Value: value, + CommonOptions: commonOptions, + } + } + + w.Logger.Debugf("finished adding jobs. lastJobID=%d", jobID) + }() + + // collect all the results of the work. + totalErrors := make([]error, 0) + + // close the results when the works are finished, but don't block reading the results + wasCancelled := int32(0) + go func() { + workers.Wait() + time.Sleep(200 * time.Microsecond) + + // prevent closing channel twice + if atomic.AddInt32(&wasCancelled, 1) == 1 { + close(results) + } + }() + + for err := range results { + if err == nil { + w.Logger.Debugf("job successful") + } else { + w.Logger.Infof("job error. %s", err) + } + + if err != nil && err != io.EOF { + + // overwrite error + err = w.CheckError(err) + if err != nil { + totalErrors = append(totalErrors, err) + } + } + // exit early + if batchOptions.AbortOnErrorCount != 0 && len(totalErrors) >= batchOptions.AbortOnErrorCount { + if atomic.AddInt32(&wasCancelled, 1) == 1 { + close(results) + } + return cmderrors.NewUserErrorWithExitCode(cmderrors.ExitAbortedWithErrors, fmt.Sprintf("aborted batch as error count has been exceeded. totalErrors=%d", batchOptions.AbortOnErrorCount)) + } + + // communicate that the prompt has received a result + pendingPrompts := atomic.AddInt32(&promptCount, -1) + if pendingPrompts+1 > 0 { + promptWG.Done() + } + } + if progbar.IsEnabled() && progbar.IsRunning() && jobID > 1 { + // wait for progress bar to update last increment + time.Sleep(progbar.RefreshRate()) + } + + maxJobsReached := maxJobs != 0 && jobID > maxJobs + if total := len(totalErrors); total > 0 { + if total == 1 { + // return only error + return totalErrors[0] + } + // aggregate error + message := fmt.Sprintf("jobs completed with %d errors", total) + if maxJobsReached { + message += fmt.Sprintf(". job limit exceeded=%v", maxJobsReached) + } + return cmderrors.NewUserErrorWithExitCode(cmderrors.ExitCompletedWithErrors, message) + } + if maxJobsReached { + return cmderrors.NewUserErrorWithExitCode(cmderrors.ExitJobLimitExceeded, fmt.Sprintf("max job limit exceeded. limit=%d", maxJobs)) + } + return nil +} + +// These workers will receive work on the `jobs` channel and send the corresponding +// results on `results` +func (w *GenericWorker) StartWorker(id int, jobs <-chan Job, results chan<- error, prog *progressbar.ProgressBar, wg *sync.WaitGroup) { + var err error + onStartup := true + + var total int64 + + defer wg.Done() + for job := range jobs { + total++ + workerStart := prog.StartJob(id, total) + + if job.Options.DelayBefore > 0 { + w.Logger.Infof("worker %d: sleeping %s before starting job", id, job.Options.DelayBefore) + time.Sleep(job.Options.DelayBefore) + } + + if !onStartup { + if !errors.Is(err, io.EOF) && job.Options.Delay > 0 { + w.Logger.Infof("worker %d: sleeping %s before fetching next job", id, job.Options.Delay) + time.Sleep(job.Options.Delay) + } + } + onStartup = false + + w.Logger.Infof("worker %d: started job %d", id, job.ID) + startTime := time.Now().UnixNano() + + result, resultErr := w.Execute(job) + // Handle post request actions (only if original response was ok) + // and stop actions if an error is encountered + if resultErr == nil { + for i, action := range job.Options.PostActions { + w.Logger.Debugf("Executing action: %d", i) + runOutput, runErr := action.Run(result) + if runErr != nil { + resultErr = runErr + w.Logger.Warningf("Action failed. output=%#v, err=%s", runOutput, runErr) + break + } + } + } + + elapsedMS := (time.Now().UnixNano() - startTime) / 1000.0 / 1000.0 + + w.Logger.Infof("worker %d: finished job %d in %dms", id, job.ID, elapsedMS) + prog.FinishedJob(id, workerStart) + + // return result before delay, so errors can be handled before the sleep + results <- resultErr + } + prog.WorkerCompleted(id) +} diff --git a/pkg/worker/worker.go b/pkg/worker/worker.go index 18fa6283f..4ad19fce2 100644 --- a/pkg/worker/worker.go +++ b/pkg/worker/worker.go @@ -42,6 +42,16 @@ type BatchOptions struct { InputData []string inputIndex int + + // Generic values + ConfirmationMessage func(string, any, any) (string, error) +} + +func (b *BatchOptions) GetConfirmationMessage(operation string, value any, input any) (string, error) { + if b.ConfirmationMessage != nil { + return b.ConfirmationMessage(operation, value, input) + } + return operation, nil } func (b *BatchOptions) GetItem() (string, error) {