diff --git a/go.mod b/go.mod index 018971e..dc85298 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/kevinburke/ssh_config v1.2.0 github.com/manifoldco/promptui v0.9.0 github.com/pkg/errors v0.9.1 + github.com/skeema/knownhosts v1.2.1 github.com/spf13/cobra v1.8.0 github.com/stretchr/testify v1.8.1 github.com/vimiix/tablewriter v0.0.0-20231207073205-aad9e2006284 diff --git a/go.sum b/go.sum index cccf7e1..43c7207 100644 --- a/go.sum +++ b/go.sum @@ -33,6 +33,8 @@ github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJ github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis= github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/skeema/knownhosts v1.2.1 h1:SHWdIUa82uGZz+F+47k8SY4QhhI291cXCpopT1lK2AQ= +github.com/skeema/knownhosts v1.2.1/go.mod h1:xYbVRSPxqBZFrdmDyMmsOs+uX1UZC3nTN3ThzgDxUwo= github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0= github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= diff --git a/ssx/client.go b/ssx/client.go index c4c79ab..8d666be 100644 --- a/ssx/client.go +++ b/ssx/client.go @@ -153,7 +153,10 @@ func dialContext(ctx context.Context, network, addr string, config *ssh.ClientCo func (c *Client) login(ctx context.Context) error { network := "tcp" addr := net.JoinHostPort(c.entry.Host, c.entry.Port) - clientConfig := c.entry.GenSSHConfig() + clientConfig, err := c.entry.GenSSHConfig() + if err != nil { + return err + } lg.Info("connecting to %s", c.entry.String()) cli, err := dialContext(ctx, network, addr, clientConfig) if err == nil { diff --git a/ssx/entry/entry.go b/ssx/entry/entry.go index a9201f4..9ba3c60 100644 --- a/ssx/entry/entry.go +++ b/ssx/entry/entry.go @@ -3,11 +3,15 @@ package entry import ( "bufio" "fmt" + "log" + "net" "os" "os/user" "path/filepath" "time" + "github.com/pkg/errors" + "github.com/skeema/knownhosts" "golang.org/x/crypto/ssh" "github.com/vimiix/ssx/internal/lg" @@ -58,15 +62,58 @@ func getConnectTimeout() time.Duration { return d } -func (e *Entry) GenSSHConfig() *ssh.ClientConfig { +func (e *Entry) GenSSHConfig() (*ssh.ClientConfig, error) { + cb, err := e.sshHostKeyCallback() + if err != nil { + return nil, err + } cfg := &ssh.ClientConfig{ User: e.User, Auth: e.AuthMethods(), - HostKeyCallback: ssh.InsecureIgnoreHostKey(), + HostKeyCallback: cb, Timeout: getConnectTimeout(), } cfg.SetDefaults() - return cfg + return cfg, nil +} + +func (e *Entry) sshHostKeyCallback() (ssh.HostKeyCallback, error) { + khPath := utils.ExpandHomeDir("~/.ssh/known_hosts") + if !utils.FileExists(khPath) { + f, err := os.OpenFile(khPath, os.O_RDWR|os.O_CREATE, 0600) + if err != nil { + return nil, err + } + _ = f.Close() + } + kh, err := knownhosts.New(khPath) + if err != nil { + lg.Error("failed to read known_hosts: ", err) + return nil, err + } + // Create a custom permissive hostkey callback which still errors on hosts + // with changed keys, but allows unknown hosts and adds them to known_hosts + cb := ssh.HostKeyCallback(func(hostname string, remote net.Addr, key ssh.PublicKey) error { + err := kh(hostname, remote, key) + if knownhosts.IsHostKeyChanged(err) { + lg.Error("REMOTE HOST IDENTIFICATION HAS CHANGED for host %s! This may indicate a MitM attack.", hostname) + return errors.Errorf("host key changed for host %s", hostname) + } else if knownhosts.IsHostUnknown(err) { + f, ferr := os.OpenFile(khPath, os.O_APPEND|os.O_WRONLY, 0600) + if ferr == nil { + defer f.Close() + ferr = knownhosts.WriteKnownHost(f, hostname, remote, key) + } + if ferr == nil { + log.Printf("Added host %s to known_hosts\n", hostname) + } else { + log.Printf("Failed to add host %s to known_hosts: %v\n", hostname, ferr) + } + return nil + } + return err + }) + return cb, nil } func (e *Entry) Tidy() error {