diff --git a/kms/yubikey/yubikey.go b/kms/yubikey/yubikey.go index 357c5bbf..5cb8db55 100644 --- a/kms/yubikey/yubikey.go +++ b/kms/yubikey/yubikey.go @@ -9,6 +9,7 @@ import ( "crypto/x509" "encoding/asn1" "encoding/hex" + "fmt" "io" "net/url" "strconv" @@ -42,6 +43,7 @@ type pivKey interface { GenerateKey(key [24]byte, slot piv.Slot, opts piv.Key) (crypto.PublicKey, error) PrivateKey(slot piv.Slot, public crypto.PublicKey, auth piv.KeyAuth) (crypto.PrivateKey, error) Attest(slot piv.Slot) (*x509.Certificate, error) + Serial() (uint32, error) Close() error } @@ -141,8 +143,8 @@ func New(_ context.Context, opts apiv1.Options) (*YubiKey, error) { // Attempt to locate the yubikey with the given serial. for _, name := range cards { if k, err := openCard(name); err == nil { - if cert, err := k.Attest(piv.SlotAuthentication); err == nil { - if serial == getSerialNumber(cert) { + if s, err := k.Serial(); err == nil { + if serial == strconv.FormatUint(uint64(s), 10) { yk = k card = name break @@ -353,10 +355,22 @@ func (k *YubiKey) CreateAttestation(req *apiv1.CreateAttestationRequest) (*apiv1 Certificate: cert, CertificateChain: []*x509.Certificate{cert, intermediate}, PublicKey: cert.PublicKey, - PermanentIdentifier: getSerialNumber(cert), + PermanentIdentifier: getAttestedSerial(cert), }, nil } +// Serial returns the serial number of the PIV card or and empty +// string if retrieval fails +func (k *YubiKey) Serial() (string, error) { + serial, err := k.yk.Serial() + + if err != nil { + return "", fmt.Errorf("error getting Yubikey's serial number: %w", err) + } + + return strconv.FormatUint(uint64(serial), 10), nil +} + // Close releases the connection to the YubiKey. func (k *YubiKey) Close() error { if err := k.yk.Close(); err != nil { @@ -505,10 +519,10 @@ func getPolicies(req *apiv1.CreateKeyRequest) (piv.PINPolicy, piv.TouchPolicy) { return pin, touch } -// getSerialNumber returns the serial number from an attestation certificate. It +// getAttestedSerial returns the serial number from an attestation certificate. It // will return an empty string if the serial number extension does not exist // or if it is malformed. -func getSerialNumber(cert *x509.Certificate) string { +func getAttestedSerial(cert *x509.Certificate) string { for _, ext := range cert.Extensions { if ext.Id.Equal(oidYubicoSerialNumber) { var serialNumber int @@ -519,6 +533,7 @@ func getSerialNumber(cert *x509.Certificate) string { return strconv.Itoa(serialNumber) } } + return "" } diff --git a/kms/yubikey/yubikey_test.go b/kms/yubikey/yubikey_test.go index 369a3dd1..8052f87b 100644 --- a/kms/yubikey/yubikey_test.go +++ b/kms/yubikey/yubikey_test.go @@ -36,6 +36,8 @@ type stubPivKey struct { certMap map[piv.Slot]*x509.Certificate signerMap map[piv.Slot]interface{} keyOptionsMap map[piv.Slot]piv.Key + serial uint32 + serialErr error closeErr error } @@ -93,7 +95,8 @@ func newStubPivKey(t *testing.T, alg symmetricAlgorithm) *stubPivKey { t.Fatal(errors.New("unknown alg")) } - serialNumber, err := asn1.Marshal(112233) + sn := 112233 + snAsn1, err := asn1.Marshal(sn) if err != nil { t.Fatal(err) } @@ -101,7 +104,7 @@ func newStubPivKey(t *testing.T, alg symmetricAlgorithm) *stubPivKey { Subject: pkix.Name{CommonName: "attested certificate"}, PublicKey: attSigner.Public(), ExtraExtensions: []pkix.Extension{ - {Id: oidYubicoSerialNumber, Value: serialNumber}, + {Id: oidYubicoSerialNumber, Value: snAsn1}, }, }) if err != nil { @@ -132,6 +135,7 @@ func newStubPivKey(t *testing.T, alg symmetricAlgorithm) *stubPivKey { piv.SlotSignature: userSigner, // 9c }, keyOptionsMap: map[piv.Slot]piv.Key{}, + serial: uint32(sn), } } @@ -220,6 +224,13 @@ func (s *stubPivKey) Close() error { return s.closeErr } +func (s *stubPivKey) Serial() (uint32, error) { + if s.serialErr != nil { + return 0, s.serialErr + } + return s.serial, nil +} + func TestRegister(t *testing.T) { pCards := pivCards t.Cleanup(func() { @@ -1029,6 +1040,37 @@ func TestYubiKey_CreateAttestation(t *testing.T) { } } +func TestYubiKey_Serial(t *testing.T) { + yk1 := newStubPivKey(t, RSA) + yk2 := newStubPivKey(t, RSA) + yk2.serialErr = errors.New("some error") + + tests := []struct { + name string + yk pivKey + want string + wantErr bool + }{ + {"ok", yk1, "112233", false}, + {"fail", yk2, "", true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := &YubiKey{ + yk: tt.yk, + } + got, err := k.Serial() + if (err != nil) != tt.wantErr { + t.Errorf("YubiKey.Serial() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("YubiKey.Serial() = %v, want %v", got, tt.want) + } + }) + } +} + func TestYubiKey_Close(t *testing.T) { yk1 := newStubPivKey(t, ECDSA) yk2 := newStubPivKey(t, RSA) @@ -1061,7 +1103,7 @@ func TestYubiKey_Close(t *testing.T) { } } -func Test_getSerialNumber(t *testing.T) { +func Test_getAttestedSerial(t *testing.T) { serialNumber, err := asn1.Marshal(112233) if err != nil { t.Fatal(err) @@ -1107,8 +1149,8 @@ func Test_getSerialNumber(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := getSerialNumber(tt.args.cert); got != tt.want { - t.Errorf("getSerialNumber() = %v, want %v", got, tt.want) + if got := getAttestedSerial(tt.args.cert); got != tt.want { + t.Errorf("getAttestedSerial() = %v, want %v", got, tt.want) } }) }