Skip to content

Commit

Permalink
expose ActionCommands
Browse files Browse the repository at this point in the history
Co-authored-by: maxlandon <maximelandon@gmail.com>
  • Loading branch information
rsteube and maxlandon committed Oct 3, 2023
1 parent 737173a commit 1e9a1db
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 54 deletions.
34 changes: 34 additions & 0 deletions defaultActions.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/rsteube/carapace/internal/common"
"github.com/rsteube/carapace/internal/config"
"github.com/rsteube/carapace/internal/env"
"github.com/rsteube/carapace/internal/export"
"github.com/rsteube/carapace/internal/man"
"github.com/rsteube/carapace/pkg/match"
Expand Down Expand Up @@ -480,3 +481,36 @@ func ActionPositional(cmd *cobra.Command) Action {
return a.Invoke(c).ToA()
})
}

// ActionCommands completes (sub)commands of given command.
// `Context.Args` is used to traverse the command tree further down - use `Action.Shift` to avoid this.
//
// carapace.Gen(helpCmd).PositionalAnyCompletion(
// carapace.ActionCommands(rootCmd),
// )
func ActionCommands(cmd *cobra.Command) Action {
return ActionCallback(func(c Context) Action {
if len(c.Args) > 0 {
for _, subCommand := range cmd.Commands() {
for _, name := range append(subCommand.Aliases, subCommand.Name()) {
if name == c.Args[0] { // cmd.Find is too lenient
return ActionCommands(subCommand).Shift(1)
}
}
}
return ActionMessage("unknown subcommand %#v for %#v", c.Args[0], cmd.Name())
}

batch := Batch()
for _, subcommand := range cmd.Commands() {
if (!subcommand.Hidden || env.Hidden()) && subcommand.Deprecated == "" {
group := common.Group{Cmd: subcommand}
batch = append(batch, ActionStyledValuesDescribed(subcommand.Name(), subcommand.Short, group.Style()).Tag(group.Tag()))
for _, alias := range subcommand.Aliases {
batch = append(batch, ActionStyledValuesDescribed(alias, subcommand.Short, group.Style()).Tag(group.Tag()))
}
}
}
return batch.ToA()
})
}
25 changes: 1 addition & 24 deletions internalActions.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"path/filepath"
"strings"

"github.com/rsteube/carapace/internal/common"
"github.com/rsteube/carapace/internal/env"
"github.com/rsteube/carapace/internal/pflagfork"
"github.com/rsteube/carapace/pkg/style"
Expand Down Expand Up @@ -133,22 +132,6 @@ func actionFlags(cmd *cobra.Command) Action {
}).Tag("flags")
}

func actionSubcommands(cmd *cobra.Command) Action {
return ActionCallback(func(c Context) Action {
batch := Batch()
for _, subcommand := range cmd.Commands() {
if (!subcommand.Hidden || env.Hidden()) && subcommand.Deprecated == "" {
group := common.Group{Cmd: subcommand}
batch = append(batch, ActionStyledValuesDescribed(subcommand.Name(), subcommand.Short, group.Style()).Tag(group.Tag()))
for _, alias := range subcommand.Aliases {
batch = append(batch, ActionStyledValuesDescribed(alias, subcommand.Short, group.Style()).Tag(group.Tag()))
}
}
}
return batch.ToA()
})
}

func initHelpCompletion(cmd *cobra.Command) {
helpCmd, _, err := cmd.Find([]string{"help"})
if err != nil {
Expand All @@ -162,12 +145,6 @@ func initHelpCompletion(cmd *cobra.Command) {
}

Gen(helpCmd).PositionalAnyCompletion(
ActionCallback(func(c Context) Action {
lastCmd, _, err := cmd.Find(c.Args)
if err != nil {
return ActionMessage(err.Error())
}
return actionSubcommands(lastCmd)
}),
ActionCommands(cmd),
)
}
60 changes: 30 additions & 30 deletions traverse.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,20 @@ import (
"github.com/spf13/cobra"
)

func traverse(c *cobra.Command, args []string) (Action, Context) {
LOG.Printf("traverse called for %#v with args %#v\n", c.Name(), args)
storage.preRun(c, args)
func traverse(cmd *cobra.Command, args []string) (Action, Context) {
LOG.Printf("traverse called for %#v with args %#v\n", cmd.Name(), args)
storage.preRun(cmd, args)

if env.Lenient() {
LOG.Printf("allowing unknown flags")
c.FParseErrWhitelist.UnknownFlags = true
cmd.FParseErrWhitelist.UnknownFlags = true
}

inArgs := []string{} // args consumed by current command
inPositionals := []string{} // positionals consumed by current command
var inFlag *pflagfork.Flag // last encountered flag that still expects arguments
c.LocalFlags() // TODO force c.mergePersistentFlags() which is missing from c.Flags()
fs := pflagfork.FlagSet{FlagSet: c.Flags()}
cmd.LocalFlags() // TODO force c.mergePersistentFlags() which is missing from c.Flags()
fs := pflagfork.FlagSet{FlagSet: cmd.Flags()}

context := NewContext(args...)
loop:
Expand All @@ -47,7 +47,7 @@ loop:
break loop

// flag
case !c.DisableFlagParsing && strings.HasPrefix(arg, "-") && (fs.IsInterspersed() || len(inPositionals) == 0):
case !cmd.DisableFlagParsing && strings.HasPrefix(arg, "-") && (fs.IsInterspersed() || len(inPositionals) == 0):
LOG.Printf("arg %#v is a flag\n", arg)
inArgs = append(inArgs, arg)
inFlag = fs.LookupArg(arg)
Expand All @@ -58,22 +58,22 @@ loop:
continue

// subcommand
case subcommand(c, arg) != nil:
case subcommand(cmd, arg) != nil:
LOG.Printf("arg %#v is a subcommand\n", arg)

switch {
case c.DisableFlagParsing:
LOG.Printf("flag parsing disabled for %#v\n", c.Name())
case cmd.DisableFlagParsing:
LOG.Printf("flag parsing disabled for %#v\n", cmd.Name())

default:
LOG.Printf("parsing flags for %#v with args %#v\n", c.Name(), inArgs)
if err := c.ParseFlags(inArgs); err != nil {
LOG.Printf("parsing flags for %#v with args %#v\n", cmd.Name(), inArgs)
if err := cmd.ParseFlags(inArgs); err != nil {
return ActionMessage(err.Error()), context
}
context.Args = c.Flags().Args()
context.Args = cmd.Flags().Args()
}

return traverse(subcommand(c, arg), args[i+1:])
return traverse(subcommand(cmd, arg), args[i+1:])

// positional
default:
Expand Down Expand Up @@ -105,56 +105,56 @@ loop:

// TODO duplicated code
switch {
case c.DisableFlagParsing:
LOG.Printf("flag parsing is disabled for %#v\n", c.Name())
case cmd.DisableFlagParsing:
LOG.Printf("flag parsing is disabled for %#v\n", cmd.Name())

default:
LOG.Printf("parsing flags for %#v with args %#v\n", c.Name(), toParse)
if err := c.ParseFlags(toParse); err != nil {
LOG.Printf("parsing flags for %#v with args %#v\n", cmd.Name(), toParse)
if err := cmd.ParseFlags(toParse); err != nil {
return ActionMessage(err.Error()), context
}
context.Args = c.Flags().Args()
context.Args = cmd.Flags().Args()
}

switch {
// dash argument
case common.IsDash(c):
case common.IsDash(cmd):
LOG.Printf("completing dash for arg %#v\n", context.Value)
context.Args = c.Flags().Args()[c.ArgsLenAtDash():]
context.Args = cmd.Flags().Args()[cmd.ArgsLenAtDash():]
LOG.Printf("context: %#v\n", context.Args)

return storage.getPositional(c, len(context.Args)), context
return storage.getPositional(cmd, len(context.Args)), context

// flag argument
case inFlag != nil && inFlag.Consumes(context.Value):
LOG.Printf("completing flag argument of %#v for arg %#v\n", inFlag.Name, context.Value)
context.Parts = inFlag.Args
return storage.getFlag(c, inFlag.Name), context
return storage.getFlag(cmd, inFlag.Name), context

// flag
case !c.DisableFlagParsing && strings.HasPrefix(context.Value, "-") && (fs.IsInterspersed() || len(inPositionals) == 0):
case !cmd.DisableFlagParsing && strings.HasPrefix(context.Value, "-") && (fs.IsInterspersed() || len(inPositionals) == 0):
if f := fs.LookupArg(context.Value); f != nil && len(f.Args) > 0 {
LOG.Printf("completing optional flag argument for arg %#v with prefix %#v\n", context.Value, f.Prefix)

switch f.Value.Type() {
case "bool":
return ActionValues("true", "false").StyleF(style.ForKeyword).Usage(f.Usage).Prefix(f.Prefix), context
default:
return storage.getFlag(c, f.Name).Prefix(f.Prefix), context
return storage.getFlag(cmd, f.Name).Prefix(f.Prefix), context
}
} else if f != nil && fs.IsPosix() && !strings.HasPrefix(context.Value, "--") && !f.IsOptarg() && f.Prefix == context.Value {
LOG.Printf("completing attached flag argument for arg %#v with prefix %#v\n", context.Value, f.Prefix)
return storage.getFlag(c, f.Name).Prefix(f.Prefix), context
return storage.getFlag(cmd, f.Name).Prefix(f.Prefix), context
}
LOG.Printf("completing flags for arg %#v\n", context.Value)
return actionFlags(c), context
return actionFlags(cmd), context

// positional or subcommand
default:
LOG.Printf("completing positionals and subcommands for arg %#v\n", context.Value)
batch := Batch(storage.getPositional(c, len(context.Args)))
if c.HasAvailableSubCommands() && len(context.Args) == 0 {
batch = append(batch, actionSubcommands(c))
batch := Batch(storage.getPositional(cmd, len(context.Args)))
if cmd.HasAvailableSubCommands() && len(context.Args) == 0 {
batch = append(batch, ActionCommands(cmd))
}
return batch.ToA(), context
}
Expand Down

0 comments on commit 1e9a1db

Please sign in to comment.