Skip to content

Commit

Permalink
Merge branch 'master' into revert
Browse files Browse the repository at this point in the history
  • Loading branch information
wxing1292 committed Dec 10, 2020
2 parents 420255e + ad093b3 commit d5ffcb6
Show file tree
Hide file tree
Showing 6 changed files with 232 additions and 102 deletions.
2 changes: 1 addition & 1 deletion common/authorization/authorizer.go
Expand Up @@ -36,7 +36,7 @@ const (
)

type (
// Attributes is input for authority to make decision.
// CallTarget is input for authorizer to make decision.
// It can be extended in future if required auth on resources like WorkflowType and TaskQueue
CallTarget struct {
APIName string
Expand Down
178 changes: 116 additions & 62 deletions common/rpc/encryption/localStoreCertProvider.go
Expand Up @@ -37,79 +37,30 @@ import (
)

var _ CertProvider = (*localStoreCertProvider)(nil)
var _ ClientCertProvider = (*localStoreCertProvider)(nil)

type localStoreCertProvider struct {
sync.RWMutex

tlsSettings *config.GroupTLS
tlsSettings *config.GroupTLS
workerTLSSettings *config.WorkerTLS

serverCert *tls.Certificate
clientCert *tls.Certificate
clientCAs *x509.CertPool
serverCAs *x509.CertPool

isLegacyWorkerConfig bool
legacyWorkerSettings *config.ClientTLS
}

func (s *localStoreCertProvider) GetSettings() *config.GroupTLS {
return s.tlsSettings
}

func (s *localStoreCertProvider) FetchServerCertificate() (*tls.Certificate, error) {
if s.tlsSettings.Server.CertFile == "" && s.tlsSettings.Server.CertData == "" {
return nil, nil
}

if s.tlsSettings.Server.CertFile != "" && s.tlsSettings.Server.CertData != "" {
return nil, errors.New("Cannot specify both certFile and certData properties")
}

s.RLock()
if s.serverCert != nil {
defer s.RUnlock()
return s.serverCert, nil
}

s.RUnlock()
s.Lock()
defer s.Unlock()

if s.serverCert != nil {
return s.serverCert, nil
}

var certBytes []byte
var keyBytes []byte
var err error

if s.tlsSettings.Server.CertFile != "" {
certBytes, err = ioutil.ReadFile(s.tlsSettings.Server.CertFile)
if err != nil {
return nil, err
}
} else if s.tlsSettings.Server.CertData != "" {
certBytes, err = base64.StdEncoding.DecodeString(s.tlsSettings.Server.CertData)
if err != nil {
return nil, fmt.Errorf("TLS public certificate could not be decoded: %w", err)
}
}

if s.tlsSettings.Server.KeyFile != "" {
keyBytes, err = ioutil.ReadFile(s.tlsSettings.Server.KeyFile)
if err != nil {
return nil, err
}
} else if s.tlsSettings.Server.KeyData != "" {
keyBytes, err = base64.StdEncoding.DecodeString(s.tlsSettings.Server.KeyData)
if err != nil {
return nil, fmt.Errorf("TLS private key could not be decoded: %w", err)
}
}

serverCert, err := tls.X509KeyPair(certBytes, keyBytes)
if err != nil {
return nil, fmt.Errorf("loading server tls certificate failed: %v", err)
}

s.serverCert = &serverCert
return s.serverCert, nil
return s.FetchCertificate(&s.serverCert, s.tlsSettings.Server.CertFile, s.tlsSettings.Server.CertData,
s.tlsSettings.Server.KeyFile, s.tlsSettings.Server.KeyData)
}

func (s *localStoreCertProvider) FetchClientCAs() (*x509.CertPool, error) {
Expand Down Expand Up @@ -154,8 +105,12 @@ func (s *localStoreCertProvider) FetchClientCAs() (*x509.CertPool, error) {
return s.clientCAs, nil
}

func (s *localStoreCertProvider) FetchServerRootCAsForClient() (*x509.CertPool, error) {
if len(s.tlsSettings.Client.RootCAFiles) == 0 && len(s.tlsSettings.Client.RootCAData) == 0 {
func (s *localStoreCertProvider) FetchServerRootCAsForClient(isWorker bool) (*x509.CertPool, error) {
clientSettings := s.getClientTLSSettings(isWorker)
rootCAFiles := clientSettings.RootCAFiles
rootCAData := clientSettings.RootCAData

if len(rootCAFiles) == 0 && len(rootCAData) == 0 {
return nil, nil
}

Expand All @@ -173,12 +128,12 @@ func (s *localStoreCertProvider) FetchServerRootCAsForClient() (*x509.CertPool,
return s.serverCAs, nil
}

serverCAPoolFromFiles, err := buildCAPoolFromFiles(s.tlsSettings.Client.RootCAFiles)
serverCAPoolFromFiles, err := buildCAPoolFromFiles(rootCAFiles)
if err != nil {
return nil, err
}

serverCAPoolFromData, err := buildCAPoolFromData(s.tlsSettings.Client.RootCAData)
serverCAPoolFromData, err := buildCAPoolFromData(rootCAData)
if err != nil {
return nil, err
}
Expand All @@ -196,6 +151,105 @@ func (s *localStoreCertProvider) FetchServerRootCAsForClient() (*x509.CertPool,
return s.serverCAs, nil
}

func (s *localStoreCertProvider) FetchClientCertificate(isWorker bool) (*tls.Certificate, error) {
if isWorker {
return s.fetchWorkerCertificate()
} else {
return s.FetchCertificate(&s.clientCert, s.tlsSettings.Server.CertFile, s.tlsSettings.Server.CertData,
s.tlsSettings.Server.KeyFile, s.tlsSettings.Server.KeyData)
}
}

func (s *localStoreCertProvider) fetchWorkerCertificate() (*tls.Certificate, error) {
if s.isLegacyWorkerConfig {
return s.FetchCertificate(&s.clientCert, s.tlsSettings.Server.CertFile, s.tlsSettings.Server.CertData,
s.tlsSettings.Server.KeyFile, s.tlsSettings.Server.KeyData)
} else {
return s.FetchCertificate(&s.clientCert, s.workerTLSSettings.CertFile, s.workerTLSSettings.CertData,
s.workerTLSSettings.KeyFile, s.workerTLSSettings.KeyData)
}
}

func (s *localStoreCertProvider) FetchCertificate(cachedCert **tls.Certificate,
certFile string, certData string,
keyFile string, keyData string) (*tls.Certificate, error) {
if certFile == "" && certData == "" {
return nil, nil
}

if certFile != "" && certData != "" {
return nil, errors.New("Cannot specify both certFile and certData properties")
}

s.RLock()
if *cachedCert != nil {
defer s.RUnlock()
return *cachedCert, nil
}

s.RUnlock()
s.Lock()
defer s.Unlock()

if *cachedCert != nil {
return *cachedCert, nil
}

var certBytes []byte
var keyBytes []byte
var err error

if certFile != "" {
certBytes, err = ioutil.ReadFile(certFile)
if err != nil {
return nil, err
}
} else if certData != "" {
certBytes, err = base64.StdEncoding.DecodeString(certData)
if err != nil {
return nil, fmt.Errorf("TLS public certificate could not be decoded: %w", err)
}
}

if keyFile != "" {
keyBytes, err = ioutil.ReadFile(keyFile)
if err != nil {
return nil, err
}
} else if keyData != "" {
keyBytes, err = base64.StdEncoding.DecodeString(keyData)
if err != nil {
return nil, fmt.Errorf("TLS private key could not be decoded: %w", err)
}
}

cert, err := tls.X509KeyPair(certBytes, keyBytes)
if err != nil {
return nil, fmt.Errorf("loading tls certificate failed: %v", err)
}

*cachedCert = &cert
return *cachedCert, nil
}

func (s *localStoreCertProvider) ServerName(isWorker bool) string {
return s.getClientTLSSettings(isWorker).ServerName
}

func (s *localStoreCertProvider) DisableHostVerification(isWorker bool) bool {
return s.getClientTLSSettings(isWorker).DisableHostVerification
}

func (s *localStoreCertProvider) getClientTLSSettings(isWorker bool) *config.ClientTLS {
if isWorker && s.workerTLSSettings != nil {
return &s.workerTLSSettings.Client // explicit system worker case
} else if isWorker {
return s.legacyWorkerSettings // legacy config case when we use Frontend.Client settings
} else {
return &s.tlsSettings.Client // internode client case
}
}

func buildCAPoolFromData(caData []string) (*x509.CertPool, error) {
atLeastOneCert := false
caPool := x509.NewCertPool()
Expand Down
39 changes: 28 additions & 11 deletions common/rpc/encryption/localStoreTlsFactory.go
Expand Up @@ -39,8 +39,10 @@ type localStoreTlsProvider struct {

settings *config.RootTLS

internodeCertProvider CertProvider
frontendCertProvider CertProvider
internodeCertProvider CertProvider
internodeClientCertProvider ClientCertProvider
frontendCertProvider CertProvider
workerCertProvider ClientCertProvider

frontendPerHostCertProviderFactory PerHostCertProviderFactory

Expand All @@ -51,9 +53,22 @@ type localStoreTlsProvider struct {
}

func NewLocalStoreTlsProvider(tlsConfig *config.RootTLS) (TLSConfigProvider, error) {
internodeProvider := &localStoreCertProvider{tlsSettings: &tlsConfig.Internode}
var workerProvider ClientCertProvider
if tlsConfig.SystemWorker.CertFile != "" || tlsConfig.SystemWorker.CertData != "" { // explcit system worker config
workerProvider = &localStoreCertProvider{workerTLSSettings: &tlsConfig.SystemWorker}
} else { // legacy implicit system worker config case
internodeWorkerProvider := &localStoreCertProvider{tlsSettings: &tlsConfig.Internode}
internodeWorkerProvider.isLegacyWorkerConfig = true
internodeWorkerProvider.legacyWorkerSettings = &tlsConfig.Frontend.Client
workerProvider = internodeWorkerProvider
}

return &localStoreTlsProvider{
internodeCertProvider: &localStoreCertProvider{tlsSettings: &tlsConfig.Internode},
internodeCertProvider: internodeProvider,
internodeClientCertProvider: internodeProvider,
frontendCertProvider: &localStoreCertProvider{tlsSettings: &tlsConfig.Frontend},
workerCertProvider: workerProvider,
frontendPerHostCertProviderFactory: newLocalStorePerHostCertProviderFactory(tlsConfig.Frontend.PerHostOverrides),
RWMutex: sync.RWMutex{},
settings: tlsConfig,
Expand All @@ -64,7 +79,8 @@ func (s *localStoreTlsProvider) GetInternodeClientConfig() (*tls.Config, error)
return s.getOrCreateConfig(
&s.internodeClientConfig,
func() (*tls.Config, error) {
return newClientTLSConfig(s.internodeCertProvider, s.internodeCertProvider)
return newClientTLSConfig(s.internodeClientCertProvider,
s.internodeCertProvider.GetSettings().Server.RequireClientAuth, false)
},
s.internodeCertProvider.GetSettings().IsEnabled(),
)
Expand All @@ -74,7 +90,8 @@ func (s *localStoreTlsProvider) GetFrontendClientConfig() (*tls.Config, error) {
return s.getOrCreateConfig(
&s.frontendClientConfig,
func() (*tls.Config, error) {
return newClientTLSConfig(s.internodeCertProvider, s.frontendCertProvider)
return newClientTLSConfig(s.workerCertProvider,
s.frontendCertProvider.GetSettings().Server.RequireClientAuth, true)
},
s.internodeCertProvider.GetSettings().IsEnabled(),
)
Expand Down Expand Up @@ -193,17 +210,17 @@ func getServerTLSConfigFromCertProvider(certProvider CertProvider) (*tls.Config,
return auth.NewTLSConfigWithClientAuthAndCAs(clientAuthType, []tls.Certificate{*serverCert}, clientCaPool), nil
}

func newClientTLSConfig(clientProvider CertProvider, remoteProvider CertProvider) (*tls.Config, error) {
func newClientTLSConfig(clientProvider ClientCertProvider, isAuthRequired bool, isWorker bool) (*tls.Config, error) {
// Optional ServerCA for client if not already trusted by host
serverCa, err := remoteProvider.FetchServerRootCAsForClient()
serverCa, err := clientProvider.FetchServerRootCAsForClient(isWorker)
if err != nil {
return nil, fmt.Errorf("failed to load client ca: %v", err)
}

// mTLS enabled, present certificate
var clientCerts []tls.Certificate
if remoteProvider.GetSettings().Server.RequireClientAuth {
cert, err := clientProvider.FetchServerCertificate()
if isAuthRequired {
cert, err := clientProvider.FetchClientCertificate(isWorker)
if err != nil {
return nil, err
}
Expand All @@ -217,7 +234,7 @@ func newClientTLSConfig(clientProvider CertProvider, remoteProvider CertProvider
return auth.NewTLSConfigWithCertsAndCAs(
clientCerts,
serverCa,
remoteProvider.GetSettings().Client.ServerName,
!remoteProvider.GetSettings().Client.DisableHostVerification,
clientProvider.ServerName(isWorker),
!clientProvider.DisableHostVerification(isWorker),
), nil
}
28 changes: 8 additions & 20 deletions common/rpc/encryption/tlsFactory.go
Expand Up @@ -44,38 +44,26 @@ type (
CertProvider interface {
FetchServerCertificate() (*tls.Certificate, error)
FetchClientCAs() (*x509.CertPool, error)
FetchServerRootCAsForClient() (*x509.CertPool, error)
GetSettings() *config.GroupTLS
}

// ClientCertProvider is an interface to load raw TLS/X509 primitives for configuring clients.
ClientCertProvider interface {
FetchClientCertificate(isWorker bool) (*tls.Certificate, error)
FetchServerRootCAsForClient(isWorker bool) (*x509.CertPool, error)
ServerName(isWorker bool) string
DisableHostVerification(isWorker bool) bool
}

// PerHostCertProviderFactory creates a CertProvider in the context of a specific Domain.
PerHostCertProviderFactory interface {
GetCertProvider(hostName string) (CertProvider, error)
}

tlsConfigConstructor func() (*tls.Config, error)

providerType string
)

const (
providerTypeLocalStore providerType = "localstore"
providerTypeSelfSigned providerType = "selfsigned"
)

// NewTLSConfigProviderFromConfig creates a new TLS Config provider from RootTLS config
func NewTLSConfigProviderFromConfig(encryptionSettings config.RootTLS) (TLSConfigProvider, error) {
/* if || encryptionSettings.Provider == "" {
return nil, nil
}
*/

/*switch providerType(encryptionSettings.Provider) {
case providerTypeSelfSigned:
return NewSelfSignedTlsFactory(encryptionSettings, hostname)
case providerTypeLocalStore:*/
return NewLocalStoreTlsProvider(&encryptionSettings)
//}

//return nil, fmt.Errorf("unknown provider: %v", encryptionSettings.Provider)
}

0 comments on commit d5ffcb6

Please sign in to comment.