Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add health checks to OIDC discovery provider #3151

Merged
merged 10 commits into from
Jun 22, 2022
122 changes: 70 additions & 52 deletions support/oidc-discovery-provider/README.md

Large diffs are not rendered by default.

37 changes: 37 additions & 0 deletions support/oidc-discovery-provider/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ package main

import (
"crypto/x509"
"sync"
"time"

"github.com/spiffe/spire/pkg/common/pemutil"
"gopkg.in/square/go-jose.v2"
)

var (
Expand All @@ -13,3 +16,37 @@ JBCRRy24/UAZY70ZviCRAJ4ePscJtnN1y1wDH13GgOAL2y52xIbtkshYmw==
-----END PUBLIC KEY-----`))
ec256PubkeyPKIX, _ = x509.MarshalPKIXPublicKey(ec256Pubkey)
)

type FakeKeySetSource struct {
mu sync.Mutex
jwks *jose.JSONWebKeySet
modTime time.Time
pollTime time.Time
}

func (s *FakeKeySetSource) SetKeySet(jwks *jose.JSONWebKeySet, modTime time.Time, pollTime time.Time) {
s.mu.Lock()
defer s.mu.Unlock()
s.jwks = jwks
s.modTime = modTime
s.pollTime = pollTime
}

func (s *FakeKeySetSource) FetchKeySet() (*jose.JSONWebKeySet, time.Time, bool) {
s.mu.Lock()
defer s.mu.Unlock()
if s.jwks == nil {
return nil, time.Time{}, false
}
return s.jwks, s.modTime, true
}

func (s *FakeKeySetSource) Close() error {
return nil
}

func (s *FakeKeySetSource) LastSuccessfulPoll() time.Time {
s.mu.Lock()
defer s.mu.Unlock()
return s.pollTime
}
32 changes: 29 additions & 3 deletions support/oidc-discovery-provider/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@ import (
)

const (
defaultLogLevel = "info"
defaultPollInterval = time.Second * 10
defaultCacheDir = "./.acme-cache"
defaultLogLevel = "info"
defaultPollInterval = time.Second * 10
defaultCacheDir = "./.acme-cache"
defaultHealthChecksBindPort = 8008
defaultHealthChecksReadyPath = "/ready"
defaultHealthChecksLivePath = "/live"
)

type Config struct {
Expand Down Expand Up @@ -60,6 +63,9 @@ type Config struct {
// as the source for the public keys. Only one source can be configured.
WorkloadAPI *WorkloadAPIConfig `hcl:"workload_api"`

// Health checks enable Liveness and Readiness probes.
HealthChecks *HealthChecksConfig `hcl:"health_checks"`

// Experimental options that are subject to change or removal.
Experimental experimentalConfig `hcl:"experimental"`
}
Expand Down Expand Up @@ -125,6 +131,14 @@ type WorkloadAPIConfig struct {
Experimental experimentalWorkloadAPIConfig `hcl:"experimental"`
}

type HealthChecksConfig struct {
// Listener port binding
BindPort int `hcl:"bind_port"`
// Paths for /ready and /live
LivePath string `hcl:"live_path"`
ReadyPath string `hcl:"ready_path"`
}

type experimentalConfig struct {
// ListenNamedPipeName specifies the pipe name of the named pipe
// to listen for plaintext HTTP on, for when deployed behind another
Expand Down Expand Up @@ -201,6 +215,18 @@ func ParseConfig(hclConfig string) (_ *Config, err error) {
methodCount++
}

if c.HealthChecks != nil {
if c.HealthChecks.BindPort <= 0 {
c.HealthChecks.BindPort = defaultHealthChecksBindPort
}
if c.HealthChecks.ReadyPath == "" {
c.HealthChecks.ReadyPath = defaultHealthChecksReadyPath
}
if c.HealthChecks.LivePath == "" {
c.HealthChecks.LivePath = defaultHealthChecksLivePath
}
}

if err := c.validateOS(); err != nil {
return nil, err
}
Expand Down
96 changes: 96 additions & 0 deletions support/oidc-discovery-provider/config_posix_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -388,5 +388,101 @@ func parseConfigCasesOS() []parseConfigCase {
`,
err: "trust_domain must be configured in the workload_api configuration section",
},
{
name: "health checks default values",
in: `
domains = ["domain.test"]
acme {
email = "admin@domain.test"
tos_accepted = true
}
server_api {
address = "unix:///some/socket/path"
}
health_checks {}
`,
out: &Config{
LogLevel: defaultLogLevel,
Domains: []string{"domain.test"},
ACME: &ACMEConfig{
CacheDir: defaultCacheDir,
Email: "admin@domain.test",
ToSAccepted: true,
},
ServerAPI: &ServerAPIConfig{
Address: "unix:///some/socket/path",
PollInterval: defaultPollInterval,
},
HealthChecks: &HealthChecksConfig{
BindPort: defaultHealthChecksBindPort,
ReadyPath: defaultHealthChecksReadyPath,
LivePath: defaultHealthChecksLivePath,
},
},
},
{
name: "health checks config overrides",
in: `
domains = ["domain.test"]
acme {
email = "admin@domain.test"
tos_accepted = true
}
server_api {
address = "unix:///some/socket/path"
}
health_checks {
bind_address = "127.0.0.1"
bind_port = "8888"
live_path = "/live/override"
ready_path = "/ready/override"
}
`,
out: &Config{
LogLevel: defaultLogLevel,
Domains: []string{"domain.test"},
ACME: &ACMEConfig{
CacheDir: defaultCacheDir,
Email: "admin@domain.test",
ToSAccepted: true,
},
ServerAPI: &ServerAPIConfig{
Address: "unix:///some/socket/path",
PollInterval: defaultPollInterval,
},
HealthChecks: &HealthChecksConfig{
BindPort: 8888,
LivePath: "/live/override",
ReadyPath: "/ready/override",
},
},
},
{
name: "health checks disabled",
in: `
domains = ["domain.test"]
acme {
email = "admin@domain.test"
tos_accepted = true
}
server_api {
address = "unix:///some/socket/path"
}
`,
out: &Config{
LogLevel: defaultLogLevel,
Domains: []string{"domain.test"},
ACME: &ACMEConfig{
CacheDir: defaultCacheDir,
Email: "admin@domain.test",
ToSAccepted: true,
},
ServerAPI: &ServerAPIConfig{
Address: "unix:///some/socket/path",
PollInterval: defaultPollInterval,
},
HealthChecks: nil,
},
},
}
}
67 changes: 22 additions & 45 deletions support/oidc-discovery-provider/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package main
import (
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"

Expand All @@ -23,6 +22,7 @@ func TestHandlerHTTPS(t *testing.T) {
path string
jwks *jose.JSONWebKeySet
modTime time.Time
pollTime time.Time
code int
body string
setKeyUse bool
Expand Down Expand Up @@ -167,7 +167,7 @@ func TestHandlerHTTPS(t *testing.T) {
testCase := testCase
t.Run(testCase.name, func(t *testing.T) {
source := new(FakeKeySetSource)
source.SetKeySet(testCase.jwks, testCase.modTime)
source.SetKeySet(testCase.jwks, testCase.modTime, testCase.pollTime)

r, err := http.NewRequest(testCase.method, "https://localhost"+testCase.path, nil)
require.NoError(t, err)
Expand All @@ -187,13 +187,14 @@ func TestHandlerHTTPInsecure(t *testing.T) {
log, _ := test.NewNullLogger()
log.Level = logrus.DebugLevel
testCases := []struct {
name string
method string
path string
jwks *jose.JSONWebKeySet
modTime time.Time
code int
body string
name string
method string
path string
jwks *jose.JSONWebKeySet
modTime time.Time
pollTime time.Time
code int
body string
}{
{
name: "GET well-known",
Expand Down Expand Up @@ -279,7 +280,7 @@ func TestHandlerHTTPInsecure(t *testing.T) {
testCase := testCase
t.Run(testCase.name, func(t *testing.T) {
source := new(FakeKeySetSource)
source.SetKeySet(testCase.jwks, testCase.modTime)
source.SetKeySet(testCase.jwks, testCase.modTime, testCase.pollTime)

r, err := http.NewRequest(testCase.method, "http://localhost"+testCase.path, nil)
require.NoError(t, err)
Expand All @@ -305,6 +306,7 @@ func TestHandlerHTTP(t *testing.T) {
path string
jwks *jose.JSONWebKeySet
modTime time.Time
pollTime time.Time
code int
body string
}{
Expand Down Expand Up @@ -443,7 +445,7 @@ func TestHandlerHTTP(t *testing.T) {
testCase := testCase
t.Run(testCase.name, func(t *testing.T) {
source := new(FakeKeySetSource)
source.SetKeySet(testCase.jwks, testCase.modTime)
source.SetKeySet(testCase.jwks, testCase.modTime, testCase.pollTime)

host := "domain.test"
if testCase.overrideHost != "" {
Expand All @@ -468,13 +470,14 @@ func TestHandlerProxied(t *testing.T) {
log, _ := test.NewNullLogger()
log.Level = logrus.DebugLevel
testCases := []struct {
name string
method string
path string
jwks *jose.JSONWebKeySet
modTime time.Time
code int
body string
name string
method string
path string
jwks *jose.JSONWebKeySet
modTime time.Time
pollTime time.Time
code int
body string
}{
{
name: "GET well-known",
Expand Down Expand Up @@ -560,7 +563,7 @@ func TestHandlerProxied(t *testing.T) {
testCase := testCase
t.Run(testCase.name, func(t *testing.T) {
source := new(FakeKeySetSource)
source.SetKeySet(testCase.jwks, testCase.modTime)
source.SetKeySet(testCase.jwks, testCase.modTime, testCase.pollTime)

r, err := http.NewRequest(testCase.method, "http://localhost"+testCase.path, nil)
require.NoError(t, err)
Expand All @@ -578,32 +581,6 @@ func TestHandlerProxied(t *testing.T) {
}
}

type FakeKeySetSource struct {
mu sync.Mutex
jwks *jose.JSONWebKeySet
modTime time.Time
}

func (s *FakeKeySetSource) SetKeySet(jwks *jose.JSONWebKeySet, modTime time.Time) {
s.mu.Lock()
defer s.mu.Unlock()
s.jwks = jwks
s.modTime = modTime
}

func (s *FakeKeySetSource) FetchKeySet() (*jose.JSONWebKeySet, time.Time, bool) {
s.mu.Lock()
defer s.mu.Unlock()
if s.jwks == nil {
return nil, time.Time{}, false
}
return s.jwks, s.modTime, true
}

func (s *FakeKeySetSource) Close() error {
return nil
}

func domainAllowlist(t *testing.T, domains ...string) DomainPolicy {
policy, err := DomainAllowlist(domains...)
require.NoError(t, err)
Expand Down