Skip to content

Commit

Permalink
Move to a JSON configuration hosts file to support more features easi…
Browse files Browse the repository at this point in the history
…er, such as custom ssh user (#16)

* start RuntimeConfig support

TODO:
*) use per host ssh user
*) update integration tests to ensure a payload is overrode when asked

* add integration test to ensure payload override works

* respect given ssh user in config

* save a pointer to a RuntimeConfig in Instance

* update README

* one more doc update
  • Loading branch information
sgsullivan committed Aug 12, 2022
1 parent b48510e commit f84ad17
Show file tree
Hide file tree
Showing 16 changed files with 316 additions and 277 deletions.
42 changes: 32 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,22 +1,45 @@
# befehl
## Run arbitrary commands over ssh in mass

- run payload.sh in PWD on hosts in targets file in PWD.. up to 2000 at a time.
- run the given payload(s) in PWD on host(s) in config.json.. up to 2000 at a time.

`./befehl execute --hosts targets --payload payload.sh -routines 2000`
`./befehl execute --runconfig config.json --routines 2000`

Output of each payload run for every node will be in the log directory (by default, its `$HOME/befehl/logs`) in a file named after the machine it ran on.

The targets file should be a plain text file (shown below) containing all hosts to run the payload on, separated by a new line. If the host has an alternate ssh port (aka not port 22) then specify the alternate port like `192.168.0.2:2222`. An example host list is shown below specifying alternate ssh port:
An example runconfig (`config.json` shown above) is shown below:

```json
{
"payload": "integration_tests/examples/payload",
"user": "root",
"hosts": [{
"host": "127.0.0.1",
"port": 1000
},
{
"host": "127.0.0.1",
"port": 1001,
"user": "snowflake",
"payload": "integration_tests/examples/payload-override"
},
{
"host": "127.0.0.1",
"port": 1002
},
{
"host": "127.0.0.1",
"port": 1003
},
{
"host": "127.0.0.1",
"port": 1004
}
]
}
```
192.168.0.2
192.168.0.3:1000
192.168.0.4:22
```

In this example, the connection attempt to 192.16.0.2 will be attempted on port 22 because the port wasn't specified.

As you can see, you can override the `payload` from the default, as `127.0.0.1:1001` is doing in the example above.
## Configuration

You can configure befehl with a config file (~/.befehl.[toml|json|yaml]) any serialization format that upstream viper supports befehl supports for the config file. Valid configuration options:
Expand All @@ -26,7 +49,6 @@ You can configure befehl with a config file (~/.befehl.[toml|json|yaml]) any ser
logdir = "/home/ssullivan/log-special"
[ssh]
privatekeyfile = "/home/ssullivan/alt/id_rsa"
user = "eingeben"
knownhostspath = "/home/asullivan/alt/.ssh/known_hosts"
hostkeyverificationenabled = true
```
Expand Down
66 changes: 40 additions & 26 deletions befehl.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,54 +16,68 @@ import (
"github.com/sgsullivan/befehl/queue"
)

func New(options *Options) *Instance {
return &Instance{
options: options,
func New(options *Options) (*Instance, error) {
runtimeConfig, err := GetRuntimeConfig(options.RunConfigPath)
if err != nil {
return nil, err
}

return &Instance{
options: options,
runtimeConfig: &runtimeConfig,
}, nil
}

func (instance *Instance) Execute(hostsFile, payload string, routines int) error {
if bytePayload, readFileErr := filesystem.ReadFile(payload); readFileErr == nil {
if instance.sshKey == nil {
if err := instance.populateSshKey(); err != nil {
return err
}
func (instance *Instance) Execute(routines int) error {
if instance.sshKey == nil {
if err := instance.populateSshKey(); err != nil {
return err
}
return instance.executePayloadOnHosts(bytePayload, hostsFile, routines)
} else {
return readFileErr
}
}

func (instance *Instance) executePayloadOnHosts(payload []byte, hostsFilePath string, routines int) error {
hostsList, err := instance.buildHostList(hostsFilePath)
if err != nil {
return err
}
return instance.executePayloadOnHosts(routines)
}

func (instance *Instance) executePayloadOnHosts(routines int) error {
var wg sync.WaitGroup
hostCnt := len(hostsList)
hostCnt := len(instance.runtimeConfig.Hosts)
wg.Add(hostCnt)
hostsChan := make(chan int, routines)
queueInstance := new(queue.Queue).New(int64(hostCnt))

sshConfig, err := instance.getSshClientConfig()
defaultSshConfig, err := instance.getDefaultSshClientConfig()
if err != nil {
return err
}

for _, hostEntry := range hostsList {
hostname, port, err := instance.transformHostFromHostEntry(hostEntry)
for _, hostEntry := range instance.runtimeConfig.Hosts {
hostsChan <- 1

hostEntry := hostEntry

chosenPayloadPath := instance.runtimeConfig.Payload
if hostEntry.Payload != "" {
chosenPayloadPath = hostEntry.Payload
}
chosenPayload, err := filesystem.ReadFile(chosenPayloadPath)
if err != nil {
return err
}
hostsChan <- 1
go func() {
instance.runPayload(&wg, hostname, port, payload, sshConfig)

sshConfig := defaultSshConfig
if hostEntry.User != "" {
sshConfig, err = instance.getSshUserClientConfig(hostEntry.User)
if err != nil {
return err
}
}

go func(hostEntry *RuntimeConfigHost) {
instance.runPayload(&wg, hostEntry.Host, hostEntry.Port, chosenPayload, sshConfig)
<-hostsChan
remaining := queueInstance.DecrementCounter()
color.Magenta(fmt.Sprintf("Remaining: %d / %d\n", remaining, hostCnt))
}()
}(&hostEntry)
}

if waitgroup.WgTimeout(&wg, time.Duration(1800)*time.Second) {
Expand Down
49 changes: 31 additions & 18 deletions cmd/execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,52 +9,67 @@ import (

var executeCmd = &cobra.Command{
Use: "execute",
Short: "Execute the given payload against the given hosts list",
Long: `Executes the given payload on each host in the hosts list. Hosts in the hosts
list should be separated by a new line. You can control how many payloads run concurrently by
passing the routines flag.
Short: "Execute the given payload(s) against the given host(s) from configuration",
Long: `Execute the given payload(s) against the given host(s) from configuration specified
by the runconfig flag. Below is an example runconfig:
{
"payload": "integration_tests/examples/payload",
"user": "root",
"hosts": [{
"host": "127.0.0.1",
"port": 1000
},
{
"host": "127.0.0.1",
"port": 1001,
"user": "snowflake",
"payload": "integration_tests/examples/payload-override"
}
]
}
You can control how many payloads are executed concurrently by passing the routines flag.
By default befehl will use the private key in $HOME/.ssh/id_rsa. This can be overrode by
specifying auth.privatekeyfile in ~/.befehl.[toml|json|yaml].
By default befehl will write the output of each payload for each host in $HOME/befehl/logs. This
can be overrode by specifying general.logdir in ~/.befehl.[toml|json|yaml].
By default befehl will attempt to ssh as root. This can be overrode by specifying auth.sshuser
in ~/.befehl.[toml|json|yaml].
Heres an example specifying all of the above mentioned options:
Heres an example specifying all supported options:
[general]
logdir = "/home/ssullivan/log-special"
[ssh]
privatekeyfile = "/home/ssullivan/alt/id_rsa"
user = "eingeben"
knownhostspath = "/home/asullivan/alt/.ssh/known_hosts"
hostkeyverificationenabled = true
`,
Run: func(cmd *cobra.Command, args []string) {
hostsFile, _ := cmd.Flags().GetString("hosts")
payload, _ := cmd.Flags().GetString("payload")
runConfig, _ := cmd.Flags().GetString("runconfig")
routines, _ := cmd.Flags().GetInt("routines")

if routines == 0 {
color.Yellow("--routines not given, defaulting to 30..\n")
routines = 30
}

instance := befehl.New(&befehl.Options{
instance, err := befehl.New(&befehl.Options{
PrivateKeyFile: Config.GetString("ssh.privatekeyfile"),
SshUser: Config.GetString("ssh.sshuser"),
LogDir: Config.GetString("general.logdir"),
SshHostKeyConfig: befehl.SshHostKeyConfig{
Enabled: Config.GetBool("ssh.hostkeyverificationenabled"),
KnownHostsPath: Config.GetString("ssh.knownhostspath"),
},
RunConfigPath: runConfig,
})
if err != nil {
panic(err)
}

if err := instance.Execute(hostsFile, payload, routines); err != nil {
if err := instance.Execute(routines); err != nil {
panic(err)
}
},
Expand All @@ -63,10 +78,8 @@ hostkeyverificationenabled = true
func init() {
RootCmd.AddCommand(executeCmd)

executeCmd.Flags().String("payload", "", "file location to the payload, which contains the commands to execute on the remote hosts")
executeCmd.Flags().String("hosts", "", "file location to hosts list, which contains all hosts (separated by newline) to run the payload on")
executeCmd.Flags().String("runconfig", "", "file location to the runtime configuration")
executeCmd.Flags().Int("routines", 0, "maximum number of payloads that will run at once (defaults to 30)")

executeCmd.MarkFlagRequired("payload")
executeCmd.MarkFlagRequired("hosts")
executeCmd.MarkFlagRequired("runconfig")
}
31 changes: 26 additions & 5 deletions getters.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
package befehl

import (
"encoding/json"
"io/ioutil"
"os"
"time"

"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/knownhosts"
)

func (instance *Instance) getSshUser() string {
if instance.options.SshUser != "" {
return instance.options.SshUser
func (instance *Instance) getDefaultSshUser() string {
if instance.runtimeConfig.User != "" {
return instance.runtimeConfig.User
}
return "root"
}
Expand All @@ -35,10 +37,10 @@ func (instance *Instance) getSshHostKeyCallback() (hostKeyCallback ssh.HostKeyCa
return
}

func (instance *Instance) getSshClientConfig() (*ssh.ClientConfig, error) {
func (instance *Instance) getSshClientConfig(getSshUser func() string) (*ssh.ClientConfig, error) {
if hostKeyCallback, err := instance.getSshHostKeyCallback(); err == nil {
return &ssh.ClientConfig{
User: instance.getSshUser(),
User: getSshUser(),
Auth: []ssh.AuthMethod{
ssh.PublicKeys(instance.sshKey),
},
Expand All @@ -50,6 +52,14 @@ func (instance *Instance) getSshClientConfig() (*ssh.ClientConfig, error) {
}
}

func (instance *Instance) getDefaultSshClientConfig() (*ssh.ClientConfig, error) {
return instance.getSshClientConfig(func() string { return instance.getDefaultSshUser() })
}

func (instance *Instance) getSshUserClientConfig(sshUser string) (*ssh.ClientConfig, error) {
return instance.getSshClientConfig(func() string { return sshUser })
}

func (instance *Instance) getLogDir() string {
if instance.options.LogDir != "" {
return instance.options.LogDir
Expand All @@ -68,3 +78,14 @@ func (instance *Instance) getPrivKeyFile() string {

return os.Getenv("HOME") + "/.ssh/id_rsa"
}

func GetRuntimeConfig(pathToRuntimeConfig string) (config RuntimeConfig, err error) {
configBytes, err := ioutil.ReadFile(pathToRuntimeConfig)
if err != nil {
return
}

err = json.Unmarshal(configBytes, &config)

return
}
48 changes: 29 additions & 19 deletions getters_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,35 @@ import (
)

func getZeroValOpts() *Instance {
return New(&Options{
if i, e := New(&Options{
PrivateKeyFile: "",
SshUser: "",
LogDir: "",
SshHostKeyConfig: SshHostKeyConfig{
Enabled: false,
KnownHostsPath: "",
},
RunConfigPath: "unit-test-resources/zero-hosts.json",
}); e != nil {
panic(e)
} else {
return i
}
}

func getNonZeroValOpts() *Instance {
i, err := New(&Options{
PrivateKeyFile: "foo",
LogDir: "baz",
SshHostKeyConfig: SshHostKeyConfig{
Enabled: true,
KnownHostsPath: defaultKnownHosts,
},
RunConfigPath: "unit-test-resources/hosts.json",
})
if err != nil {
panic(err)
}
return i
}

var defaultSshPath = os.Getenv("HOME") + "/.ssh"
Expand All @@ -40,24 +60,14 @@ func init() {
}
}

func getNonZeroValOpts() *Instance {
return New(&Options{
PrivateKeyFile: "foo",
SshUser: "bar",
LogDir: "baz",
SshHostKeyConfig: SshHostKeyConfig{
Enabled: true,
KnownHostsPath: defaultKnownHosts,
},
})
}

func TestGetSshUser(t *testing.T) {
if getZeroValOpts().getSshUser() != "root" {
t.Fatal("SshUser for zeroval is unexpected")
zuser := getZeroValOpts().getDefaultSshUser()
if zuser != "root" {
t.Fatalf("User [%s] for zeroval is unexpected", zuser)
}
if getNonZeroValOpts().getSshUser() != "bar" {
t.Fatal("SshUser for nonzeroval is unexpected")
nuser := getNonZeroValOpts().getDefaultSshUser()
if nuser != "r00t" {
t.Fatalf("User [%s] for nonzeroval is unexpected", nuser)
}
}

Expand Down Expand Up @@ -90,7 +100,7 @@ func TestGetPrivKeyFile(t *testing.T) {
}

func TestGetSshClientConfig(t *testing.T) {
got, err := getNonZeroValOpts().getSshClientConfig()
got, err := getNonZeroValOpts().getSshClientConfig(func() string { return getNonZeroValOpts().getDefaultSshUser() })
if err != nil {
t.Fatal(err)
}
Expand Down

0 comments on commit f84ad17

Please sign in to comment.