Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ type Authority interface {
GetTLSOptions() *tlsutil.TLSOptions
Root(shasum string) (*x509.Certificate, error)
Sign(cr *x509.CertificateRequest, signOpts authority.SignOptions, extraOpts ...interface{}) (*x509.Certificate, *x509.Certificate, error)
Renew(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error)
Renew(peer *x509.Certificate) (*x509.Certificate, *x509.Certificate, error)
GetProvisioners(cursor string, limit int) ([]*authority.Provisioner, string, error)
GetEncryptedKey(kid string) (string, error)
GetRoots() (federation []*x509.Certificate, err error)
GetFederation() ([]*x509.Certificate, error)
}

// Certificate wraps a *x509.Certificate and adds the json.Marshaler interface.
Expand Down Expand Up @@ -186,6 +188,16 @@ type SignResponse struct {
TLS *tls.ConnectionState `json:"-"`
}

// RootsResponse is the response object of the roots request.
type RootsResponse struct {
Certificates []Certificate `json:"crts"`
}

// FederationResponse is the response object of the federation request.
type FederationResponse struct {
Certificates []Certificate `json:"crts"`
}

// caHandler is the type used to implement the different CA HTTP endpoints.
type caHandler struct {
Authority Authority
Expand All @@ -205,6 +217,8 @@ func (h *caHandler) Route(r Router) {
r.MethodFunc("POST", "/renew", h.Renew)
r.MethodFunc("GET", "/provisioners", h.Provisioners)
r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", h.ProvisionerKey)
r.MethodFunc("GET", "/roots", h.Roots)
r.MethodFunc("GET", "/federation", h.Federation)
// For compatibility with old code:
r.MethodFunc("POST", "/re-sign", h.Renew)
}
Expand Down Expand Up @@ -320,6 +334,44 @@ func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) {
JSON(w, &ProvisionerKeyResponse{key})
}

// Roots returns all the root certificates for the CA.
func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
roots, err := h.Authority.GetRoots()
if err != nil {
WriteError(w, Forbidden(err))
return
}

certs := make([]Certificate, len(roots))
for i := range roots {
certs[i] = Certificate{roots[i]}
}

w.WriteHeader(http.StatusCreated)
JSON(w, &RootsResponse{
Certificates: certs,
})
}

// Federation returns all the public certificates in the federation.
func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) {
federated, err := h.Authority.GetFederation()
if err != nil {
WriteError(w, Forbidden(err))
return
}

certs := make([]Certificate, len(federated))
for i := range federated {
certs[i] = Certificate{federated[i]}
}

w.WriteHeader(http.StatusCreated)
JSON(w, &FederationResponse{
Certificates: certs,
})
}

func parseCursor(r *http.Request) (cursor string, limit int, err error) {
q := r.URL.Query()
cursor = q.Get("cursor")
Expand Down
108 changes: 108 additions & 0 deletions api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,8 @@ type mockAuthority struct {
renew func(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error)
getProvisioners func(nextCursor string, limit int) ([]*authority.Provisioner, string, error)
getEncryptedKey func(kid string) (string, error)
getRoots func() ([]*x509.Certificate, error)
getFederation func() ([]*x509.Certificate, error)
}

func (m *mockAuthority) Authorize(ott string) ([]interface{}, error) {
Expand Down Expand Up @@ -443,6 +445,20 @@ func (m *mockAuthority) GetEncryptedKey(kid string) (string, error) {
return m.ret1.(string), m.err
}

func (m *mockAuthority) GetRoots() ([]*x509.Certificate, error) {
if m.getRoots != nil {
return m.getRoots()
}
return m.ret1.([]*x509.Certificate), m.err
}

func (m *mockAuthority) GetFederation() ([]*x509.Certificate, error) {
if m.getFederation != nil {
return m.getFederation()
}
return m.ret1.([]*x509.Certificate), m.err
}

func Test_caHandler_Route(t *testing.T) {
type fields struct {
Authority Authority
Expand Down Expand Up @@ -812,3 +828,95 @@ func Test_caHandler_ProvisionerKey(t *testing.T) {
})
}
}

func Test_caHandler_Roots(t *testing.T) {
cs := &tls.ConnectionState{
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
}
tests := []struct {
name string
tls *tls.ConnectionState
cert *x509.Certificate
root *x509.Certificate
err error
statusCode int
}{
{"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
{"no peer certificates", &tls.ConnectionState{}, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
{"fail", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden},
}

expected := []byte(`{"crts":["` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`)

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler)
req := httptest.NewRequest("GET", "http://example.com/roots", nil)
req.TLS = tt.tls
w := httptest.NewRecorder()
h.Roots(w, req)
res := w.Result()

if res.StatusCode != tt.statusCode {
t.Errorf("caHandler.Roots StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
}

body, err := ioutil.ReadAll(res.Body)
res.Body.Close()
if err != nil {
t.Errorf("caHandler.Roots unexpected error = %v", err)
}
if tt.statusCode < http.StatusBadRequest {
if !bytes.Equal(bytes.TrimSpace(body), expected) {
t.Errorf("caHandler.Roots Body = %s, wants %s", body, expected)
}
}
})
}
}

func Test_caHandler_Federation(t *testing.T) {
cs := &tls.ConnectionState{
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
}
tests := []struct {
name string
tls *tls.ConnectionState
cert *x509.Certificate
root *x509.Certificate
err error
statusCode int
}{
{"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
{"no peer certificates", &tls.ConnectionState{}, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
{"fail", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden},
}

expected := []byte(`{"crts":["` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`)

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler)
req := httptest.NewRequest("GET", "http://example.com/federation", nil)
req.TLS = tt.tls
w := httptest.NewRecorder()
h.Federation(w, req)
res := w.Result()

if res.StatusCode != tt.statusCode {
t.Errorf("caHandler.Federation StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
}

body, err := ioutil.ReadAll(res.Body)
res.Body.Close()
if err != nil {
t.Errorf("caHandler.Federation unexpected error = %v", err)
}
if tt.statusCode < http.StatusBadRequest {
if !bytes.Equal(bytes.TrimSpace(body), expected) {
t.Errorf("caHandler.Federation Body = %s, wants %s", body, expected)
}
}
})
}
}
32 changes: 23 additions & 9 deletions authority/authority.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package authority

import (
"crypto/sha256"
realx509 "crypto/x509"
"crypto/x509"
"encoding/hex"
"fmt"
"sync"
Expand All @@ -17,7 +17,7 @@ const legacyAuthority = "step-certificate-authority"
// Authority implements the Certificate Authority internal interface.
type Authority struct {
config *Config
rootX509Crt *realx509.Certificate
rootX509Certs []*x509.Certificate
intermediateIdentity *x509util.Identity
validateOnce bool
certificates *sync.Map
Expand Down Expand Up @@ -79,15 +79,29 @@ func (a *Authority) init() error {
}

var err error
// First load the root using our modified pem/x509 package.
a.rootX509Crt, err = pemutil.ReadCertificate(a.config.Root)
if err != nil {
return err

// Load the root certificates and add them to the certificate store
a.rootX509Certs = make([]*x509.Certificate, len(a.config.Root))
for i, path := range a.config.Root {
crt, err := pemutil.ReadCertificate(path)
if err != nil {
return err
}
// Add root certificate to the certificate map
sum := sha256.Sum256(crt.Raw)
a.certificates.Store(hex.EncodeToString(sum[:]), crt)
a.rootX509Certs[i] = crt
}

// Add root certificate to the certificate map
sum := sha256.Sum256(a.rootX509Crt.Raw)
a.certificates.Store(hex.EncodeToString(sum[:]), a.rootX509Crt)
// Add federated roots
for _, path := range a.config.FederatedRoots {
crt, err := pemutil.ReadCertificate(path)
if err != nil {
return err
}
sum := sha256.Sum256(crt.Raw)
a.certificates.Store(hex.EncodeToString(sum[:]), crt)
}

// Decrypt and load intermediate public / private key pair.
if len(a.config.Password) > 0 {
Expand Down
8 changes: 4 additions & 4 deletions authority/authority_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func testAuthority(t *testing.T) *Authority {
}
c := &Config{
Address: "127.0.0.1:443",
Root: "testdata/secrets/root_ca.crt",
Root: []string{"testdata/secrets/root_ca.crt"},
IntermediateCert: "testdata/secrets/intermediate_ca.crt",
IntermediateKey: "testdata/secrets/intermediate_ca_key",
DNSNames: []string{"test.ca.smallstep.com"},
Expand Down Expand Up @@ -68,7 +68,7 @@ func TestAuthorityNew(t *testing.T) {
"fail bad root": func(t *testing.T) *newTest {
c, err := LoadConfiguration("../ca/testdata/ca.json")
assert.FatalError(t, err)
c.Root = "foo"
c.Root = []string{"foo"}
return &newTest{
config: c,
err: errors.New("open foo failed: no such file or directory"),
Expand Down Expand Up @@ -105,10 +105,10 @@ func TestAuthorityNew(t *testing.T) {
}
} else {
if assert.Nil(t, tc.err) {
sum := sha256.Sum256(auth.rootX509Crt.Raw)
sum := sha256.Sum256(auth.rootX509Certs[0].Raw)
root, ok := auth.certificates.Load(hex.EncodeToString(sum[:]))
assert.Fatal(t, ok)
assert.Equals(t, auth.rootX509Crt, root)
assert.Equals(t, auth.rootX509Certs[0], root)

assert.True(t, auth.initOnce)
assert.NotNil(t, auth.intermediateIdentity)
Expand Down
34 changes: 3 additions & 31 deletions authority/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,38 +33,10 @@ var (
}
)

type duration struct {
time.Duration
}

// MarshalJSON parses a duration string and sets it to the duration.
//
// A duration string is a possibly signed sequence of decimal numbers, each with
// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m".
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
func (d *duration) MarshalJSON() ([]byte, error) {
return json.Marshal(d.String())
}

// UnmarshalJSON parses a duration string and sets it to the duration.
//
// A duration string is a possibly signed sequence of decimal numbers, each with
// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m".
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
func (d *duration) UnmarshalJSON(data []byte) (err error) {
var s string
if err = json.Unmarshal(data, &s); err != nil {
return errors.Wrapf(err, "error unmarshalling %s", data)
}
if d.Duration, err = time.ParseDuration(s); err != nil {
return errors.Wrapf(err, "error parsing %s as duration", s)
}
return
}

// Config represents the CA configuration and it's mapped to a JSON object.
type Config struct {
Root string `json:"root"`
Root multiString `json:"root"`
FederatedRoots []string `json:"federatedRoots"`
IntermediateCert string `json:"crt"`
IntermediateKey string `json:"key"`
Address string `json:"address"`
Expand Down Expand Up @@ -145,7 +117,7 @@ func (c *Config) Validate() error {
case c.Address == "":
return errors.New("address cannot be empty")

case c.Root == "":
case c.Root.HasEmpties():
return errors.New("root cannot be empty")

case c.IntermediateCert == "":
Expand Down
Loading