Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
hslatman committed Feb 6, 2024
1 parent 745017c commit 194341e
Show file tree
Hide file tree
Showing 9 changed files with 69 additions and 55 deletions.
20 changes: 10 additions & 10 deletions acme/api/order.go
Expand Up @@ -84,12 +84,12 @@ func (n *NewOrderRequest) validateWireIdentifiers() error {
return fmt.Errorf("expected exactly one Wire DeviceID identifier, got %d", len(deviceIdentifiers))
}

wireUserID, err := wire.ParseUserID([]byte(userIdentifiers[0].Value))
wireUserID, err := wire.ParseUserID(userIdentifiers[0].Value)
if err != nil {
return fmt.Errorf("failed parsing Wire UserID: %w", err)
}

wireDeviceID, err := wire.ParseDeviceID([]byte(deviceIdentifiers[0].Value))
wireDeviceID, err := wire.ParseDeviceID(deviceIdentifiers[0].Value)
if err != nil {
return fmt.Errorf("failed parsing Wire DeviceID: %w", err)
}
Expand Down Expand Up @@ -337,26 +337,26 @@ func newAuthorization(ctx context.Context, az *acme.Authorization) error {
var target string
switch az.Identifier.Type {
case acme.WireUser:
wireOptions := prov.GetOptions().GetWireOptions()
if wireOptions == nil {
return acme.NewErrorISE("failed getting Wire options")
wireOptions, err := prov.GetOptions().GetWireOptions()
if err != nil {
return acme.WrapErrorISE(err, "failed getting Wire options")
}
target, err = wireOptions.GetOIDCOptions().EvaluateTarget("")
target, err = wireOptions.GetOIDCOptions().EvaluateTarget("") // TODO(hs): determine if required by Wire
if err != nil {
return acme.WrapError(acme.ErrorMalformedType, err, "invalid Go template registered for 'target'")
}
case acme.WireDevice:
wireID, err := wire.ParseDeviceID([]byte(az.Identifier.Value))
wireID, err := wire.ParseDeviceID(az.Identifier.Value)
if err != nil {
return acme.WrapError(acme.ErrorMalformedType, err, "failed parsing WireDevice")
}
clientID, err := wire.ParseClientID(wireID.ClientID)
if err != nil {
return acme.WrapError(acme.ErrorMalformedType, err, "failed parsing ClientID")
}
wireOptions := prov.GetOptions().GetWireOptions()
if wireOptions == nil {
return acme.NewErrorISE("failed getting Wire options")
wireOptions, err := prov.GetOptions().GetWireOptions()
if err != nil {
return acme.WrapErrorISE(err, "failed getting Wire options")
}
target, err = wireOptions.GetDPOPOptions().EvaluateTarget(clientID.DeviceID)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions acme/api/order_test.go
Expand Up @@ -699,7 +699,7 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k=
az: az,
err: &acme.Error{
Type: "urn:ietf:params:acme:error:serverInternal",
Err: errors.New("failed getting Wire options"),
Err: errors.New("failed getting Wire options: no Wire options available"),
Detail: "The server experienced an internal error",
Status: 500,
},
Expand Down Expand Up @@ -765,7 +765,7 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k=
az: az,
err: &acme.Error{
Type: "urn:ietf:params:acme:error:serverInternal",
Err: errors.New("failed getting Wire options"),
Err: errors.New("failed getting Wire options: no Wire options available"),
Detail: "The server experienced an internal error",
Status: 500,
},
Expand Down
21 changes: 10 additions & 11 deletions acme/challenge.go
Expand Up @@ -362,22 +362,21 @@ func wireOIDC01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSO
if !ok {
return NewErrorISE("missing provisioner")
}
wireOptions := prov.GetOptions().GetWireOptions()
if wireOptions == nil {
return NewErrorISE("no Wire options available")
wireOptions, err := prov.GetOptions().GetWireOptions()
if err != nil {
return WrapErrorISE(err, "failed getting Wire options")
}
linker, ok := LinkerFromContext(ctx)
if !ok {
return NewErrorISE("missing linker")
}

var oidcPayload wireOidcPayload
err := json.Unmarshal(payload, &oidcPayload)
if err != nil {
if err := json.Unmarshal(payload, &oidcPayload); err != nil {
return WrapError(ErrorMalformedType, err, "error unmarshalling Wire OIDC challenge payload")
}

wireID, err := wire.ParseUserID([]byte(ch.Value))
wireID, err := wire.ParseUserID(ch.Value)
if err != nil {
return WrapErrorISE(err, "error unmarshalling challenge data")
}
Expand Down Expand Up @@ -493,9 +492,9 @@ func wireDPOP01Validate(ctx context.Context, ch *Challenge, db DB, accountJWK *j
if !ok {
return NewErrorISE("missing provisioner")
}
wireOptions := prov.GetOptions().GetWireOptions()
if wireOptions == nil {
return NewErrorISE("no Wire options available")
wireOptions, err := prov.GetOptions().GetWireOptions()
if err != nil {
return WrapErrorISE(err, "failed getting Wire options")
}
linker, ok := LinkerFromContext(ctx)
if !ok {
Expand All @@ -507,7 +506,7 @@ func wireDPOP01Validate(ctx context.Context, ch *Challenge, db DB, accountJWK *j
return WrapError(ErrorMalformedType, err, "error unmarshalling Wire DPoP challenge payload")
}

wireID, err := wire.ParseDeviceID([]byte(ch.Value))
wireID, err := wire.ParseDeviceID(ch.Value)
if err != nil {
return WrapErrorISE(err, "error unmarshalling challenge data")
}
Expand Down Expand Up @@ -728,7 +727,7 @@ func parseAndVerifyWireAccessToken(v wireVerifyParams) (*wireAccessToken, *wireD
return nil, nil, fmt.Errorf("invalid display name in Wire DPoP token")
}
if name == "" || name != v.wireID.Name {
return nil, nil, fmt.Errorf("invalid Wire client display name %q", handle)
return nil, nil, fmt.Errorf("invalid Wire client display name %q", name)
}

return &accessToken, &dpopToken, nil
Expand Down
6 changes: 3 additions & 3 deletions acme/order.go
Expand Up @@ -340,7 +340,7 @@ func createWireSubject(o *Order, csr *x509.CertificateRequest) (subject x509util
for _, identifier := range o.Identifiers {
switch identifier.Type {
case WireUser:
wireID, err := wire.ParseUserID([]byte(identifier.Value))
wireID, err := wire.ParseUserID(identifier.Value)
if err != nil {
return subject, NewErrorISE("unmarshal wireID: %s", err)
}
Expand Down Expand Up @@ -406,7 +406,7 @@ func (o *Order) sans(csr *x509.CertificateRequest) ([]x509util.SubjectAlternativ
orderPIDs[indexPID] = n.Value
indexPID++
case WireUser:
wireID, err := wire.ParseUserID([]byte(n.Value))
wireID, err := wire.ParseUserID(n.Value)
if err != nil {
return sans, NewErrorISE("unsupported identifier value in order: %s", n.Value)
}
Expand All @@ -417,7 +417,7 @@ func (o *Order) sans(csr *x509.CertificateRequest) ([]x509util.SubjectAlternativ
tmpOrderURIs[indexURI] = handle
indexURI++
case WireDevice:
wireID, err := wire.ParseDeviceID([]byte(n.Value))
wireID, err := wire.ParseDeviceID(n.Value)
if err != nil {
return sans, NewErrorISE("unsupported identifier value in order: %s", n.Value)
}
Expand Down
8 changes: 4 additions & 4 deletions acme/wire/id.go
Expand Up @@ -22,8 +22,8 @@ type DeviceID struct {
Handle string `json:"handle,omitempty"`
}

func ParseUserID(data []byte) (id UserID, err error) {
if err = json.Unmarshal(data, &id); err != nil {
func ParseUserID(value string) (id UserID, err error) {
if err = json.Unmarshal([]byte(value), &id); err != nil {
return
}

Expand All @@ -39,8 +39,8 @@ func ParseUserID(data []byte) (id UserID, err error) {
return
}

func ParseDeviceID(data []byte) (id DeviceID, err error) {
if err = json.Unmarshal(data, &id); err != nil {
func ParseDeviceID(value string) (id DeviceID, err error) {
if err = json.Unmarshal([]byte(value), &id); err != nil {
return
}

Expand Down
30 changes: 15 additions & 15 deletions acme/wire/id_test.go
Expand Up @@ -15,19 +15,19 @@ func TestParseUserID(t *testing.T) {
emptyDomain := `{"name": "Alice Smith", "domain": "", "handle": "wireapp://%40alice_wire@wire.com"}`
tests := []struct {
name string
data []byte
value string
wantWireID UserID
wantErr bool
}{
{name: "ok", data: []byte(ok), wantWireID: UserID{Name: "Alice Smith", Domain: "wire.com", Handle: "wireapp://%40alice_wire@wire.com"}},
{name: "fail/json", data: []byte(failJSON), wantErr: true},
{name: "fail/empty-handle", data: []byte(emptyHandle), wantErr: true},
{name: "fail/empty-name", data: []byte(emptyName), wantErr: true},
{name: "fail/empty-domain", data: []byte(emptyDomain), wantErr: true},
{name: "ok", value: ok, wantWireID: UserID{Name: "Alice Smith", Domain: "wire.com", Handle: "wireapp://%40alice_wire@wire.com"}},
{name: "fail/json", value: failJSON, wantErr: true},
{name: "fail/empty-handle", value: emptyHandle, wantErr: true},
{name: "fail/empty-name", value: emptyName, wantErr: true},
{name: "fail/empty-domain", value: emptyDomain, wantErr: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotWireID, err := ParseUserID(tt.data)
gotWireID, err := ParseUserID(tt.value)
if tt.wantErr {
assert.Error(t, err)
return
Expand All @@ -48,20 +48,20 @@ func TestParseDeviceID(t *testing.T) {
emptyClientID := `{"name": "device", "domain": "wire.com", "client-id": "", "handle": "wireapp://%40alice_wire@wire.com"}`
tests := []struct {
name string
data []byte
value string
wantWireID DeviceID
wantErr bool
}{
{name: "ok", data: []byte(ok), wantWireID: DeviceID{Name: "device", Domain: "wire.com", ClientID: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", Handle: "wireapp://%40alice_wire@wire.com"}},
{name: "fail/json", data: []byte(failJSON), wantErr: true},
{name: "fail/empty-handle", data: []byte(emptyHandle), wantErr: true},
{name: "fail/empty-name", data: []byte(emptyName), wantErr: true},
{name: "fail/empty-domain", data: []byte(emptyDomain), wantErr: true},
{name: "fail/empty-client-id", data: []byte(emptyClientID), wantErr: true},
{name: "ok", value: ok, wantWireID: DeviceID{Name: "device", Domain: "wire.com", ClientID: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", Handle: "wireapp://%40alice_wire@wire.com"}},
{name: "fail/json", value: failJSON, wantErr: true},
{name: "fail/empty-handle", value: emptyHandle, wantErr: true},
{name: "fail/empty-name", value: emptyName, wantErr: true},
{name: "fail/empty-domain", value: emptyDomain, wantErr: true},
{name: "fail/empty-client-id", value: emptyClientID, wantErr: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotWireID, err := ParseDeviceID(tt.data)
gotWireID, err := ParseDeviceID(tt.value)
if tt.wantErr {
assert.Error(t, err)
return
Expand Down
10 changes: 5 additions & 5 deletions authority/provisioner/acme.go
Expand Up @@ -235,9 +235,9 @@ func (p *ACME) initializeWireOptions() error {
return nil
}

w := p.GetOptions().GetWireOptions()
if w == nil {
return errors.New("no Wire options available")
w, err := p.GetOptions().GetWireOptions()
if err != nil {
return fmt.Errorf("failed getting Wire options: %w", err)
}

if err := w.Validate(); err != nil {
Expand Down Expand Up @@ -295,13 +295,13 @@ func (p *ACME) AuthorizeOrderIdentifier(_ context.Context, identifier ACMEIdenti
err = x509Policy.IsDNSAllowed(identifier.Value)
case WireUser:
var wireID wire.UserID
if wireID, err = wire.ParseUserID([]byte(identifier.Value)); err != nil {
if wireID, err = wire.ParseUserID(identifier.Value); err != nil {
return fmt.Errorf("failed parsing Wire SANs: %w", err)
}
err = x509Policy.AreSANsAllowed([]string{wireID.Handle})
case WireDevice:
var wireID wire.DeviceID
if wireID, err = wire.ParseDeviceID([]byte(identifier.Value)); err != nil {
if wireID, err = wire.ParseDeviceID(identifier.Value); err != nil {
return fmt.Errorf("failed parsing Wire SANs: %w", err)
}
err = x509Policy.AreSANsAllowed([]string{wireID.ClientID})
Expand Down
13 changes: 12 additions & 1 deletion authority/provisioner/acme_test.go
Expand Up @@ -155,7 +155,18 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k=
Type: "ACME",
Challenges: []ACMEChallenge{WIREOIDC_01, WIREDPOP_01},
},
err: errors.New("failed initializing Wire options: no Wire options available"),
err: errors.New("failed initializing Wire options: failed getting Wire options: no options available"),
}
},
"fail/wire-missing-wire-options": func(t *testing.T) ProvisionerValidateTest {
return ProvisionerValidateTest{
p: &ACME{
Name: "foo",
Type: "ACME",
Challenges: []ACMEChallenge{WIREOIDC_01, WIREDPOP_01},
Options: &Options{},
},
err: errors.New("failed initializing Wire options: failed getting Wire options: no Wire options available"),
}
},
"fail/wire-validate-options": func(t *testing.T) ProvisionerValidateTest {
Expand Down
12 changes: 8 additions & 4 deletions authority/provisioner/options.go
Expand Up @@ -53,12 +53,16 @@ func (o *Options) GetSSHOptions() *SSHOptions {
return o.SSH
}

// GetWireOptions returns the SSH options.
func (o *Options) GetWireOptions() *wire.Options {
// GetWireOptions returns the Wire options if available. It
// returns an error if they're not available.
func (o *Options) GetWireOptions() (*wire.Options, error) {
if o == nil {
return nil
return nil, errors.New("no options available")
}
if o.Wire == nil {
return nil, errors.New("no Wire options available")
}
return o.Wire
return o.Wire, nil
}

// GetWebhooks returns the webhooks options.
Expand Down

0 comments on commit 194341e

Please sign in to comment.