Skip to content

Commit

Permalink
Merge pull request #7 from puppetlabs/add-labels
Browse files Browse the repository at this point in the history
Add support for labels on the managed secret
  • Loading branch information
Iristyle committed Aug 20, 2020
2 parents f7f27cf + 136112a commit 6835ed2
Show file tree
Hide file tree
Showing 8 changed files with 179 additions and 42 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ A tiny utility for ensuring TLS certificates in Kubernetes are up-to-date.

To use `tlser` in a cluster, include `puppet/tlser:1.1.1` as an `initContainer`, mount a CA cert/key pair as a volume, and specify necessary arguments (`-name` is required) such as
```
tlser -cacert /cert/tls.crt -cakey /cert/tls.key -name app-tls -subject example.com -dns example.com,localhost,app -ip 10.0.0.1 -expire 365
tlser -cacert /cert/tls.crt -cakey /cert/tls.key -name app-tls -subject example.com -dns example.com,localhost,app -ip 10.0.0.1 -expire 365 -label app=myapp -label part-of=myapp
```

When run, `tlser` will check whether a secret exists. If it exists, is not expired or about to expire, and its properties already match the parameters, it won't be regenerated. Otherwise it generates a new certificate and updates or creates the appropriate secret.
Expand Down
47 changes: 47 additions & 0 deletions labels.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package main

import (
"fmt"
"strings"
)

type labels map[string]string

func (f *labels) String() string {
if f == nil {
return ""
}

strArray := make([]string, 0, len(*f))
for k, v := range *f {
strArray = append(strArray, k+"="+v)
}
return strings.Join(strArray, ",")
}

func (f *labels) Set(value string) error {
if *f == nil {
*f = make(labels)
}

pair := strings.SplitN(value, "=", 2)
if len(pair) != 2 {
return fmt.Errorf("label must be in the form <label>=<value>, not %v", value)
}
(*f)[pair[0]] = pair[1]
return nil
}

func (f *labels) Equals(other map[string]string) bool {
if len(*f) != len(other) {
return false
}

// Maps have equal size, so if other contains all our keys and their values match then they're equal.
for k, v := range *f {
if other[k] != v {
return false
}
}
return true
}
54 changes: 54 additions & 0 deletions labels_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package main

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestLabelsString(t *testing.T) {
req := assert.New(t)

l := labels{"key1": "value1", "key2": "value2"}
req.Equal("key1=value1,key2=value2", l.String())

req.Equal("", (&labels{}).String())
req.Equal("", (*labels)(nil).String())
}

func TestLabelsSet(t *testing.T) {
req := assert.New(t)

var l labels
req.Error(l.Set("key"))

req.NoError(l.Set("key=value"))
req.NoError(l.Set("key2=value=more"))
req.Equal(2, len(l))
req.Equal("value", l["key"])
req.Equal("value=more", l["key2"])

l = labels{"key1": "value1", "key2": "value2"}
req.NoError(l.Set("key=value"))
req.Equal(3, len(l))
req.Equal("value", l["key"])
req.Equal("value1", l["key1"])
req.Equal("value2", l["key2"])
}

func TestLabelsEqual(t *testing.T) {
req := assert.New(t)

var l labels
var other map[string]string
req.True(l.Equals(map[string]string{}))
req.True(l.Equals(l))
req.True(l.Equals(labels{}))
req.True(l.Equals(other))

other = map[string]string{"key": "value"}
req.False(l.Equals(other))

req.NoError(l.Set("key=value"))
req.True(l.Equals(other))
}
7 changes: 1 addition & 6 deletions secrets.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,7 @@ type secrets interface {
setSecret(secret *secret, update bool) error
}

func getTLSFromSecret(c secrets, id identifier) (certificate, error) {
secret, err := c.getSecret(id)
if err != nil {
return certificate{}, err
}

func getTLSFromSecret(secret *secret, id identifier) (certificate, error) {
if secret.Type != tlsSecretType {
return certificate{}, fmt.Errorf("secret %v must have type %v, not %v", id, tlsSecretType, secret.Type)
}
Expand Down
33 changes: 4 additions & 29 deletions secrets_test.go
Original file line number Diff line number Diff line change
@@ -1,27 +1,11 @@
package main

import (
"errors"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)

type secretMock struct {
mock.Mock
}

func (m *secretMock) getSecret(id identifier) (*secret, error) {
args := m.Called(id)
return args.Get(0).(*secret), args.Error(1)
}

func (m *secretMock) setSecret(secret *secret, update bool) error {
args := m.Called(secret, update)
return args.Error(0)
}

var (
testCert = `-----BEGIN CERTIFICATE-----
MIIC5zCCAc+gAwIBAgIJAOC7Munm9txXMA0GCSqGSIb3DQEBBQUAMBUxEzARBgNV
Expand Down Expand Up @@ -73,32 +57,23 @@ Sucz81ym6QREo7DZ4lDXuz5PhPW4KLeoWRw8syyraVQ/o6RsbHQ1
func TestGetTLSFromSecret(t *testing.T) {
req := assert.New(t)
id := identifier{name: "foo", namespace: "default"}
var m1, m2 secretMock

m1.On("getSecret", id).Return((*secret)(nil), errors.New("failed"))
_, err := getTLSFromSecret(&m1, id)
req.Error(err)

var secret secret
m2.On("getSecret", id).Return(&secret, nil)
_, err = getTLSFromSecret(&m2, id)
_, err := getTLSFromSecret(&secret, id)
req.Error(err)

secret.Type = tlsSecretType
_, err = getTLSFromSecret(&m2, id)
_, err = getTLSFromSecret(&secret, id)
req.Error(err)

secret.Data = make(map[string][]byte)
secret.Data["tls.crt"] = []byte(testCert)
_, err = getTLSFromSecret(&m2, id)
_, err = getTLSFromSecret(&secret, id)
req.Error(err)

secret.Data["tls.key"] = []byte(testKey)
cert, err := getTLSFromSecret(&m2, id)
cert, err := getTLSFromSecret(&secret, id)
req.NoError(err)
req.NotNil(cert.cert)
req.NotNil(cert.key)

m1.AssertExpectations(t)
m2.AssertExpectations(t)
}
25 changes: 22 additions & 3 deletions syncer.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ type syncer struct {
subject string
ip, dns []string
daysValid int
labels labels
getSigner func() (certificate, error)
}

Expand All @@ -24,14 +25,31 @@ func (s syncer) sync() error {
return fmt.Errorf("unable to get signing certificate: %w", err)
}

previous, err := getTLSFromSecret(s.secrets, s.id)
priorSecret, err := s.secrets.getSecret(s.id)
if err != nil && !k8errors.IsNotFound(err) {
return fmt.Errorf("unable to retrieve secret %v: %w", s.id, err)
}

var previous certificate
if priorSecret != nil {
if previous, err = getTLSFromSecret(priorSecret, s.id); err != nil {
return err
}
}

// Check whether it needs to be updated.
if previous.cert != nil && previous.isValid(signer) && previous.inSync(s.subject, s.ip, s.dns) {
log.Print("Previous cert matches parameters, no update performed.")
if priorSecret != nil && previous.isValid(signer) && previous.inSync(s.subject, s.ip, s.dns) {
if s.labels.Equals(priorSecret.Labels) {
log.Print("Previous secret matches parameters, no update performed.")
return nil
}

log.Printf("Labels out-of-sync: %+v, %+v", priorSecret.Labels, s.labels)
log.Printf("Updating labels on secret %v", s.id)
priorSecret.Labels = s.labels
if err := s.secrets.setSecret(priorSecret, true); err != nil {
return fmt.Errorf("unable to update secret %v: %w", s.id, err)
}
return nil
}

Expand Down Expand Up @@ -62,6 +80,7 @@ func (s syncer) sync() error {
secret.Namespace = s.id.namespace
secret.Data = map[string][]byte{"tls.crt": []byte(cert), "tls.key": []byte(key)}
secret.Type = tlsSecretType
secret.Labels = s.labels
if err := s.secrets.setSecret(&secret, previous.cert != nil); err != nil {
return fmt.Errorf("unable to update secret %v: %w", s.id, err)
}
Expand Down
50 changes: 47 additions & 3 deletions syncer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,20 @@ import (
"k8s.io/apimachinery/pkg/runtime/schema"
)

type secretMock struct {
mock.Mock
}

func (m *secretMock) getSecret(id identifier) (*secret, error) {
args := m.Called(id)
return args.Get(0).(*secret), args.Error(1)
}

func (m *secretMock) setSecret(secret *secret, update bool) error {
args := m.Called(secret, update)
return args.Error(0)
}

func TestSync(t *testing.T) {
const (
cn = "foo.com"
Expand All @@ -34,26 +48,30 @@ func TestSync(t *testing.T) {
getSigner: func() (certificate, error) { return signer, nil },
}

var m1, m2, m3, m4, m5 secretMock
var m1, m2, m3, m4, m5, m6, m7, m8 secretMock

// Errors if getting the secret fails
m1.On("getSecret", sync.id).Return((*secret)(nil), errors.New("failed"))
sync.secrets = &m1
req.Error(sync.sync())
m1.AssertExpectations(t)

// Creates the secret if NotFound
notFound := k8errors.NewNotFound(schema.GroupResource{Resource: "Secret"}, sync.id.name)
m2.On("getSecret", sync.id).Return((*secret)(nil), notFound)
m2.On("setSecret", mock.AnythingOfType("*v1.Secret"), false).Return(nil)
sync.secrets = &m2
req.NoError(sync.sync())
m2.AssertExpectations(t)

// Errors if creating the secret fails
m3.On("getSecret", sync.id).Return((*secret)(nil), notFound)
m3.On("setSecret", mock.AnythingOfType("*v1.Secret"), false).Return(errors.New("failed"))
sync.secrets = &m3
req.Error(sync.sync())
m3.AssertExpectations(t)

// Does not update the secret if in-sync
key, err := rsa.GenerateKey(rand.Reader, 2048)
req.NoError(err)
certBytes, keyBytes, err := generateSignedCert(cn, []string{ip1, ip2}, []string{dns1, dns2}, 50, key, signer)
Expand All @@ -69,13 +87,39 @@ func TestSync(t *testing.T) {
req.NoError(sync.sync())
m4.AssertExpectations(t)

// Updates the secret if certs not in-sync
req.NoError(err)
certBytes, keyBytes, err = generateSignedCert(cn, []string{}, []string{}, 50, key, signer)
newCertBytes, newKeyBytes, err := generateSignedCert(cn, []string{}, []string{}, 50, key, signer)
req.NoError(err)
secret.Data = map[string][]byte{"tls.crt": []byte(certBytes), "tls.key": []byte(keyBytes)}
secret.Data = map[string][]byte{"tls.crt": []byte(newCertBytes), "tls.key": []byte(newKeyBytes)}
m5.On("getSecret", sync.id).Return(&secret, nil)
m5.On("setSecret", mock.AnythingOfType("*v1.Secret"), true).Return(nil)
sync.secrets = &m5
req.NoError(sync.sync())
m5.AssertExpectations(t)

// Updates the secret if labels missing
secret.Data = map[string][]byte{"tls.crt": []byte(certBytes), "tls.key": []byte(keyBytes)}
sync.labels = labels{"key": "value"}
m6.On("getSecret", sync.id).Return(&secret, nil)
m6.On("setSecret", mock.AnythingOfType("*v1.Secret"), true).Return(nil)
sync.secrets = &m6
req.NoError(sync.sync())
m6.AssertExpectations(t)

// Updates the secret if labels have different values
sync.labels = labels{"key": "value"}
secret.Labels["key"] = "other value"
m7.On("getSecret", sync.id).Return(&secret, nil)
m7.On("setSecret", mock.AnythingOfType("*v1.Secret"), true).Return(nil)
sync.secrets = &m7
req.NoError(sync.sync())
m7.AssertExpectations(t)

// Does not update the secret if labels in-sync
secret.Labels = sync.labels
m8.On("getSecret", sync.id).Return(&secret, nil)
sync.secrets = &m8
req.NoError(sync.sync())
m8.AssertExpectations(t)
}
3 changes: 3 additions & 0 deletions tlser.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@ var (

k8sName = flag.String("name", "", "Name of the Kubernetes secret to update")
k8sNs = flag.String("namespace", "default", "Namespace of the Kubernetes secret to update")
label = labels{}
interval = flag.String("interval", "", "Interval to check if cert is insync (ex: 1h, 30m)")
)

func main() {
flag.Var(&label, "label", "Specify a label as key=value to put on the generated secret; can appear repeatedly for multiple labels")
log.SetFlags(0)
flag.Parse()

Expand Down Expand Up @@ -104,6 +106,7 @@ func main() {
ip: ipStrings,
dns: dnsStrings,
daysValid: *expire,
labels: label,
getSigner: func() (certificate, error) { return readCa(*cacrt, *cakey) },
}

Expand Down

0 comments on commit 6835ed2

Please sign in to comment.