diff --git a/cmd/spot/main.go b/cmd/spot/main.go index 0951978c..f0a84c34 100644 --- a/cmd/spot/main.go +++ b/cmd/spot/main.go @@ -315,7 +315,7 @@ func makePlaybook(opts options, inventory string) (*config.PlayBook, error) { } func makeRunner(opts options, pbook *config.PlayBook) (*runner.Process, error) { - sshKey, err := sshKey(opts.SSHKey, pbook) + sshKey, err := sshKey(opts.SSHAgent, opts.SSHKey, pbook) if err != nil { return nil, fmt.Errorf("can't get ssh key: %w", err) } @@ -381,7 +381,7 @@ func targetsForTask(targets []string, taskName string, pbook runner.Playbook) [] } // get ssh key from cli or playbook. if no key is provided, use default ~/.ssh/id_rsa -func sshKey(sshKey string, pbook *config.PlayBook) (key string, err error) { +func sshKey(sshAgent bool, sshKey string, pbook *config.PlayBook) (key string, err error) { if sshKey == "" && (pbook == nil || pbook.SSHKey != "") { // no key provided in cli sshKey = pbook.SSHKey // use playbook's ssh_key } @@ -394,7 +394,9 @@ func sshKey(sshKey string, pbook *config.PlayBook) (key string, err error) { if err != nil { return "", fmt.Errorf("can't get current user: %w", err) } - sshKey = filepath.Join(u.HomeDir, ".ssh", "id_rsa") + if !sshAgent { + sshKey = filepath.Join(u.HomeDir, ".ssh", "id_rsa") + } } log.Printf("[INFO] ssh key: %s", sshKey) diff --git a/cmd/spot/main_test.go b/cmd/spot/main_test.go index df69e090..81f1d33f 100644 --- a/cmd/spot/main_test.go +++ b/cmd/spot/main_test.go @@ -420,6 +420,21 @@ func Test_sshUserAndKey(t *testing.T) { expectedUser: osUser.Username, expectedKey: filepath.Join(osUser.HomeDir, ".ssh", "id_rsa"), }, + { + name: "SSHAgent set no key in playbook and command line", + opts: options{ + TaskNames: []string{"test_task"}, + SSHUser: "cmd_user", + SSHAgent: true, + }, + conf: config.PlayBook{ + Tasks: []config.Task{ + {Name: "test_task"}, + }, + }, + expectedUser: "cmd_user", + expectedKey: "", + }, { name: "tilde expansion in key path", opts: options{ @@ -441,7 +456,7 @@ func Test_sshUserAndKey(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - key, err := sshKey(tc.opts.SSHKey, &tc.conf) + key, err := sshKey(tc.opts.SSHAgent, tc.opts.SSHKey, &tc.conf) require.NoError(t, err, "sshKey should not return an error") assert.Equal(t, tc.expectedKey, key, "key should match expected key") sshUser, err := sshUser(tc.opts.SSHUser, &tc.conf)