Skip to content

Commit

Permalink
Check all the C/N and SANs of provided certificates before to generat…
Browse files Browse the repository at this point in the history
…e ACME certificates only for domains which need it
  • Loading branch information
nmengin authored and traefiker committed Feb 26, 2018
1 parent 700b7a1 commit 4c79a9a
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 36 deletions.
18 changes: 18 additions & 0 deletions acme/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,24 @@ func (dc *DomainsCertificates) exists(domainToFind Domain) (*DomainsCertificate,
return nil, false
}

func (dc *DomainsCertificates) toDomainsMap() map[string]*tls.Certificate {
domainsCertificatesMap := make(map[string]*tls.Certificate)
for _, domainCertificate := range dc.Certs {
certKey := domainCertificate.Domains.Main
if domainCertificate.Domains.SANs != nil {
sort.Strings(domainCertificate.Domains.SANs)
for _, dnsName := range domainCertificate.Domains.SANs {
if dnsName != domainCertificate.Domains.Main {
certKey += fmt.Sprintf(",%s", dnsName)
}
}

}
domainsCertificatesMap[certKey] = domainCertificate.tlsCert
}
return domainsCertificatesMap
}

// DomainsCertificate contains a certificate for multiple domains
type DomainsCertificate struct {
Domains Domain
Expand Down
120 changes: 88 additions & 32 deletions acme/acme.go
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ func (a *ACME) getCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificat
domain := types.CanonicalDomain(clientHello.ServerName)
account := a.store.Get().(*Account)

if providedCertificate := a.getProvidedCertificate([]string{domain}); providedCertificate != nil {
if providedCertificate := a.getProvidedCertificate(domain); providedCertificate != nil {
return providedCertificate, nil
}

Expand Down Expand Up @@ -625,11 +625,6 @@ func (a *ACME) LoadCertificateForDomains(domains []string) {

domains = fun.Map(types.CanonicalDomain, domains).([]string)

// Check provided certificates
if a.getProvidedCertificate(domains) != nil {
return
}

operation := func() error {
if a.client == nil {
return errors.New("ACME client still not built")
Expand All @@ -647,32 +642,34 @@ func (a *ACME) LoadCertificateForDomains(domains []string) {
return
}
account := a.store.Get().(*Account)
var domain Domain
if len(domains) > 1 {
domain = Domain{Main: domains[0], SANs: domains[1:]}
} else {
domain = Domain{Main: domains[0]}
}
if _, exists := account.DomainsCertificate.exists(domain); exists {
// domain already exists

// Check provided certificates
uncheckedDomains := a.getUncheckedDomains(domains, account)
if len(uncheckedDomains) == 0 {
return
}
certificate, err := a.getDomainsCertificates(domains)
certificate, err := a.getDomainsCertificates(uncheckedDomains)
if err != nil {
log.Errorf("Error getting ACME certificates %+v : %v", domains, err)
log.Errorf("Error getting ACME certificates %+v : %v", uncheckedDomains, err)
return
}
log.Debugf("Got certificate for domains %+v", domains)
log.Debugf("Got certificate for domains %+v", uncheckedDomains)
transaction, object, err := a.store.Begin()

if err != nil {
log.Errorf("Error creating transaction %+v : %v", domains, err)
log.Errorf("Error creating transaction %+v : %v", uncheckedDomains, err)
return
}
var domain Domain
if len(uncheckedDomains) > 1 {
domain = Domain{Main: uncheckedDomains[0], SANs: uncheckedDomains[1:]}
} else {
domain = Domain{Main: uncheckedDomains[0]}
}
account = object.(*Account)
_, err = account.DomainsCertificate.addCertificateForDomains(certificate, domain)
if err != nil {
log.Errorf("Error adding ACME certificates %+v : %v", domains, err)
log.Errorf("Error adding ACME certificates %+v : %v", uncheckedDomains, err)
return
}
if err = transaction.Commit(account); err != nil {
Expand All @@ -684,36 +681,95 @@ func (a *ACME) LoadCertificateForDomains(domains []string) {

// Get provided certificate which check a domains list (Main and SANs)
// from static and dynamic provided certificates
func (a *ACME) getProvidedCertificate(domains []string) *tls.Certificate {
func (a *ACME) getProvidedCertificate(domains string) *tls.Certificate {
log.Debugf("Looking for provided certificate to validate %s...", domains)
cert := searchProvidedCertificateForDomains(domains, a.TLSConfig.NameToCertificate)
if cert == nil && a.dynamicCerts != nil && a.dynamicCerts.Get() != nil {
cert = searchProvidedCertificateForDomains(domains, a.dynamicCerts.Get().(*traefikTls.DomainsCertificates).Get().(map[string]*tls.Certificate))
}
log.Debugf("No provided certificate found for domains %s, get ACME certificate.", domains)
if cert == nil {
log.Debugf("No provided certificate found for domains %s, get ACME certificate.", domains)
}
return cert
}

func searchProvidedCertificateForDomains(domains []string, certs map[string]*tls.Certificate) *tls.Certificate {
func searchProvidedCertificateForDomains(domain string, certs map[string]*tls.Certificate) *tls.Certificate {
// Use regex to test for provided certs that might have been added into TLSConfig
providedCertMatch := false
for k := range certs {
selector := "^" + strings.Replace(k, "*.", "[^\\.]*\\.?", -1) + "$"
for _, domainToCheck := range domains {
providedCertMatch, _ = regexp.MatchString(selector, domainToCheck)
if !providedCertMatch {
for certDomains := range certs {
domainCheck := false
for _, certDomain := range strings.Split(certDomains, ",") {
selector := "^" + strings.Replace(certDomain, "*.", "[^\\.]*\\.?", -1) + "$"
domainCheck, _ = regexp.MatchString(selector, domain)
if domainCheck {
break
}
}
if providedCertMatch {
log.Debugf("Got provided certificate for domains %s", domains)
return certs[k]

if domainCheck {
log.Debugf("Domain %q checked by provided certificate %q", domain, certDomains)
return certs[certDomains]
}
}
return nil
}

// Get provided certificate which check a domains list (Main and SANs)
// from static and dynamic provided certificates
func (a *ACME) getUncheckedDomains(domains []string, account *Account) []string {
log.Debugf("Looking for provided certificate to validate %s...", domains)
allCerts := make(map[string]*tls.Certificate)

// Get static certificates
for domains, certificate := range a.TLSConfig.NameToCertificate {
allCerts[domains] = certificate
}

// Get dynamic certificates
if a.dynamicCerts != nil && a.dynamicCerts.Get() != nil {
for domains, certificate := range a.dynamicCerts.Get().(*traefikTls.DomainsCertificates).Get().(map[string]*tls.Certificate) {
allCerts[domains] = certificate
}
}

// Get ACME certificates
if account != nil {
for domains, certificate := range account.DomainsCertificate.toDomainsMap() {
allCerts[domains] = certificate
}
}

return searchUncheckedDomains(domains, allCerts)
}

func searchUncheckedDomains(domains []string, certs map[string]*tls.Certificate) []string {
uncheckedDomains := []string{}
for _, domainToCheck := range domains {
domainCheck := false
for certDomains := range certs {
domainCheck = false
for _, certDomain := range strings.Split(certDomains, ",") {
// Use regex to test for provided certs that might have been added into TLSConfig
selector := "^" + strings.Replace(certDomain, "*.", "[^\\.]*\\.?", -1) + "$"
domainCheck, _ = regexp.MatchString(selector, domainToCheck)
if domainCheck {
break
}
}
if domainCheck {
break
}
}
if !domainCheck {
uncheckedDomains = append(uncheckedDomains, domainToCheck)
}
}
if len(uncheckedDomains) == 0 {
log.Debugf("No ACME certificate to generate for domains %q.", domains)
} else {
log.Debugf("Domains %q need ACME certificates generation for domains %q.", domains, strings.Join(uncheckedDomains, ","))
}
return uncheckedDomains
}

func (a *ACME) getDomainsCertificates(domains []string) (*Certificate, error) {
domains = fun.Map(types.CanonicalDomain, domains).([]string)
log.Debugf("Loading ACME certificates %s...", domains)
Expand Down
35 changes: 31 additions & 4 deletions acme/acme_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,17 +281,44 @@ cijFkALeQp/qyeXdFld2v9gUN3eCgljgcl0QweRoIc=---`)
}
}

func TestAcme_getProvidedCertificate(t *testing.T) {
func TestAcme_getUncheckedCertificates(t *testing.T) {
mm := make(map[string]*tls.Certificate)
mm["*.containo.us"] = &tls.Certificate{}
mm["traefik.acme.io"] = &tls.Certificate{}

a := ACME{TLSConfig: &tls.Config{NameToCertificate: mm}}

domains := []string{"traefik.containo.us", "trae.containo.us"}
certificate := a.getProvidedCertificate(domains)
assert.NotNil(t, certificate)
uncheckedDomains := a.getUncheckedDomains(domains, nil)
assert.Empty(t, uncheckedDomains)
domains = []string{"traefik.acme.io", "trae.acme.io"}
certificate = a.getProvidedCertificate(domains)
uncheckedDomains = a.getUncheckedDomains(domains, nil)
assert.Len(t, uncheckedDomains, 1)
domainsCertificates := DomainsCertificates{Certs: []*DomainsCertificate{
{
tlsCert: &tls.Certificate{},
Domains: Domain{
Main: "*.acme.wtf",
SANs: []string{"trae.acme.io"},
},
},
}}
account := Account{DomainsCertificate: domainsCertificates}
uncheckedDomains = a.getUncheckedDomains(domains, &account)
assert.Empty(t, uncheckedDomains)
}

func TestAcme_getProvidedCertificate(t *testing.T) {
mm := make(map[string]*tls.Certificate)
mm["*.containo.us"] = &tls.Certificate{}
mm["traefik.acme.io"] = &tls.Certificate{}

a := ACME{TLSConfig: &tls.Config{NameToCertificate: mm}}

domain := "traefik.containo.us"
certificate := a.getProvidedCertificate(domain)
assert.NotNil(t, certificate)
domain = "trae.acme.io"
certificate = a.getProvidedCertificate(domain)
assert.Nil(t, certificate)
}

0 comments on commit 4c79a9a

Please sign in to comment.