diff --git a/internal_test.go b/internal_test.go index 979b0f9..db3c75e 100644 --- a/internal_test.go +++ b/internal_test.go @@ -4,34 +4,85 @@ package tskagent import ( + "crypto" + "crypto/ed25519" + crand "crypto/rand" + "crypto/rsa" + "encoding/pem" "testing" + "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/agent" - - _ "embed" ) var _ agent.Agent = &Server{} -// The test data key is a throwaway generated for testing, and is not used -// anywhere else. To generate a new test key, run: -// -// ssh-keygen -C "Dummy key for testing" -t ed25519 -f testdata/test.key - -//go:embed testdata/test.key -var testPrivKey []byte - func TestKeyParse(t *testing.T) { - key, err := parseStoredKey("foo", 1, testPrivKey) - if err != nil { - t.Fatalf("Parsing stored key: %v", err) + tests := []struct { + name string + input []byte + comment string + keyType string + }{ + { + name: "ED2559/Comment", + input: mustGenerateKey(t, genED25519, "elliptic justice"), + comment: "elliptic justice", + keyType: "ssh-ed25519", + }, + { + name: "ED2559/NoComment", + input: mustGenerateKey(t, genED25519, ""), + comment: "", + keyType: "ssh-ed25519", + }, + { + name: "RSA/Comment", + input: mustGenerateKey(t, genRSA, "what year is it"), + comment: "what year is it", + keyType: "ssh-rsa", + }, + { + name: "RSA/NoComment", + input: mustGenerateKey(t, genRSA, ""), + comment: "", + keyType: "ssh-rsa", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + key, err := parseStoredKey(tc.name, 1, tc.input) + if err != nil { + t.Fatalf("parsing stored key: %v", err) + } + if key.Comment != tc.comment { + t.Errorf("Comment: got %q, want %q", key.Comment, tc.comment) + } + if got := key.Signer.PublicKey().Type(); got != tc.keyType { + t.Errorf("Key type: got %q, want %q", got, tc.keyType) + } + }) } +} - const wantComment = "Dummy key for testing" - if key.Comment != wantComment { - t.Errorf("Comment: got %q, want %q", key.Comment, wantComment) +func mustGenerateKey(t *testing.T, gen func() (crypto.PrivateKey, error), comment string) []byte { + t.Helper() + key, err := gen() + if err != nil { + t.Fatalf("Generating key: %v", err) } - if got, want := key.Signer.PublicKey().Type(), "ssh-ed25519"; got != want { - t.Errorf("Key type: got %q, want %q", got, want) + enc, err := ssh.MarshalPrivateKey(key, comment) + if err != nil { + t.Fatalf("Marshaling key: %v", err) } + return pem.EncodeToMemory(enc) +} + +func genED25519() (crypto.PrivateKey, error) { + _, key, err := ed25519.GenerateKey(crand.Reader) + return key, err +} + +func genRSA() (crypto.PrivateKey, error) { + return rsa.GenerateKey(crand.Reader, 1024) } diff --git a/tskagent.go b/tskagent.go index 2d3e1dc..f44064c 100644 --- a/tskagent.go +++ b/tskagent.go @@ -303,43 +303,121 @@ func parseStoredKey(name string, version api.SecretVersion, data []byte) (*sshKe func parseComment(key []byte) string { blk, _ := pem.Decode(key) - // The OpenSSH key format begins with a header followed by a public and a - // private key. Cut off the headers and skip the public key to find the - // private key, where the comment resides. The header is separated from the - // keys by a hard-coded uint32 key count of 1 (big-endian). - _, keys, ok := bytes.Cut(blk.Bytes, []byte("\x00\x00\x00\x01")) - if !ok { + // See: https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.key + s := newScanner(blk.Bytes) + + // Check magic format header. + if err := s.scanLiteral("openssh-key-v1\x00"); err != nil { + return "" // not a key file, or some antique version + } + cipher, err := s.scanString() + if err != nil || string(cipher) != "none" { + return "" // encrypted contents, we can't read them + } + // Skip kdfname, kdfoptions, which we don't care about. + if err := s.skipStrings(2); err != nil { + return "" + } + // The next field is the number of keys. This could in theory be any value, + // but OpenSSH hardcodes it to 1. + if nk, err := s.scanUint32(); err != nil || nk != 1 { + return "" + } + // Skip the public keys, as the comment (if any) is with the private key. + if err := s.skipStrings(1); err != nil { return "" } - // Skip the public key... - pubLen := int(binary.BigEndian.Uint32(keys)) - if 4+pubLen > len(keys) { + // The rest of the packet should be a bundle of private keys. + // Because we know cipher is "none", it is plaintext, but there may + // be some padding at the end. + pkeys, err := s.scanString() + if err != nil { return "" } - keys = keys[4+pubLen:] - // Extract the private key... - privLen := int(binary.BigEndian.Uint32(keys)) - if 4+privLen > len(keys) { + pk := newScanner(pkeys) + // Skip the two 32-bit validity nonces. + if err := pk.skipBytes(8); err != nil { return "" } + // The rest of the bundle depends on what type of key this is, but + // the last string field will be the comment (if any). + var last string + for !pk.atEOF() { + s, err := pk.scanString() + if err != nil { + break + } + last = string(s) + } + return last +} + +// A scanner is a minimal scanner for a slice of bytes representing an OpenSSH +// key file. The methods of this type alias (but do not modify) the input. +type scanner struct { + buf []byte +} + +func newScanner(data []byte) *scanner { + return &scanner{buf: data} +} - // Remove padding at end (pad bytes are 0x01-0x07) - for n := len(keys) - 1; keys[n] < 0x08; n-- { - keys = keys[:n] +// atEOF reports whether s has any further contents. +func (s *scanner) atEOF() bool { return len(s.buf) == 0 } + +// skipBytes advances past the first n bytes of the input. +func (s *scanner) skipBytes(n int) error { + if len(s.buf) < n { + return fmt.Errorf("got %d bytes, want %d", len(s.buf), n) } - keys = keys[4:] // remove length prefix (checked above) - keys = keys[8:] // remove checksum (not used) - - // The comment is the last length-prefixed field of the private key. - // Skip past all the others. - for len(keys) >= 4 { - n := int(binary.BigEndian.Uint32(keys)) - if 4+n == len(keys) { - return string(keys[4:]) + s.buf = s.buf[n:] + return nil +} + +// skipStrings advances past the next n length-prefixed strings. +func (s *scanner) skipStrings(n int) error { + for n > 0 { + if _, err := s.scanString(); err != nil { + return err } - keys = keys[4+n:] + n-- + } + return nil +} + +// scanLiteral advances past the specified prefix of the input. +func (s *scanner) scanLiteral(want string) error { + rest, ok := bytes.CutPrefix(s.buf, []byte(want)) + if !ok { + return fmt.Errorf("missing %q", want) } - return "" + s.buf = rest + return nil +} + +// scanString consumes and returns a length-prefixed string. +func (s *scanner) scanString() ([]byte, error) { + n32, err := s.scanUint32() + if err != nil { + return nil, err + } + n := int(n32) + if n > len(s.buf) { + return nil, fmt.Errorf("got %d bytes, want %d", len(s.buf), n) + } + out := s.buf[:n] + s.buf = s.buf[n:] + return out, nil +} + +// scanUint32 consumes and returns a big-endian 32-bit integer. +func (s *scanner) scanUint32() (uint32, error) { + if len(s.buf) < 4 { + return 0, fmt.Errorf("got %d bytes, want 4", len(s.buf)) + } + out := binary.BigEndian.Uint32(s.buf) + s.buf = s.buf[4:] + return out, nil }