Skip to content

Commit

Permalink
Add COSE_Key support (#146)
Browse files Browse the repository at this point in the history
Signed-off-by: Sergei Trofimov <sergei.trofimov@arm.com>
  • Loading branch information
setrofim committed Jun 29, 2023
1 parent 354ac99 commit a579021
Show file tree
Hide file tree
Showing 10 changed files with 1,774 additions and 53 deletions.
3 changes: 2 additions & 1 deletion .github/.codecov.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
coverage:
status:
patch: off
project:
default:
target: 89%
target: 89%
39 changes: 38 additions & 1 deletion algorithm.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cose

import (
"crypto"
"fmt"
"strconv"
)

Expand Down Expand Up @@ -36,10 +37,12 @@ const (

// PureEdDSA by RFC 8152.
AlgorithmEd25519 Algorithm = -8

// An invalid/unrecognised algorithm.
AlgorithmInvalid Algorithm = 0
)

// Algorithm represents an IANA algorithm entry in the COSE Algorithms registry.
// Algorithms with string values are not supported.
//
// # See Also
//
Expand Down Expand Up @@ -72,6 +75,35 @@ func (a Algorithm) String() string {
}
}

// MarshalCBOR marshals the Algorithm as a CBOR int.
func (a Algorithm) MarshalCBOR() ([]byte, error) {
return encMode.Marshal(int64(a))
}

// UnmarshalCBOR populates the Algorithm from the provided CBOR value (must be
// int or tstr).
func (a *Algorithm) UnmarshalCBOR(data []byte) error {
var raw intOrStr

if err := raw.UnmarshalCBOR(data); err != nil {
return fmt.Errorf("invalid algorithm value: %w", err)
}

if raw.IsString() {
v := algorithmFromString(raw.String())
if v == AlgorithmInvalid {
return fmt.Errorf("unknown algorithm value %q", raw.String())
}

*a = v
} else {
v := raw.Int()
*a = Algorithm(v)
}

return nil
}

// hashFunc returns the hash associated with the algorithm supported by this
// library.
func (a Algorithm) hashFunc() crypto.Hash {
Expand Down Expand Up @@ -103,3 +135,8 @@ func computeHash(h crypto.Hash, data []byte) ([]byte, error) {
}
return hh.Sum(nil), nil
}

// NOTE: there are currently no registered string values for an algorithm.
func algorithmFromString(v string) Algorithm {
return AlgorithmInvalid
}
52 changes: 17 additions & 35 deletions algorithm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,41 +16,6 @@ func TestAlgorithm_String(t *testing.T) {
alg Algorithm
want string
}{
{
name: "PS256",
alg: AlgorithmPS256,
want: "PS256",
},
{
name: "PS384",
alg: AlgorithmPS384,
want: "PS384",
},
{
name: "PS512",
alg: AlgorithmPS512,
want: "PS512",
},
{
name: "ES256",
alg: AlgorithmES256,
want: "ES256",
},
{
name: "ES384",
alg: AlgorithmES384,
want: "ES384",
},
{
name: "ES512",
alg: AlgorithmES512,
want: "ES512",
},
{
name: "Ed25519",
alg: AlgorithmEd25519,
want: "EdDSA",
},
{
name: "unknown algorithm",
alg: 0,
Expand All @@ -66,6 +31,23 @@ func TestAlgorithm_String(t *testing.T) {
}
}

func TestAlgorithm_CBOR(t *testing.T) {
tvs2 := []struct {
Data []byte
ExpectedError string
}{
{[]byte{0x63, 0x66, 0x6f, 0x6f}, "unknown algorithm value \"foo\""},
{[]byte{0x40}, "invalid algorithm value: must be int or string, found []uint8"},
}

for _, tv := range tvs2 {
var a Algorithm

err := a.UnmarshalCBOR(tv.Data)
assertEqualError(t, err, tv.ExpectedError)
}
}

func TestAlgorithm_computeHash(t *testing.T) {
// run tests
data := []byte("hello world")
Expand Down
96 changes: 96 additions & 0 deletions common.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package cose

import (
"errors"
"fmt"
)

// intOrStr is a value that can be either an int or a tstr when serialized to
// CBOR.
type intOrStr struct {
intVal int64
strVal string
isString bool
}

func newIntOrStr(v interface{}) *intOrStr {
var ios intOrStr
if err := ios.Set(v); err != nil {
return nil
}
return &ios
}

func (ios intOrStr) Int() int64 {
return ios.intVal
}

func (ios intOrStr) String() string {
if ios.IsString() {
return ios.strVal
}
return fmt.Sprint(ios.intVal)
}

func (ios intOrStr) IsInt() bool {
return !ios.isString
}

func (ios intOrStr) IsString() bool {
return ios.isString
}

func (ios intOrStr) Value() interface{} {
if ios.IsInt() {
return ios.intVal
}

return ios.strVal
}

func (ios *intOrStr) Set(v interface{}) error {
switch t := v.(type) {
case int64:
ios.intVal = t
ios.strVal = ""
ios.isString = false
case int:
ios.intVal = int64(t)
ios.strVal = ""
ios.isString = false
case string:
ios.strVal = t
ios.intVal = 0
ios.isString = true
default:
return fmt.Errorf("must be int or string, found %T", t)
}

return nil
}

// MarshalCBOR returns the encoded CBOR representation of the intOrString, as
// either int or tstr, depending on the value. If no value has been set,
// intOrStr is encoded as a zero-length tstr.
func (ios intOrStr) MarshalCBOR() ([]byte, error) {
if ios.IsInt() {
return encMode.Marshal(ios.intVal)
}

return encMode.Marshal(ios.strVal)
}

// UnmarshalCBOR unmarshals the provided CBOR encoded data (must be an int,
// uint, or tstr).
func (ios *intOrStr) UnmarshalCBOR(data []byte) error {
if len(data) == 0 {
return errors.New("zero length buffer")
}

var val interface{}
if err := decMode.Unmarshal(data, &val); err != nil {
return err
}

return ios.Set(val)
}
140 changes: 140 additions & 0 deletions common_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
package cose

import (
"bytes"
"reflect"
"testing"

"github.com/fxamacker/cbor/v2"
)

func Test_intOrStr(t *testing.T) {
ios := newIntOrStr(3)
assertEqual(t, true, ios.IsInt())
assertEqual(t, false, ios.IsString())
assertEqual(t, 3, ios.Int())
assertEqual(t, "3", ios.String())

ios = newIntOrStr("foo")
assertEqual(t, false, ios.IsInt())
assertEqual(t, true, ios.IsString())
assertEqual(t, 0, ios.Int())
assertEqual(t, "foo", ios.String())

ios = newIntOrStr(3.5)
if ios != nil {
t.Errorf("Expected nil, got %v", ios)
}
}

func Test_intOrStr_CBOR(t *testing.T) {
ios := newIntOrStr(3)
data, err := ios.MarshalCBOR()
requireNoError(t, err)
assertEqual(t, []byte{0x03}, data)

ios = &intOrStr{}
err = ios.UnmarshalCBOR(data)
requireNoError(t, err)
assertEqual(t, true, ios.IsInt())
assertEqual(t, 3, ios.Int())

ios = newIntOrStr("foo")
data, err = ios.MarshalCBOR()
requireNoError(t, err)
assertEqual(t, []byte{0x63, 0x66, 0x6f, 0x6f}, data)

ios = &intOrStr{}
err = ios.UnmarshalCBOR(data)
requireNoError(t, err)
assertEqual(t, true, ios.IsString())
assertEqual(t, "foo", ios.String())

// empty value as field
s := struct {
Field1 intOrStr `cbor:"1,keyasint"`
Field2 int `cbor:"2,keyasint"`
}{Field1: intOrStr{}, Field2: 7}

data, err = cbor.Marshal(s)
requireNoError(t, err)
assertEqual(t, []byte{0xa2, 0x1, 0x00, 0x2, 0x7}, data)

ios = &intOrStr{}
data = []byte{0x22}
err = ios.UnmarshalCBOR(data)
requireNoError(t, err)
assertEqual(t, true, ios.IsInt())
assertEqual(t, -3, ios.Int())

data = []byte{}
err = ios.UnmarshalCBOR(data)
assertEqualError(t, err, "zero length buffer")

data = []byte{0x40}
err = ios.UnmarshalCBOR(data)
assertEqualError(t, err, "must be int or string, found []uint8")

data = []byte{0xff, 0xff}
err = ios.UnmarshalCBOR(data)
assertEqualError(t, err, "cbor: unexpected \"break\" code")
}

func requireNoError(t *testing.T, err error) {
if err != nil {
t.Errorf("Unexpected error: %q", err)
t.Fail()
}
}

func assertEqualError(t *testing.T, err error, expected string) {
if err == nil || err.Error() != expected {
t.Errorf("Unexpected error: want %q, got %q", expected, err)
}
}

func assertEqual(t *testing.T, expected, actual interface{}) {
if !objectsAreEqualValues(expected, actual) {
t.Errorf("Unexpected value: want %v, got %v", expected, actual)
}
}

// taken from github.com/stretchr/testify
func objectsAreEqualValues(expected, actual interface{}) bool {
if objectsAreEqual(expected, actual) {
return true
}

actualType := reflect.TypeOf(actual)
if actualType == nil {
return false
}
expectedValue := reflect.ValueOf(expected)
if expectedValue.IsValid() && expectedValue.Type().ConvertibleTo(actualType) {
// Attempt comparison after type conversion
return reflect.DeepEqual(expectedValue.Convert(actualType).Interface(), actual)
}

return false
}

// taken from github.com/stretchr/testify
func objectsAreEqual(expected, actual interface{}) bool {
if expected == nil || actual == nil {
return expected == actual
}

exp, ok := expected.([]byte)
if !ok {
return reflect.DeepEqual(expected, actual)
}

act, ok := actual.([]byte)
if !ok {
return false
}
if exp == nil || act == nil {
return exp == nil && act == nil
}
return bytes.Equal(exp, act)
}
4 changes: 4 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,8 @@ var (
ErrUnavailableHashFunc = errors.New("hash function is not available")
ErrVerification = errors.New("verification error")
ErrInvalidPubKey = errors.New("invalid public key")
ErrInvalidPrivKey = errors.New("invalid private key")
ErrNotPrivKey = errors.New("not a private key")
ErrSignOpNotSupported = errors.New("sign key_op not supported by key")
ErrVerifyOpNotSupported = errors.New("verify key_op not supported by key")
)
Loading

0 comments on commit a579021

Please sign in to comment.