Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add collections interface and play around with collections. #889

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 13 additions & 0 deletions authority/admins.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,19 @@ import (
"go.step.sm/linkedca"
)

// Admins is the interface used by the admin collection.
type Admins interface {
Store(adm *linkedca.Admin, prov provisioner.Interface) error
Update(id string, nu *linkedca.Admin) (*linkedca.Admin, error)
Remove(id string) error
LoadByID(id string) (*linkedca.Admin, bool)
LoadBySubProv(sub, provName string) (*linkedca.Admin, bool)
LoadByProvisioner(provName string) ([]*linkedca.Admin, bool)
Find(cursor string, limit int) ([]*linkedca.Admin, string)
SuperCount() int
SuperCountByProvisioner(provName string) int
}

// LoadAdminByID returns an *linkedca.Admin with the given ID.
func (a *Authority) LoadAdminByID(id string) (*linkedca.Admin, bool) {
a.adminMutex.RLock()
Expand Down
13 changes: 10 additions & 3 deletions authority/authority.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/smallstep/certificates/authority/admin"
adminDBNosql "github.com/smallstep/certificates/authority/admin/db/nosql"
"github.com/smallstep/certificates/authority/administrator"
"github.com/smallstep/certificates/authority/cache"
"github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/cas"
Expand All @@ -35,8 +36,8 @@ import (
type Authority struct {
config *config.Config
keyManager kms.KeyManager
provisioners *provisioner.Collection
admins *administrator.Collection
provisioners Provisioners
admins Admins
db db.AuthDB
adminDB admin.DB
templates *templates.Templates
Expand Down Expand Up @@ -76,6 +77,7 @@ type Authority struct {
getIdentityFunc provisioner.GetIdentityFunc
authorizeRenewFunc provisioner.AuthorizeRenewFunc
authorizeSSHRenewFunc provisioner.AuthorizeSSHRenewFunc
cachePool cache.Pool

adminMutex sync.RWMutex
}
Expand Down Expand Up @@ -175,7 +177,7 @@ func (a *Authority) reloadAdminResources(ctx context.Context) error {
}

// Create provisioner collection.
provClxn := provisioner.NewCollection(provisionerConfig.Audiences)
provClxn := provisioner.NewCollection(provisionerConfig)
for _, p := range provList {
if err := p.Init(provisionerConfig); err != nil {
return err
Expand Down Expand Up @@ -502,6 +504,11 @@ func (a *Authority) init() error {
}
}

// Initialize the default cache pool.
if a.cachePool == nil {
a.cachePool = cache.DefaultPool()
}

provs, err := a.adminDB.GetProvisioners(context.Background())
if err != nil {
return admin.WrapErrorISE(err, "error loading provisioners to initialize authority")
Expand Down
89 changes: 89 additions & 0 deletions authority/cache/cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package cache

import (
"context"
"errors"
"sync"
)

var ErrNotFound = errors.New("not found")

type Cache interface {
Get(context.Context, string) ([]byte, error)
Set(context.Context, string, []byte) error
Delete(context.Context, string) error
}

type Getter interface {
Get(ctx context.Context, key string) ([]byte, error)
}

// A GetterFunc implements Getter with a function.
type GetterFunc func(ctx context.Context, key string) ([]byte, error)

func (f GetterFunc) Get(ctx context.Context, key string) ([]byte, error) {
return f(ctx, key)
}

type Pool interface {
New(name string, getter Getter) Cache
Get(name string) (Cache, bool)
}

func DefaultPool() Pool {
return &defaultPool{
caches: make(map[string]Cache),
}
}

type defaultPool struct {
mu sync.RWMutex
caches map[string]Cache
}

func (p *defaultPool) New(name string, getter Getter) Cache {
c := &mapCache{
m: new(sync.Map),
getter: getter,
}
p.mu.Lock()
p.caches[name] = c
p.mu.Unlock()
return c
}

func (p *defaultPool) Get(name string) (Cache, bool) {
p.mu.RLock()
c, ok := p.caches[name]
p.mu.RUnlock()
return c, ok
}

type mapCache struct {
name string
m *sync.Map
getter Getter
}

func (m *mapCache) Get(ctx context.Context, key string) ([]byte, error) {
v, ok := m.m.Load(key)
if !ok {
b, err := m.getter.Get(ctx, key)
if err != nil {
return nil, err
}
m.m.Store(key, b)
return b, nil
}
return v.([]byte), nil
}

func (m *mapCache) Set(ctx context.Context, key string, value []byte) error {
m.m.Store(key, value)
return nil
}

func (m *mapCache) Delete(ctx context.Context, key string) error {
m.m.Delete(key)
return nil
}
9 changes: 9 additions & 0 deletions authority/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/cache"
"github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/cas"
Expand Down Expand Up @@ -284,6 +285,14 @@ func WithX509Enforcers(ces ...provisioner.CertificateEnforcer) Option {
}
}

// WithCachePool is an options that allows to define a custom cache pool.
func WithCachePool(pool cache.Pool) Option {
return func(a *Authority) error {
a.cachePool = pool
return nil
}
}

func readCertificateBundle(pemCerts []byte) ([]*x509.Certificate, error) {
var block *pem.Block
var certs []*x509.Certificate
Expand Down
54 changes: 48 additions & 6 deletions authority/provisioner/collection.go
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
package provisioner

import (
"context"
"crypto/sha1"
"crypto/x509"
"encoding/asn1"
"encoding/binary"
"encoding/hex"
"fmt"
"log"
"net/url"
"sort"
"strings"
"sync"
"time"

"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/cache"
"go.step.sm/crypto/jose"
"go.step.sm/linkedca"
"google.golang.org/protobuf/proto"
)

// DefaultProvisionersLimit is the default limit for listing provisioners.
Expand All @@ -33,6 +39,10 @@ func (p provisionerSlice) Len() int { return len(p) }
func (p provisionerSlice) Less(i, j int) bool { return p[i].uid < p[j].uid }
func (p provisionerSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }

func defaultContext() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), 5*time.Second)
}

// loadByTokenPayload is a payload used to extract the id used to load the
// provisioner.
type loadByTokenPayload struct {
Expand All @@ -50,22 +60,53 @@ type Collection struct {
byTokenID *sync.Map
sorted provisionerSlice
audiences Audiences

// new
byIDCache cache.Cache
byNameCache cache.Cache
}

// NewCollection initializes a collection of provisioners. The given list of
// audiences are the audiences used by the JWT provisioner.
func NewCollection(audiences Audiences) *Collection {
func NewCollection(config Config) *Collection {
byID := config.CachePool.New("provisioner_by_id", cache.GetterFunc(func(ctx context.Context, id string) ([]byte, error) {
p, err := config.AdminDB.GetProvisioner(ctx, id)
if err != nil {
return nil, err
}
return proto.Marshal(p)
}))
// byName maps a name with a provisioner id, we will manually fill this cache.
byName := config.CachePool.New("provisioner_by_name", cache.GetterFunc(func(ctx context.Context, name string) ([]byte, error) {
return nil, cache.ErrNotFound
}))

return &Collection{
byID: new(sync.Map),
byKey: new(sync.Map),
byName: new(sync.Map),
byTokenID: new(sync.Map),
audiences: audiences,
byID: new(sync.Map),
byKey: new(sync.Map),
byName: new(sync.Map),
byTokenID: new(sync.Map),
audiences: config.Audiences,
byIDCache: byID,
byNameCache: byName,
}
}

// Load a provisioner by the ID.
func (c *Collection) Load(id string) (Interface, bool) {
ctx, cancel := defaultContext()
defer cancel()

b, err := c.byIDCache.Get(ctx, id)
if err != nil {
return nil, false
}

var p linkedca.Provisioner
if err := proto.Unmarshal(b, &p); err != nil {
return nil, false
}
log.Printf("Provisioner.Load(%s): %v", id, p)
return loadProvisioner(c.byID, id)
}

Expand Down Expand Up @@ -180,6 +221,7 @@ func (c *Collection) LoadEncryptedKey(keyID string) (string, bool) {
// Store adds a provisioner to the collection and enforces the uniqueness of
// provisioner IDs.
func (c *Collection) Store(p Interface) error {

// Store provisioner always in byID. ID must be unique.
if _, loaded := c.byID.LoadOrStore(p.GetID(), p); loaded {
return admin.NewError(admin.ErrorBadRequestType,
Expand Down
6 changes: 6 additions & 0 deletions authority/provisioner/provisioner.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"strings"

"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/cache"
"github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/errs"
"golang.org/x/crypto/ssh"
Expand Down Expand Up @@ -214,6 +216,8 @@ type Config struct {
Audiences Audiences
// DB is the interface to the authority DB client.
DB db.AuthDB
// AdminDB is the interface to the administration DB client.
AdminDB admin.DB
// SSHKeys are the root SSH public keys
SSHKeys *SSHKeys
// GetIdentityFunc is a function that returns an identity that will be
Expand All @@ -225,6 +229,8 @@ type Config struct {
// AuthorizeSSHRenewFunc is a function that returns nil if a given SSH
// certificate can be renewed.
AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc
// CachePool is a type that allows to create new caches.
CachePool cache.Pool
}

type provisioner struct {
Expand Down
14 changes: 14 additions & 0 deletions authority/provisioners.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,20 @@ import (
"gopkg.in/square/go-jose.v2/jwt"
)

// Provisioners is the interface used by the provisioners collection.
type Provisioners interface {
Load(id string) (provisioner.Interface, bool)
Store(p provisioner.Interface) error
Update(p provisioner.Interface) error
Remove(id string) error
LoadByName(name string) (provisioner.Interface, bool)
LoadByToken(token *jose.JSONWebToken, claims *jose.Claims) (provisioner.Interface, bool)
LoadByTokenID(tokenProvisionerID string) (provisioner.Interface, bool)
LoadByCertificate(cert *x509.Certificate) (provisioner.Interface, bool)
Find(cursor string, limit int) (provisioner.List, string)
LoadEncryptedKey(keyID string) (string, bool)
}

// GetEncryptedKey returns the JWE key corresponding to the given kid argument.
func (a *Authority) GetEncryptedKey(kid string) (string, error) {
a.adminMutex.RLock()
Expand Down