Skip to content

Commit

Permalink
feat: Support SANs (#157)
Browse files Browse the repository at this point in the history
* Initial support for SANs

* Fix tests

* Update test to include SANs
  • Loading branch information
soerenschneider committed Jun 17, 2022
1 parent ba71a3d commit c2953ee
Show file tree
Hide file tree
Showing 9 changed files with 105 additions and 39 deletions.
12 changes: 10 additions & 2 deletions contrib/server.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
{
"domains": [
"domain1.tld",
"domain2.tld"
{
"domain": "domain1.tld",
"sans": [
"domain3.tld",
"domain4.tld"
]
},
{
"domain": "domain2.tld"
}
],
"email": "my@email.tld",
"vaultAddr": "https://vault:8200",
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module github.com/soerenschneider/acmevault
go 1.16

require (
github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d
github.com/aws/aws-sdk-go v1.44.32
github.com/blushft/go-diagrams v0.0.0-20201006005127-c78c821223d9
github.com/go-acme/lego/v4 v4.7.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ github.com/armon/go-metrics v0.3.9/go.mod h1:4O98XIr/9W0sxpJ8UaYkvjk10Iff7SnFrb4
github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8=
github.com/armon/go-radix v1.0.0 h1:F4z6KzEeeQIMeLFa97iZU6vupzoecKdU5TX24SNppXI=
github.com/armon/go-radix v1.0.0/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8=
github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d h1:Byv0BzEl3/e6D5CLfI0j/7hiIEtvGVFPCZ7Ei2oq8iQ=
github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw=
github.com/awalterschulze/gographviz v0.0.0-20200901124122-0eecad45bd71 h1:m3N1Fv5vE5IcxuTOGFGGV0grrVFHV8UY2SV0wSBXAC8=
github.com/awalterschulze/gographviz v0.0.0-20200901124122-0eecad45bd71/go.mod h1:/ynarkO/43wP/JM2Okn61e8WFMtdbtA8he7GJxW+SFM=
github.com/aws/aws-sdk-go v1.39.0/go.mod h1:hcU610XS61/+aQV88ixoOzUoG7v3b31pl2zKMmprdro=
Expand Down
81 changes: 61 additions & 20 deletions internal/config/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/asaskevich/govalidator"
"github.com/rs/zerolog/log"
"io/ioutil"
"net/url"
Expand All @@ -22,9 +23,35 @@ var (
type AcmeVaultServerConfig struct {
VaultConfig
AcmeConfig
IntervalSeconds int `json:"intervalSeconds"`
Domains []string `json:"domains"`
MetricsAddr string `json:"metricsAddr"`
IntervalSeconds int `json:"intervalSeconds"`
Domains []AcmeServerDomains `json:"domains"`
MetricsAddr string `json:"metricsAddr"`
}

type AcmeServerDomains struct {
Domain string `json:"domain"`
Sans []string `json:"sans,omitempty"`
}

func (a AcmeServerDomains) Verify() error {
if ok := govalidator.IsDNSName(a.Domain); !ok {
return fmt.Errorf("invalid domain name: '%s' is not a domain name", a.Domain)
}
for _, domain := range a.Sans {
if ok := govalidator.IsDNSName(domain); !ok {
return fmt.Errorf("invalid sans domain name: '%s' is not a domain name", domain)
}
}

return nil
}

func (a AcmeServerDomains) String() string {
if len(a.Sans) > 0 {
return fmt.Sprintf("%s (%v)", a.Domain, a.Sans)
}

return a.Domain
}

type AcmeConfig struct {
Expand All @@ -33,14 +60,7 @@ type AcmeConfig struct {
AcmeDnsProvider string `json:"acmeDnsProvider"`
}

func isValidEmail(email string) bool {
if len(email) < 3 || len(email) > 254 {
return false
}
return emailRegex.MatchString(email)
}

func (conf AcmeVaultServerConfig) Validate() error {
func (conf AcmeConfig) Validate() error {
if len(conf.AcmeDnsProvider) == 0 {
return errors.New("field `acmeDnsProvider` not configured")
}
Expand All @@ -49,14 +69,20 @@ func (conf AcmeVaultServerConfig) Validate() error {
return fmt.Errorf("could not parse `acmeDnsProvider`: %v", err)
}

if len(conf.Domains) == 0 {
return errors.New("field `domains` not configured")
if !govalidator.IsEmail(conf.Email) {
return fmt.Errorf("field `email` not configured (correctly): %s", conf.Email)
}

if !isValidEmail(conf.Email) {
return errors.New("field `email` not configured (correctly)")
}
return nil
}

func (conf AcmeConfig) Print() {
log.Info().Msgf("AcmeEmail=%s", conf.Email)
log.Info().Msgf("AcmeUrl=%s", conf.AcmeUrl)
log.Info().Msgf("AcmeDnsProvider=%s", conf.AcmeDnsProvider)
}

func (conf AcmeVaultServerConfig) Validate() error {
if conf.IntervalSeconds < 0 {
return fmt.Errorf("field `intervalSeconds` can not be a negative number: %d", conf.IntervalSeconds)
}
Expand All @@ -65,18 +91,33 @@ func (conf AcmeVaultServerConfig) Validate() error {
return fmt.Errorf("field `intervalSeconds` shouldn't be > 86400: %d", conf.IntervalSeconds)
}

if len(conf.Domains) == 0 {
return errors.New("no domains configured")
}
for _, domain := range conf.Domains {
err := domain.Verify()
if err != nil {
return err
}
}
err := conf.AcmeConfig.Validate()
if err != nil {
return err
}

return conf.VaultConfig.Validate()
}

func (conf AcmeVaultServerConfig) Print() {
log.Info().Msg("--- Server Config Start ---")
conf.VaultConfig.Print()
log.Info().Msgf("AcmeDomains=%s", conf.Domains)
log.Info().Msgf("AcmeEmail=%s", conf.Email)
log.Info().Msgf("AcmeUrl=%s", conf.AcmeUrl)
conf.AcmeConfig.Print()
for index, domain := range conf.Domains {
log.Info().Msgf("AcmeDomains[%d]=%s", index, domain.String())
}

log.Info().Msgf("IntervalSeconds=%d", conf.IntervalSeconds)
log.Info().Msgf("MetricsAddr=%s", conf.MetricsAddr)
log.Info().Msgf("AcmeDnsProvider=%s", conf.AcmeDnsProvider)
log.Info().Msg("--- Server Config End ---")
}

Expand Down
12 changes: 10 additions & 2 deletions internal/config/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,16 @@ func TestAcmeVaultServerConfigFromFile(t *testing.T) {
AcmeDnsProvider: "",
},
IntervalSeconds: 43200,
Domains: []string{"domain1.tld", "domain2.tld"},
MetricsAddr: "127.0.0.1:9112",
Domains: []AcmeServerDomains{
{
Domain: "domain1.tld",
Sans: []string{"domain3.tld", "domain4.tld"},
},
{
Domain: "domain2.tld",
},
},
MetricsAddr: "127.0.0.1:9112",
},
wantErr: false,
},
Expand Down
6 changes: 4 additions & 2 deletions internal/server/acme/lego.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,11 @@ func (l *GoLego) RegisterAccount() (*registration.Resource, error) {
return l.client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})
}

func (l *GoLego) ObtainCert(domain string) (*certstorage.AcmeCertificate, error) {
func (l *GoLego) ObtainCert(domain config.AcmeServerDomains) (*certstorage.AcmeCertificate, error) {
domains := []string{domain.Domain}
domains = append(domains, domain.Sans...)
request := certificate.ObtainRequest{
Domains: []string{domain},
Domains: domains,
Bundle: true,
}

Expand Down
3 changes: 2 additions & 1 deletion internal/server/acme/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"crypto"
"github.com/go-acme/lego/v4/certcrypto"
"github.com/go-acme/lego/v4/registration"
"github.com/soerenschneider/acmevault/internal/config"
"github.com/soerenschneider/acmevault/pkg/certstorage"
)

Expand All @@ -14,7 +15,7 @@ const (

type AcmeDealer interface {
RegisterAccount() (*registration.Resource, error)
ObtainCert(domain string) (*certstorage.AcmeCertificate, error)
ObtainCert(domain config.AcmeServerDomains) (*certstorage.AcmeCertificate, error)
RenewCert(cert *certstorage.AcmeCertificate) (*certstorage.AcmeCertificate, error)
}

Expand Down
21 changes: 11 additions & 10 deletions internal/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"github.com/rs/zerolog/log"
"github.com/soerenschneider/acmevault/internal/config"
"github.com/soerenschneider/acmevault/internal/metrics"
"github.com/soerenschneider/acmevault/internal/server/acme"
"github.com/soerenschneider/acmevault/pkg/certstorage"
Expand All @@ -19,11 +20,11 @@ const (
type AcmeVaultServer struct {
acmeClient acme.AcmeDealer
certStorage certstorage.CertStorage
domains []string
domains []config.AcmeServerDomains
}

func NewAcmeVaultServer(domains []string, acmeClient acme.AcmeDealer, storage certstorage.CertStorage) (*AcmeVaultServer, error) {
if nil == domains || len(domains) == 0 {
func NewAcmeVaultServer(domains []config.AcmeServerDomains, acmeClient acme.AcmeDealer, storage certstorage.CertStorage) (*AcmeVaultServer, error) {
if len(domains) == 0 {
return nil, errors.New("no domains given")
}

Expand Down Expand Up @@ -54,17 +55,17 @@ func (c *AcmeVaultServer) CheckCerts() {
c.certStorage.Logout()
}

func (c *AcmeVaultServer) obtainAndHandleCert(domain string) error {
log.Info().Msgf("Trying to read certificate data for domain %s from storage", domain)
read, err := c.certStorage.ReadPublicCertificateData(domain)
func (c *AcmeVaultServer) obtainAndHandleCert(domain config.AcmeServerDomains) error {
log.Info().Msgf("Trying to read certificate data for domain %s from storage", domain.Domain)
read, err := c.certStorage.ReadPublicCertificateData(domain.Domain)
if err != nil || read == nil {
log.Error().Msgf("Error reading cert data from storage for domain %s: %v", domain, err)
log.Info().Msgf("Trying to obtain cert from configured ACME provider for domain %s", domain)
log.Error().Msgf("Error reading cert data from storage for domain %s: %v", domain.Domain, err)
log.Info().Msgf("Trying to obtain cert from configured ACME provider for domain %s", domain.Domain)
obtained, err := c.acmeClient.ObtainCert(domain)
metrics.CertificatesRetrieved.Inc()
if err != nil {
metrics.CertificatesRetrievalErrors.Inc()
return fmt.Errorf("obtaining cert for domain %s failed: %v", domain, err)
return fmt.Errorf("obtaining cert for domain %s failed: %v", domain.Domain, err)
}
return handleReceivedCert(obtained, c.certStorage)
}
Expand All @@ -76,7 +77,7 @@ func (c *AcmeVaultServer) obtainAndHandleCert(domain string) error {
} else {
timeLeft := expiry.Sub(time.Now().UTC())
if timeLeft > MinCertLifetime {
metrics.CertServerExpiryTimestamp.WithLabelValues(domain).Set(float64(expiry.Unix()))
metrics.CertServerExpiryTimestamp.WithLabelValues(domain.Domain).Set(float64(expiry.Unix()))
log.Info().Msgf("Not renewing cert for domain %s, still valid for %v", domain, timeLeft)
return nil
}
Expand Down
6 changes: 4 additions & 2 deletions internal/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package server

import (
"github.com/go-acme/lego/v4/registration"
"github.com/soerenschneider/acmevault/internal/config"
"github.com/soerenschneider/acmevault/pkg/certstorage"
"github.com/stretchr/testify/mock"
"testing"
Expand All @@ -13,7 +14,7 @@ func TestServerHappyPathRenewal(t *testing.T) {
server := AcmeVaultServer{
acmeClient: dealer,
certStorage: certStorage,
domains: []string{"domain"},
domains: []config.AcmeServerDomains{{Domain: "example.com"}},
}

old := &certstorage.AcmeCertificate{}
Expand All @@ -38,7 +39,8 @@ func (m *MockAcmeDealer) RegisterAccount() (*registration.Resource, error) {
}
return args.Get(0).(*registration.Resource), args.Error(1)
}
func (m *MockAcmeDealer) ObtainCert(domain string) (*certstorage.AcmeCertificate, error) {

func (m *MockAcmeDealer) ObtainCert(domains config.AcmeServerDomains) (*certstorage.AcmeCertificate, error) {
args := m.Called()
if nil == args.Get(0) {
return nil, args.Error(1)
Expand Down

0 comments on commit c2953ee

Please sign in to comment.