From 3489cf5e193d1a89b5d58f89d2086d0a60877fe5 Mon Sep 17 00:00:00 2001 From: qmuntal Date: Thu, 19 May 2022 13:33:30 +0200 Subject: [PATCH] check for coexisting iv and partialiv in different headers Signed-off-by: qmuntal --- headers.go | 51 ++++++++++++++++++++++++++++++++++++++------ sign.go | 12 ++--------- sign1.go | 6 +----- sign1_test.go | 58 +++++++++++++++++++++++++++++++++++++++++++++++++++ sign_test.go | 52 +++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 158 insertions(+), 21 deletions(-) diff --git a/headers.go b/headers.go index cd0eba8..884d206 100644 --- a/headers.go +++ b/headers.go @@ -267,6 +267,23 @@ type Headers struct { Unprotected UnprotectedHeader } +// marshal encoded both headers. +// It returns RawProtected and RawUnprotected if those are set. +func (h *Headers) marshal() (cbor.RawMessage, cbor.RawMessage, error) { + if err := h.ensureIV(); err != nil { + return nil, nil, err + } + protected, err := h.MarshalProtected() + if err != nil { + return nil, nil, err + } + unprotected, err := h.MarshalUnprotected() + if err != nil { + return nil, nil, err + } + return protected, unprotected, nil +} + // MarshalProtected encodes the protected header. // RawProtected is returned if it is not set to nil. func (h *Headers) MarshalProtected() ([]byte, error) { @@ -294,6 +311,9 @@ func (h *Headers) UnmarshalFromRaw() error { if err := decMode.Unmarshal(h.RawUnprotected, &h.Unprotected); err != nil { return fmt.Errorf("cbor: invalid unprotected header: %w", err) } + if err := h.ensureIV(); err != nil { + return err + } return nil } @@ -345,17 +365,36 @@ func (h *Headers) ensureVerificationAlgorithm(alg Algorithm, external []byte) er return err } +// ensureIV ensures IV and Partial IV are not both present +// in the protected and unprotected headers. +// It does not check if they are both present within one header, +// as it will be checked later on. +// +// Reference: https://datatracker.ietf.org/doc/html/rfc8152#section-3.1 +func (h *Headers) ensureIV() error { + if hasLabel(h.Protected, HeaderLabelIV) && hasLabel(h.Unprotected, HeaderLabelPartialIV) { + return errors.New("IV (protected) and PartialIV (unprotected) parameters must not both be present") + } + if hasLabel(h.Protected, HeaderLabelPartialIV) && hasLabel(h.Unprotected, HeaderLabelIV) { + return errors.New("IV (unprotected) and PartialIV (protected) parameters must not both be present") + } + return nil +} + +// hasLabel returns true if h contains label. +func hasLabel(h map[interface{}]interface{}, label interface{}) bool { + _, ok := h[label] + return ok +} + // ensureHeaderIV ensures IV and Partial IV are not both present in the header. // // Reference: https://datatracker.ietf.org/doc/html/rfc8152#section-3.1 func ensureHeaderIV(h map[interface{}]interface{}) error { - if _, ok := h[HeaderLabelIV]; !ok { - return nil + if hasLabel(h, HeaderLabelIV) && hasLabel(h, HeaderLabelPartialIV) { + return errors.New("IV and PartialIV parameters must not both be present") } - if _, ok := h[HeaderLabelPartialIV]; !ok { - return nil - } - return errors.New("the 'Initialization Vector' and 'Partial Initialization Vector' parameters must not both be present") + return nil } // validateHeaderLabel validates if all header labels are integers or strings. diff --git a/sign.go b/sign.go index 0995f3b..e9ca759 100644 --- a/sign.go +++ b/sign.go @@ -73,11 +73,7 @@ func (s *Signature) MarshalCBOR() ([]byte, error) { if len(s.Signature) == 0 { return nil, ErrEmptySignature } - protected, err := s.Headers.MarshalProtected() - if err != nil { - return nil, err - } - unprotected, err := s.Headers.MarshalUnprotected() + protected, unprotected, err := s.Headers.marshal() if err != nil { return nil, err } @@ -329,11 +325,7 @@ func (m *SignMessage) MarshalCBOR() ([]byte, error) { if len(m.Signatures) == 0 { return nil, ErrNoSignatures } - protected, err := m.Headers.MarshalProtected() - if err != nil { - return nil, err - } - unprotected, err := m.Headers.MarshalUnprotected() + protected, unprotected, err := m.Headers.marshal() if err != nil { return nil, err } diff --git a/sign1.go b/sign1.go index 9d8606f..2fcb201 100644 --- a/sign1.go +++ b/sign1.go @@ -58,11 +58,7 @@ func (m *Sign1Message) MarshalCBOR() ([]byte, error) { if len(m.Signature) == 0 { return nil, ErrEmptySignature } - protected, err := m.Headers.MarshalProtected() - if err != nil { - return nil, err - } - unprotected, err := m.Headers.MarshalUnprotected() + protected, unprotected, err := m.Headers.marshal() if err != nil { return nil, err } diff --git a/sign1_test.go b/sign1_test.go index 890e7ba..0ac77e1 100644 --- a/sign1_test.go +++ b/sign1_test.go @@ -130,6 +130,40 @@ func TestSign1Message_MarshalCBOR(t *testing.T) { }, wantErr: true, }, + { + name: "protected has IV and unprotected has PartialIV error", + m: &Sign1Message{ + Headers: Headers{ + Protected: ProtectedHeader{ + HeaderLabelAlgorithm: AlgorithmES256, + HeaderLabelIV: "", + }, + Unprotected: UnprotectedHeader{ + HeaderLabelPartialIV: "", + }, + }, + Payload: []byte("foo"), + Signature: []byte("bar"), + }, + wantErr: true, + }, + { + name: "protected has PartialIV and unprotected has IV error", + m: &Sign1Message{ + Headers: Headers{ + Protected: ProtectedHeader{ + HeaderLabelAlgorithm: AlgorithmES256, + HeaderLabelPartialIV: "", + }, + Unprotected: UnprotectedHeader{ + HeaderLabelIV: "", + }, + }, + Payload: []byte("foo"), + Signature: []byte("bar"), + }, + wantErr: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -324,6 +358,30 @@ func TestSign1Message_UnmarshalCBOR(t *testing.T) { }, wantErr: true, }, + { + name: "protected has IV and unprotected has PartialIV", + data: []byte{ + 0xd2, // tag + 0x84, + 0x46, 0xa1, 0x5, 0x63, 0x66, 0x6f, 0x6f, // protected + 0xa1, 0x6, 0x63, 0x62, 0x61, 0x72, // unprotected + 0xf6, // payload + 0x43, 0x62, 0x61, 0x72, // signature + }, + wantErr: true, + }, + { + name: "protected has PartialIV and unprotected has IV", + data: []byte{ + 0xd2, // tag + 0x84, + 0x46, 0xa1, 0x6, 0x63, 0x66, 0x6f, 0x6f, // protected + 0xa1, 0x5, 0x63, 0x62, 0x61, 0x72, // unprotected + 0xf6, // payload + 0x43, 0x62, 0x61, 0x72, // signature + }, + wantErr: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/sign_test.go b/sign_test.go index cc66581..b572c13 100644 --- a/sign_test.go +++ b/sign_test.go @@ -102,6 +102,38 @@ func TestSignature_MarshalCBOR(t *testing.T) { }, wantErr: true, }, + { + name: "protected has IV and unprotected has PartialIV error", + s: &Signature{ + Headers: Headers{ + Protected: ProtectedHeader{ + HeaderLabelAlgorithm: AlgorithmES256, + HeaderLabelIV: "", + }, + Unprotected: UnprotectedHeader{ + HeaderLabelPartialIV: "", + }, + }, + Signature: []byte("bar"), + }, + wantErr: true, + }, + { + name: "protected has PartialIV and unprotected has IV error", + s: &Signature{ + Headers: Headers{ + Protected: ProtectedHeader{ + HeaderLabelAlgorithm: AlgorithmES256, + HeaderLabelPartialIV: "", + }, + Unprotected: UnprotectedHeader{ + HeaderLabelIV: "", + }, + }, + Signature: []byte("bar"), + }, + wantErr: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -227,6 +259,26 @@ func TestSignature_UnmarshalCBOR(t *testing.T) { }, wantErr: true, }, + { + name: "protected has IV and unprotected has PartialIV", + data: []byte{ + 0x83, + 0x46, 0xa1, 0x5, 0x63, 0x66, 0x6f, 0x6f, // protected + 0xa1, 0x6, 0x63, 0x62, 0x61, 0x72, // unprotected + 0x43, 0x62, 0x61, 0x72, // signature + }, + wantErr: true, + }, + { + name: "protected has PartialIV and unprotected has IV", + data: []byte{ + 0x83, + 0x46, 0xa1, 0x6, 0x63, 0x66, 0x6f, 0x6f, // protected + 0xa1, 0x5, 0x63, 0x62, 0x61, 0x72, // unprotected + 0x43, 0x62, 0x61, 0x72, // signature + }, + wantErr: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {