Skip to content

Commit

Permalink
Merge branch 'release/0.8.1'
Browse files Browse the repository at this point in the history
  • Loading branch information
nbari committed Feb 4, 2017
2 parents 4549e25 + 3c91cc3 commit bf7ffb4
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 27 deletions.
9 changes: 9 additions & 0 deletions a_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package sshvault

import (
"os"
"reflect"
"runtime"
"testing"
"time"
)

/* Test Helpers */
Expand All @@ -13,3 +15,10 @@ func expect(t *testing.T, a interface{}, b interface{}) {
t.Errorf("Expected: %v (type %v) Got: %v (type %v) in %s:%d", a, reflect.TypeOf(a), b, reflect.TypeOf(b), fn, line)
}
}

// PtyWriteback
func PtyWriteback(pty *os.File, msg string) {
time.Sleep(500 * time.Millisecond)
defer pty.Sync()
pty.Write([]byte(msg))
}
15 changes: 8 additions & 7 deletions cmd/ssh-vault/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func exit1(err error) {

func main() {
var (
k = flag.String("k", "~/.ssh/id_rsa.pub", "public `ssh key or index` when using option -u")
k = flag.String("k", "~/.ssh/id_rsa.pub", "Public `ssh key or index` when using option -u")
u = flag.String("u", "", "GitHub `username or URL`, optional [-k N] where N is the key index to use")
f = flag.Bool("f", false, "Print ssh key `fingerprint`")
options = []string{"create", "edit", "view"}
Expand All @@ -32,9 +32,9 @@ func main() {
fmt.Fprintf(os.Stderr, "Usage: %s [-k key] [-u user] [create|edit|view] vault\n\n%s\n%s\n%s\n%s\n\n",
os.Args[0],
" Options:",
" create creates a new vault",
" edit edit an existing vault",
" view view an existing vault")
" create Creates a new vault",
" edit Edit an existing vault",
" view View an existing vault")
flag.PrintDefaults()
}

Expand All @@ -45,6 +45,10 @@ func main() {
os.Exit(0)
}

if flag.NArg() < 1 && !*f {
exit1(fmt.Errorf("Missing option, use (\"%s -h\") for help.\n", os.Args[0]))
}

usr, _ := user.Current()
if len(*k) > 2 {
if (*k)[:2] == "~/" {
Expand All @@ -69,9 +73,6 @@ func main() {
}

// check options
if flag.NArg() < 1 {
exit1(fmt.Errorf("Missing option, use (\"%s -h\") for help.\n", os.Args[0]))
}
exit := true
for _, v := range options {
if flag.Arg(0) == v {
Expand Down
4 changes: 2 additions & 2 deletions get_password_darwin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func TestKeychain(t *testing.T) {
}
defer DeleteKeychainPassword(key_path) // clean up

_, tty, err := pty.Open()
pty, tty, err := pty.Open()
if err != nil {
t.Errorf("Unable to open pty: %s", err)
}
Expand All @@ -77,7 +77,7 @@ func TestKeychain(t *testing.T) {
syscall.Dup2(int(tty.Fd()), int(syscall.Stdin))
syscall.Dup2(int(tty.Fd()), int(syscall.Stdout))

// go PtyWriteback(pty, key_bad_pw)
go PtyWriteback(pty, key_bad_pw)

key_pw_test, err := vault.GetPassword()

Expand Down
2 changes: 1 addition & 1 deletion vault.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func New(k, u, o, v string) (*vault, error) {
return nil, err
}
} else if !cache.IsFile(keyPath) {
return nil, fmt.Errorf("key not found or unable to read")
return nil, fmt.Errorf("SSH key %q not found or unable to read", keyPath)
}
if o == "create" {
if cache.IsFile(v) {
Expand Down
26 changes: 9 additions & 17 deletions vault_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,12 @@ import (
"strings"
"syscall"
"testing"
"time"

"github.com/kr/pty"
"github.com/ssh-vault/crypto"
"github.com/ssh-vault/crypto/aead"
)

// zomg this is a race condition
func PtyWriteback(pty *os.File, msg string) {
time.Sleep(500 * time.Millisecond)
defer pty.Sync()
pty.Write([]byte(msg))
}

// These are done in one function to avoid declaring global variables in a test
// file.
func TestVaultFunctions(t *testing.T) {
Expand All @@ -41,7 +33,7 @@ func TestVaultFunctions(t *testing.T) {
t.Error(err)
}

key_pw := string("argle-bargle\n")
keyPw := string("argle-bargle\n")
pty, tty, err := pty.Open()
if err != nil {
t.Errorf("Unable to open pty: %s", err)
Expand All @@ -51,23 +43,23 @@ func TestVaultFunctions(t *testing.T) {
// from stdin. For the test, we save stdin to a spare FD,
// point stdin at the file, run the system under test, and
// finally restore the original stdin
old_stdin, _ := syscall.Dup(int(syscall.Stdin))
old_stdout, _ := syscall.Dup(int(syscall.Stdout))
oldStdin, _ := syscall.Dup(int(syscall.Stdin))
oldStdout, _ := syscall.Dup(int(syscall.Stdout))
syscall.Dup2(int(tty.Fd()), int(syscall.Stdin))
syscall.Dup2(int(tty.Fd()), int(syscall.Stdout))

go PtyWriteback(pty, key_pw)
go PtyWriteback(pty, keyPw)

key_pw_test, err := vault.GetPasswordPrompt()
keyPwTest, err := vault.GetPasswordPrompt()

syscall.Dup2(old_stdin, int(syscall.Stdin))
syscall.Dup2(old_stdout, int(syscall.Stdout))
syscall.Dup2(oldStdin, int(syscall.Stdin))
syscall.Dup2(oldStdout, int(syscall.Stdout))

if err != nil {
t.Error(err)
}
if string(strings.Trim(key_pw, "\n")) != string(key_pw_test) {
t.Errorf("password prompt: expected %s, got %s\n", key_pw, key_pw_test)
if string(strings.Trim(keyPw, "\n")) != string(keyPwTest) {
t.Errorf("password prompt: expected %s, got %s\n", keyPw, keyPwTest)
}

if err = vault.PKCS8(); err != nil {
Expand Down

0 comments on commit bf7ffb4

Please sign in to comment.