Skip to content

Commit

Permalink
ssh key parsing to separate file
Browse files Browse the repository at this point in the history
  • Loading branch information
sgsullivan committed Aug 4, 2022
1 parent a9da272 commit f6a7529
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 59 deletions.
57 changes: 0 additions & 57 deletions befehl.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ package befehl
import (
"bufio"
"bytes"
"crypto/x509"
"encoding/pem"
"fmt"
"log"
"os"
Expand All @@ -14,7 +12,6 @@ import (
"golang.org/x/crypto/ssh"

"github.com/fatih/color"
"github.com/howeyc/gopass"

"github.com/sgsullivan/befehl/helpers/filesystem"
"github.com/sgsullivan/befehl/helpers/waitgroup"
Expand All @@ -39,60 +36,6 @@ func (instance *Instance) Execute(hostsFile, payload string, routines int) error
}
}

func (instance *Instance) populateSshKeyEncrypted(privKeyBytes *pem.Block) error {
fmt.Printf("enter private key password: ")
password, err := gopass.GetPasswd()
if err != nil {
return fmt.Errorf("error when reading input: %v", err)
}

pwBuf, err := x509.DecryptPEMBlock(privKeyBytes, []byte(password))
if err != nil {
return fmt.Errorf("x509.DecryptPEMBlock failed: %v", err)
}

pk, err := x509.ParsePKCS1PrivateKey(pwBuf)
if err != nil {
return fmt.Errorf("x509.ParsePKCS1PrivateKey failed: %v", err)
}

signer, err := ssh.NewSignerFromKey(pk)
if err != nil {
return fmt.Errorf("ssh.NewSignerFromKey failed: %v", err)
}

instance.sshKey = signer

return nil
}

func (instance *Instance) populateSshKeyUnencrypted(rawKey []byte) error {
signer, err := ssh.ParsePrivateKey(rawKey)
if err != nil {
return fmt.Errorf("unable to parse private key: %v", err)
}

instance.sshKey = signer

return nil
}

func (instance *Instance) populateSshKey() error {
privKeyFile := instance.getPrivKeyFile()

if rawKey, readFileError := filesystem.ReadFile(privKeyFile); readFileError == nil {
privKeyBytes, _ := pem.Decode(rawKey)

if x509.IsEncryptedPEMBlock(privKeyBytes) {
return instance.populateSshKeyEncrypted(privKeyBytes)
} else {
return instance.populateSshKeyUnencrypted(rawKey)
}
} else {
return readFileError
}
}

func (instance *Instance) buildHostLists(hostsFilePath string) ([]string, error) {
hostsFile, err := os.Open(hostsFilePath)
if err != nil {
Expand Down
67 changes: 67 additions & 0 deletions crypto.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package befehl

import (
"crypto/x509"
"encoding/pem"
"fmt"

"golang.org/x/crypto/ssh"

"github.com/howeyc/gopass"

"github.com/sgsullivan/befehl/helpers/filesystem"
)

func (instance *Instance) populateSshKeyEncrypted(privKeyBytes *pem.Block) error {
fmt.Printf("enter private key password: ")
password, err := gopass.GetPasswd()
if err != nil {
return fmt.Errorf("error when reading input: %v", err)
}

pwBuf, err := x509.DecryptPEMBlock(privKeyBytes, []byte(password))
if err != nil {
return fmt.Errorf("x509.DecryptPEMBlock failed: %v", err)
}

pk, err := x509.ParsePKCS1PrivateKey(pwBuf)
if err != nil {
return fmt.Errorf("x509.ParsePKCS1PrivateKey failed: %v", err)
}

signer, err := ssh.NewSignerFromKey(pk)
if err != nil {
return fmt.Errorf("ssh.NewSignerFromKey failed: %v", err)
}

instance.sshKey = signer

return nil
}

func (instance *Instance) populateSshKeyUnencrypted(rawKey []byte) error {
signer, err := ssh.ParsePrivateKey(rawKey)
if err != nil {
return fmt.Errorf("unable to parse private key: %v", err)
}

instance.sshKey = signer

return nil
}

func (instance *Instance) populateSshKey() error {
privKeyFile := instance.getPrivKeyFile()

if rawKey, readFileError := filesystem.ReadFile(privKeyFile); readFileError == nil {
privKeyBytes, _ := pem.Decode(rawKey)

if x509.IsEncryptedPEMBlock(privKeyBytes) {
return instance.populateSshKeyEncrypted(privKeyBytes)
} else {
return instance.populateSshKeyUnencrypted(rawKey)
}
} else {
return readFileError
}
}
4 changes: 2 additions & 2 deletions getters_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func TestGetPrivKeyFile(t *testing.T) {

func TestGetSshClientConfig(t *testing.T) {
got := getNonZeroValOpts().getSshClientConfig()
if got.Timeout != time.Duration(10) * time.Second {
if got.Timeout != time.Duration(10)*time.Second {
t.Fatalf("returned timeout %s is unexpected", got.Timeout)
}
}
}

0 comments on commit f6a7529

Please sign in to comment.