Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 69 additions & 18 deletions internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,85 @@
package tskagent

import (
"crypto"
"crypto/ed25519"
crand "crypto/rand"
"crypto/rsa"
"encoding/pem"
"testing"

"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"

_ "embed"
)

var _ agent.Agent = &Server{}

// The test data key is a throwaway generated for testing, and is not used
// anywhere else. To generate a new test key, run:
//
// ssh-keygen -C "Dummy key for testing" -t ed25519 -f testdata/test.key

//go:embed testdata/test.key
var testPrivKey []byte

func TestKeyParse(t *testing.T) {
key, err := parseStoredKey("foo", 1, testPrivKey)
if err != nil {
t.Fatalf("Parsing stored key: %v", err)
tests := []struct {
name string
input []byte
comment string
keyType string
}{
{
name: "ED2559/Comment",
input: mustGenerateKey(t, genED25519, "elliptic justice"),
comment: "elliptic justice",
keyType: "ssh-ed25519",
},
{
name: "ED2559/NoComment",
input: mustGenerateKey(t, genED25519, ""),
comment: "",
keyType: "ssh-ed25519",
},
{
name: "RSA/Comment",
input: mustGenerateKey(t, genRSA, "what year is it"),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤣

comment: "what year is it",
keyType: "ssh-rsa",
},
{
name: "RSA/NoComment",
input: mustGenerateKey(t, genRSA, ""),
comment: "",
keyType: "ssh-rsa",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
key, err := parseStoredKey(tc.name, 1, tc.input)
if err != nil {
t.Fatalf("parsing stored key: %v", err)
}
if key.Comment != tc.comment {
t.Errorf("Comment: got %q, want %q", key.Comment, tc.comment)
}
if got := key.Signer.PublicKey().Type(); got != tc.keyType {
t.Errorf("Key type: got %q, want %q", got, tc.keyType)
}
})
}
}

const wantComment = "Dummy key for testing"
if key.Comment != wantComment {
t.Errorf("Comment: got %q, want %q", key.Comment, wantComment)
func mustGenerateKey(t *testing.T, gen func() (crypto.PrivateKey, error), comment string) []byte {
t.Helper()
key, err := gen()
if err != nil {
t.Fatalf("Generating key: %v", err)
}
if got, want := key.Signer.PublicKey().Type(), "ssh-ed25519"; got != want {
t.Errorf("Key type: got %q, want %q", got, want)
enc, err := ssh.MarshalPrivateKey(key, comment)
if err != nil {
t.Fatalf("Marshaling key: %v", err)
}
return pem.EncodeToMemory(enc)
}

func genED25519() (crypto.PrivateKey, error) {
_, key, err := ed25519.GenerateKey(crand.Reader)
return key, err
}

func genRSA() (crypto.PrivateKey, error) {
return rsa.GenerateKey(crand.Reader, 1024)
}
132 changes: 105 additions & 27 deletions tskagent.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,43 +303,121 @@ func parseStoredKey(name string, version api.SecretVersion, data []byte) (*sshKe
func parseComment(key []byte) string {
blk, _ := pem.Decode(key)

// The OpenSSH key format begins with a header followed by a public and a
// private key. Cut off the headers and skip the public key to find the
// private key, where the comment resides. The header is separated from the
// keys by a hard-coded uint32 key count of 1 (big-endian).
_, keys, ok := bytes.Cut(blk.Bytes, []byte("\x00\x00\x00\x01"))
if !ok {
// See: https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.key
s := newScanner(blk.Bytes)

// Check magic format header.
if err := s.scanLiteral("openssh-key-v1\x00"); err != nil {
return "" // not a key file, or some antique version
}
cipher, err := s.scanString()
if err != nil || string(cipher) != "none" {
return "" // encrypted contents, we can't read them
}
// Skip kdfname, kdfoptions, which we don't care about.
if err := s.skipStrings(2); err != nil {
return ""
}
// The next field is the number of keys. This could in theory be any value,
// but OpenSSH hardcodes it to 1.
if nk, err := s.scanUint32(); err != nil || nk != 1 {
return ""
}
// Skip the public keys, as the comment (if any) is with the private key.
if err := s.skipStrings(1); err != nil {
return ""
}

// Skip the public key...
pubLen := int(binary.BigEndian.Uint32(keys))
if 4+pubLen > len(keys) {
// The rest of the packet should be a bundle of private keys.
// Because we know cipher is "none", it is plaintext, but there may
// be some padding at the end.
pkeys, err := s.scanString()
if err != nil {
return ""
}
keys = keys[4+pubLen:]

// Extract the private key...
privLen := int(binary.BigEndian.Uint32(keys))
if 4+privLen > len(keys) {
pk := newScanner(pkeys)
// Skip the two 32-bit validity nonces.
if err := pk.skipBytes(8); err != nil {
return ""
}
Comment thread
andrew-d marked this conversation as resolved.
// The rest of the bundle depends on what type of key this is, but
// the last string field will be the comment (if any).
var last string
for !pk.atEOF() {
s, err := pk.scanString()
if err != nil {
break
}
last = string(s)
}
return last
}

// A scanner is a minimal scanner for a slice of bytes representing an OpenSSH
// key file. The methods of this type alias (but do not modify) the input.
type scanner struct {
buf []byte
}

func newScanner(data []byte) *scanner {
return &scanner{buf: data}
}

// Remove padding at end (pad bytes are 0x01-0x07)
for n := len(keys) - 1; keys[n] < 0x08; n-- {
keys = keys[:n]
// atEOF reports whether s has any further contents.
func (s *scanner) atEOF() bool { return len(s.buf) == 0 }

// skipBytes advances past the first n bytes of the input.
func (s *scanner) skipBytes(n int) error {
if len(s.buf) < n {
return fmt.Errorf("got %d bytes, want %d", len(s.buf), n)
}
keys = keys[4:] // remove length prefix (checked above)
keys = keys[8:] // remove checksum (not used)

// The comment is the last length-prefixed field of the private key.
// Skip past all the others.
for len(keys) >= 4 {
n := int(binary.BigEndian.Uint32(keys))
if 4+n == len(keys) {
return string(keys[4:])
s.buf = s.buf[n:]
return nil
}

// skipStrings advances past the next n length-prefixed strings.
func (s *scanner) skipStrings(n int) error {
for n > 0 {
if _, err := s.scanString(); err != nil {
return err
}
keys = keys[4+n:]
n--
}
return nil
}

// scanLiteral advances past the specified prefix of the input.
func (s *scanner) scanLiteral(want string) error {
rest, ok := bytes.CutPrefix(s.buf, []byte(want))
if !ok {
return fmt.Errorf("missing %q", want)
}
return ""
s.buf = rest
return nil
}

// scanString consumes and returns a length-prefixed string.
func (s *scanner) scanString() ([]byte, error) {
n32, err := s.scanUint32()
if err != nil {
return nil, err
}
n := int(n32)
if n > len(s.buf) {
return nil, fmt.Errorf("got %d bytes, want %d", len(s.buf), n)
}
out := s.buf[:n]
s.buf = s.buf[n:]
return out, nil
}

// scanUint32 consumes and returns a big-endian 32-bit integer.
func (s *scanner) scanUint32() (uint32, error) {
if len(s.buf) < 4 {
return 0, fmt.Errorf("got %d bytes, want 4", len(s.buf))
}
out := binary.BigEndian.Uint32(s.buf)
s.buf = s.buf[4:]
return out, nil
}
Loading