Skip to content

Commit

Permalink
add configurable known hosts checks
Browse files Browse the repository at this point in the history
  • Loading branch information
sgsullivan committed Aug 4, 2022
1 parent a97edaf commit ef66f96
Show file tree
Hide file tree
Showing 8 changed files with 98 additions and 16 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@ logdir = "/home/ssullivan/log-special"
[auth]
privatekeyfile = "/home/ssullivan/alt/.ssh/id_rsa"
sshuser = "nonrootuser"
sshhostkeyverificationenabled = true
sshknownhostspath = "/home/asullivan/alt/.ssh/known_hosts"
```

These options should be self explanatory so I wont describe what each does here.
Unless enabled as shown above, ssh known host verification is disabled.

## Obtaining prebuilt binaries

Expand Down
7 changes: 5 additions & 2 deletions befehl.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ func (instance *Instance) executePayloadOnHosts(payload []byte, hostsFilePath st
hostsChan := make(chan int, routines)
queueInstance := new(queue.Queue).New(int64(hostCnt))

sshConfig := instance.getSshClientConfig()
sshConfig, err := instance.getSshClientConfig()
if err != nil {
return err
}

for _, hostEntry := range hostsList {
hostname, port, err := instance.transformHostFromHostEntry(hostEntry)
Expand All @@ -73,7 +76,7 @@ func (instance *Instance) executePayloadOnHosts(payload []byte, hostsFilePath st

func (instance *Instance) runPayload(wg *sync.WaitGroup, host string, port int, payload []byte, sshConfig *ssh.ClientConfig) {
defer wg.Done()
log.Printf("running payload on %s ..\n", host)
log.Printf("running payload on %s:%d ..\n", host, port)

// establish the connection
conn, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", host, port), sshConfig)
Expand Down
4 changes: 4 additions & 0 deletions cmd/execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ sshuser = "eingeben"
PrivateKeyFile: Config.GetString("auth.privatekeyfile"),
SshUser: Config.GetString("auth.sshuser"),
LogDir: Config.GetString("general.logdir"),
SshHostKeyConfig: befehl.SshHostKeyConfig{
Enabled: Config.GetBool("auth.sshhostkeyverificationenabled"),
KnownHostsPath: Config.GetString("auth.sshknownhostspath"),
},
})

if err := instance.Execute(hostsFile, payload, routines); err != nil {
Expand Down
41 changes: 33 additions & 8 deletions getters.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"time"

"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/knownhosts"
)

func (instance *Instance) getSshUser() string {
Expand All @@ -14,14 +15,38 @@ func (instance *Instance) getSshUser() string {
return "root"
}

func (instance *Instance) getSshClientConfig() *ssh.ClientConfig {
return &ssh.ClientConfig{
User: instance.getSshUser(),
Auth: []ssh.AuthMethod{
ssh.PublicKeys(instance.sshKey),
},
Timeout: time.Duration(10) * time.Second,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
func (instance *Instance) getSshKnowHostsPath() string {
path := os.Getenv("HOME") + "/.ssh/known_hosts"

config := instance.options.SshHostKeyConfig
if config.Enabled && config.KnownHostsPath != "" {
path = config.KnownHostsPath
}

return path
}

func (instance *Instance) getSshHostKeyCallback() (hostKeyCallback ssh.HostKeyCallback, err error) {
hostKeyCallback = ssh.InsecureIgnoreHostKey()
if instance.options.SshHostKeyConfig.Enabled {
hostKeyCallback, err = knownhosts.New(instance.getSshKnowHostsPath())
}

return
}

func (instance *Instance) getSshClientConfig() (*ssh.ClientConfig, error) {
if hostKeyCallback, err := instance.getSshHostKeyCallback(); err == nil {
return &ssh.ClientConfig{
User: instance.getSshUser(),
Auth: []ssh.AuthMethod{
ssh.PublicKeys(instance.sshKey),
},
Timeout: time.Duration(10) * time.Second,
HostKeyCallback: hostKeyCallback,
}, nil
} else {
return nil, err
}
}

Expand Down
38 changes: 36 additions & 2 deletions getters_test.go
Original file line number Diff line number Diff line change
@@ -1,24 +1,47 @@
package befehl

import (
"fmt"
"os"
"testing"
"time"

"github.com/sgsullivan/befehl/helpers/filesystem"
)

func getZeroValOpts() *Instance {
return New(&Options{
PrivateKeyFile: "",
SshUser: "",
LogDir: "",
SshHostKeyConfig: SshHostKeyConfig{
Enabled: false,
KnownHostsPath: "",
},
})
}

var defaultKnownHosts = os.Getenv("HOME") + "/.ssh/known_hosts"

func init() {
if !filesystem.FileExists(defaultKnownHosts) {
f, err := os.Create(defaultKnownHosts)
if err != nil {
panic(fmt.Sprintf("failed to create %s: %s", defaultKnownHosts, err))
}
f.Close()
}
}

func getNonZeroValOpts() *Instance {
return New(&Options{
PrivateKeyFile: "foo",
SshUser: "bar",
LogDir: "baz",
SshHostKeyConfig: SshHostKeyConfig{
Enabled: true,
KnownHostsPath: defaultKnownHosts,
},
})
}

Expand Down Expand Up @@ -57,12 +80,23 @@ func TestGetPrivKeyFile(t *testing.T) {
if getNonZeroValOpts().getPrivKeyFile() != "foo" {
t.Fatal("PrivateKeyFile for nonzeroval is unexpected")
}

}

func TestGetSshClientConfig(t *testing.T) {
got := getNonZeroValOpts().getSshClientConfig()
got, err := getNonZeroValOpts().getSshClientConfig()
if err != nil {
t.Fatal(err)
}
if got.Timeout != time.Duration(10)*time.Second {
t.Fatalf("returned timeout %s is unexpected", got.Timeout)
}
}

func TestGetSshKnowHostsPath(t *testing.T) {
if getZeroValOpts().getSshKnowHostsPath() != os.Getenv("HOME")+"/.ssh/known_hosts" {
t.Fatal("getSshKnowHostsPath for zeroval is unexpected")
}
if getNonZeroValOpts().getSshKnowHostsPath() != defaultKnownHosts {
t.Fatal("PrivateKeyFile for nonzeroval is unexpected")
}
}
5 changes: 5 additions & 0 deletions integration_tests/examples/hosts
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
127.0.0.1:1000
127.0.0.1:1001
127.0.0.1:1002
127.0.0.1:1003
127.0.0.1:1004
3 changes: 3 additions & 0 deletions integration_tests/examples/payload
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash

echo "Hello, world"
12 changes: 9 additions & 3 deletions types.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,16 @@ import (
"golang.org/x/crypto/ssh"
)

type SshHostKeyConfig struct {
Enabled bool
KnownHostsPath string
}

type Options struct {
PrivateKeyFile string
SshUser string
LogDir string
PrivateKeyFile string
SshUser string
LogDir string
SshHostKeyConfig SshHostKeyConfig
}

type Instance struct {
Expand Down

0 comments on commit ef66f96

Please sign in to comment.