Skip to content

Commit

Permalink
adding more acme nosql unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dopey committed Mar 25, 2021
1 parent 88e6f00 commit 7f9ffbd
Show file tree
Hide file tree
Showing 4 changed files with 730 additions and 45 deletions.
2 changes: 1 addition & 1 deletion acme/db/nosql/challenge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func TestDB_getDBChallenge(t *testing.T) {
assert.Equals(t, k.Type, tc.acmeErr.Type)
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
assert.Equals(t, k.Status, tc.acmeErr.Status)
assert.Equals(t, k.Err, tc.acmeErr.Err)
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
}
default:
Expand Down
126 changes: 126 additions & 0 deletions acme/db/nosql/nosql_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package nosql

import (
"context"
"testing"

"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/db"
"github.com/smallstep/nosql"
)

func TestNew(t *testing.T) {
type test struct {
db nosql.DB
err error
}
var tests = map[string]test{
"fail/db.CreateTable-error": test{
db: &db.MockNoSQLDB{
MCreateTable: func(bucket []byte) error {
assert.Equals(t, string(bucket), string(accountTable))
return errors.New("force")
},
},
err: errors.Errorf("error creating table %s: force", string(accountTable)),
},
"ok": test{
db: &db.MockNoSQLDB{
MCreateTable: func(bucket []byte) error {
return nil
},
},
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
if _, err := New(tc.db); err != nil {
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
assert.Nil(t, tc.err)
}
})
}
}

type errorThrower string

func (et errorThrower) MarshalJSON() ([]byte, error) {
return nil, errors.New("force")
}

func TestDB_save(t *testing.T) {
type test struct {
db nosql.DB
nu interface{}
old interface{}
err error
}
var tests = map[string]test{
"fail/error-marshaling-new": test{
nu: errorThrower("foo"),
err: errors.New("error marshaling acme type: challenge"),
},
"fail/error-marshaling-old": test{
nu: "new",
old: errorThrower("foo"),
err: errors.New("error marshaling acme type: challenge"),
},
"fail/db.CmpAndSwap-error": test{
nu: "new",
old: "old",
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
assert.Equals(t, bucket, challengeTable)
assert.Equals(t, string(key), "id")
assert.Equals(t, string(old), "\"old\"")
assert.Equals(t, string(nu), "\"new\"")
return nil, false, errors.New("force")
},
},
err: errors.New("error saving acme challenge: force"),
},
"fail/db.CmpAndSwap-false-marshaling-old": test{
nu: "new",
old: "old",
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
assert.Equals(t, bucket, challengeTable)
assert.Equals(t, string(key), "id")
assert.Equals(t, string(old), "\"old\"")
assert.Equals(t, string(nu), "\"new\"")
return nil, false, nil
},
},
err: errors.New("error saving acme challenge; changed since last read"),
},
"ok": test{
nu: "new",
old: "old",
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
assert.Equals(t, bucket, challengeTable)
assert.Equals(t, string(key), "id")
assert.Equals(t, string(old), "\"old\"")
assert.Equals(t, string(nu), "\"new\"")
return nu, true, nil
},
},
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
db := &DB{db: tc.db}
if err := db.save(context.Background(), "id", tc.nu, tc.old, "challenge", challengeTable); err != nil {
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
assert.Nil(t, tc.err)
}
})
}
}
90 changes: 46 additions & 44 deletions acme/db/nosql/order.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,18 @@ import (
var ordersByAccountMux sync.Mutex

type dbOrder struct {
ID string `json:"id"`
AccountID string `json:"accountID"`
ProvisionerID string `json:"provisionerID"`
Created time.Time `json:"created"`
Expires time.Time `json:"expires,omitempty"`
Status acme.Status `json:"status"`
Identifiers []acme.Identifier `json:"identifiers"`
NotBefore time.Time `json:"notBefore,omitempty"`
NotAfter time.Time `json:"notAfter,omitempty"`
Error *acme.Error `json:"error,omitempty"`
Authorizations []string `json:"authorizations"`
CertificateID string `json:"certificate,omitempty"`
ID string `json:"id"`
AccountID string `json:"accountID"`
ProvisionerID string `json:"provisionerID"`
CreatedAt time.Time `json:"createdAt"`
ExpiresAt time.Time `json:"expiresAt,omitempty"`
Status acme.Status `json:"status"`
Identifiers []acme.Identifier `json:"identifiers"`
NotBefore time.Time `json:"notBefore,omitempty"`
NotAfter time.Time `json:"notAfter,omitempty"`
Error *acme.Error `json:"error,omitempty"`
AuthorizationIDs []string `json:"authorizationIDs"`
CertificateID string `json:"certificate,omitempty"`
}

func (a *dbOrder) clone() *dbOrder {
Expand All @@ -38,13 +38,13 @@ func (a *dbOrder) clone() *dbOrder {
func (db *DB) getDBOrder(ctx context.Context, id string) (*dbOrder, error) {
b, err := db.db.Get(orderTable, []byte(id))
if nosql.IsErrNotFound(err) {
return nil, acme.WrapError(acme.ErrorMalformedType, err, "order %s not found", id)
return nil, acme.NewError(acme.ErrorMalformedType, "order %s not found", id)
} else if err != nil {
return nil, errors.Wrapf(err, "error loading order %s", id)
}
o := new(dbOrder)
if err := json.Unmarshal(b, &o); err != nil {
return nil, errors.Wrap(err, "error unmarshaling order")
return nil, errors.Wrapf(err, "error unmarshaling order %s into dbOrder", id)
}
return o, nil
}
Expand All @@ -57,15 +57,17 @@ func (db *DB) GetOrder(ctx context.Context, id string) (*acme.Order, error) {
}

o := &acme.Order{
ID: dbo.ID,
AccountID: dbo.AccountID,
ProvisionerID: dbo.ProvisionerID,
CertificateID: dbo.CertificateID,
Status: dbo.Status,
ExpiresAt: dbo.Expires,
ExpiresAt: dbo.ExpiresAt,
Identifiers: dbo.Identifiers,
NotBefore: dbo.NotBefore,
NotAfter: dbo.NotAfter,
AuthorizationIDs: dbo.Authorizations,
ID: dbo.ID,
ProvisionerID: dbo.ProvisionerID,
CertificateID: dbo.CertificateID,
AuthorizationIDs: dbo.AuthorizationIDs,
Error: dbo.Error,
}

return o, nil
Expand All @@ -81,16 +83,16 @@ func (db *DB) CreateOrder(ctx context.Context, o *acme.Order) error {

now := clock.Now()
dbo := &dbOrder{
ID: o.ID,
AccountID: o.AccountID,
ProvisionerID: o.ProvisionerID,
Created: now,
Status: acme.StatusPending,
Expires: o.ExpiresAt,
Identifiers: o.Identifiers,
NotBefore: o.NotBefore,
NotAfter: o.NotBefore,
Authorizations: o.AuthorizationIDs,
ID: o.ID,
AccountID: o.AccountID,
ProvisionerID: o.ProvisionerID,
Status: o.Status,
CreatedAt: now,
ExpiresAt: o.ExpiresAt,
Identifiers: o.Identifiers,
NotBefore: o.NotBefore,
NotAfter: o.NotBefore,
AuthorizationIDs: o.AuthorizationIDs,
}
if err := db.save(ctx, o.ID, dbo, nil, "order", orderTable); err != nil {
return err
Expand All @@ -103,6 +105,21 @@ func (db *DB) CreateOrder(ctx context.Context, o *acme.Order) error {
return nil
}

// UpdateOrder saves an updated ACME Order to the database.
func (db *DB) UpdateOrder(ctx context.Context, o *acme.Order) error {
old, err := db.getDBOrder(ctx, o.ID)
if err != nil {
return err
}

nu := old.clone()

nu.Status = o.Status
nu.Error = o.Error
nu.CertificateID = o.CertificateID
return db.save(ctx, old.ID, nu, old, "order", orderTable)
}

type orderIDsByAccount struct{}

func (db *DB) updateAddOrderIDs(ctx context.Context, accID string, addOids ...string) ([]string, error) {
Expand Down Expand Up @@ -158,18 +175,3 @@ func (db *DB) updateAddOrderIDs(ctx context.Context, accID string, addOids ...st
func (db *DB) GetOrdersByAccountID(ctx context.Context, accID string) ([]string, error) {
return db.updateAddOrderIDs(ctx, accID)
}

// UpdateOrder saves an updated ACME Order to the database.
func (db *DB) UpdateOrder(ctx context.Context, o *acme.Order) error {
old, err := db.getDBOrder(ctx, o.ID)
if err != nil {
return err
}

nu := old.clone()

nu.Status = o.Status
nu.Error = o.Error
nu.CertificateID = o.CertificateID
return db.save(ctx, old.ID, nu, old, "order", orderTable)
}

0 comments on commit 7f9ffbd

Please sign in to comment.