Skip to content

Commit

Permalink
Add health checks to OIDC discovery provider (spiffe#3151)
Browse files Browse the repository at this point in the history
Signed-off-by: Christoph Dalski <chdalski.coding@gmail.com>
  • Loading branch information
chdalski authored and stevend-uber committed Oct 13, 2023
1 parent 13d980a commit 956443c
Show file tree
Hide file tree
Showing 11 changed files with 518 additions and 110 deletions.
122 changes: 70 additions & 52 deletions support/oidc-discovery-provider/README.md

Large diffs are not rendered by default.

37 changes: 37 additions & 0 deletions support/oidc-discovery-provider/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ package main

import (
"crypto/x509"
"sync"
"time"

"github.com/spiffe/spire/pkg/common/pemutil"
"gopkg.in/square/go-jose.v2"
)

var (
Expand All @@ -13,3 +16,37 @@ JBCRRy24/UAZY70ZviCRAJ4ePscJtnN1y1wDH13GgOAL2y52xIbtkshYmw==
-----END PUBLIC KEY-----`))
ec256PubkeyPKIX, _ = x509.MarshalPKIXPublicKey(ec256Pubkey)
)

type FakeKeySetSource struct {
mu sync.Mutex
jwks *jose.JSONWebKeySet
modTime time.Time
pollTime time.Time
}

func (s *FakeKeySetSource) SetKeySet(jwks *jose.JSONWebKeySet, modTime time.Time, pollTime time.Time) {
s.mu.Lock()
defer s.mu.Unlock()
s.jwks = jwks
s.modTime = modTime
s.pollTime = pollTime
}

func (s *FakeKeySetSource) FetchKeySet() (*jose.JSONWebKeySet, time.Time, bool) {
s.mu.Lock()
defer s.mu.Unlock()
if s.jwks == nil {
return nil, time.Time{}, false
}
return s.jwks, s.modTime, true
}

func (s *FakeKeySetSource) Close() error {
return nil
}

func (s *FakeKeySetSource) LastSuccessfulPoll() time.Time {
s.mu.Lock()
defer s.mu.Unlock()
return s.pollTime
}
32 changes: 29 additions & 3 deletions support/oidc-discovery-provider/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@ import (
)

const (
defaultLogLevel = "info"
defaultPollInterval = time.Second * 10
defaultCacheDir = "./.acme-cache"
defaultLogLevel = "info"
defaultPollInterval = time.Second * 10
defaultCacheDir = "./.acme-cache"
defaultHealthChecksBindPort = 8008
defaultHealthChecksReadyPath = "/ready"
defaultHealthChecksLivePath = "/live"
)

type Config struct {
Expand Down Expand Up @@ -60,6 +63,9 @@ type Config struct {
// as the source for the public keys. Only one source can be configured.
WorkloadAPI *WorkloadAPIConfig `hcl:"workload_api"`

// Health checks enable Liveness and Readiness probes.
HealthChecks *HealthChecksConfig `hcl:"health_checks"`

// Experimental options that are subject to change or removal.
Experimental experimentalConfig `hcl:"experimental"`
}
Expand Down Expand Up @@ -125,6 +131,14 @@ type WorkloadAPIConfig struct {
Experimental experimentalWorkloadAPIConfig `hcl:"experimental"`
}

type HealthChecksConfig struct {
// Listener port binding
BindPort int `hcl:"bind_port"`
// Paths for /ready and /live
LivePath string `hcl:"live_path"`
ReadyPath string `hcl:"ready_path"`
}

type experimentalConfig struct {
// ListenNamedPipeName specifies the pipe name of the named pipe
// to listen for plaintext HTTP on, for when deployed behind another
Expand Down Expand Up @@ -201,6 +215,18 @@ func ParseConfig(hclConfig string) (_ *Config, err error) {
methodCount++
}

if c.HealthChecks != nil {
if c.HealthChecks.BindPort <= 0 {
c.HealthChecks.BindPort = defaultHealthChecksBindPort
}
if c.HealthChecks.ReadyPath == "" {
c.HealthChecks.ReadyPath = defaultHealthChecksReadyPath
}
if c.HealthChecks.LivePath == "" {
c.HealthChecks.LivePath = defaultHealthChecksLivePath
}
}

if err := c.validateOS(); err != nil {
return nil, err
}
Expand Down
96 changes: 96 additions & 0 deletions support/oidc-discovery-provider/config_posix_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -388,5 +388,101 @@ func parseConfigCasesOS() []parseConfigCase {
`,
err: "trust_domain must be configured in the workload_api configuration section",
},
{
name: "health checks default values",
in: `
domains = ["domain.test"]
acme {
email = "admin@domain.test"
tos_accepted = true
}
server_api {
address = "unix:///some/socket/path"
}
health_checks {}
`,
out: &Config{
LogLevel: defaultLogLevel,
Domains: []string{"domain.test"},
ACME: &ACMEConfig{
CacheDir: defaultCacheDir,
Email: "admin@domain.test",
ToSAccepted: true,
},
ServerAPI: &ServerAPIConfig{
Address: "unix:///some/socket/path",
PollInterval: defaultPollInterval,
},
HealthChecks: &HealthChecksConfig{
BindPort: defaultHealthChecksBindPort,
ReadyPath: defaultHealthChecksReadyPath,
LivePath: defaultHealthChecksLivePath,
},
},
},
{
name: "health checks config overrides",
in: `
domains = ["domain.test"]
acme {
email = "admin@domain.test"
tos_accepted = true
}
server_api {
address = "unix:///some/socket/path"
}
health_checks {
bind_address = "127.0.0.1"
bind_port = "8888"
live_path = "/live/override"
ready_path = "/ready/override"
}
`,
out: &Config{
LogLevel: defaultLogLevel,
Domains: []string{"domain.test"},
ACME: &ACMEConfig{
CacheDir: defaultCacheDir,
Email: "admin@domain.test",
ToSAccepted: true,
},
ServerAPI: &ServerAPIConfig{
Address: "unix:///some/socket/path",
PollInterval: defaultPollInterval,
},
HealthChecks: &HealthChecksConfig{
BindPort: 8888,
LivePath: "/live/override",
ReadyPath: "/ready/override",
},
},
},
{
name: "health checks disabled",
in: `
domains = ["domain.test"]
acme {
email = "admin@domain.test"
tos_accepted = true
}
server_api {
address = "unix:///some/socket/path"
}
`,
out: &Config{
LogLevel: defaultLogLevel,
Domains: []string{"domain.test"},
ACME: &ACMEConfig{
CacheDir: defaultCacheDir,
Email: "admin@domain.test",
ToSAccepted: true,
},
ServerAPI: &ServerAPIConfig{
Address: "unix:///some/socket/path",
PollInterval: defaultPollInterval,
},
HealthChecks: nil,
},
},
}
}
67 changes: 22 additions & 45 deletions support/oidc-discovery-provider/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package main
import (
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"

Expand All @@ -23,6 +22,7 @@ func TestHandlerHTTPS(t *testing.T) {
path string
jwks *jose.JSONWebKeySet
modTime time.Time
pollTime time.Time
code int
body string
setKeyUse bool
Expand Down Expand Up @@ -167,7 +167,7 @@ func TestHandlerHTTPS(t *testing.T) {
testCase := testCase
t.Run(testCase.name, func(t *testing.T) {
source := new(FakeKeySetSource)
source.SetKeySet(testCase.jwks, testCase.modTime)
source.SetKeySet(testCase.jwks, testCase.modTime, testCase.pollTime)

r, err := http.NewRequest(testCase.method, "https://localhost"+testCase.path, nil)
require.NoError(t, err)
Expand All @@ -187,13 +187,14 @@ func TestHandlerHTTPInsecure(t *testing.T) {
log, _ := test.NewNullLogger()
log.Level = logrus.DebugLevel
testCases := []struct {
name string
method string
path string
jwks *jose.JSONWebKeySet
modTime time.Time
code int
body string
name string
method string
path string
jwks *jose.JSONWebKeySet
modTime time.Time
pollTime time.Time
code int
body string
}{
{
name: "GET well-known",
Expand Down Expand Up @@ -279,7 +280,7 @@ func TestHandlerHTTPInsecure(t *testing.T) {
testCase := testCase
t.Run(testCase.name, func(t *testing.T) {
source := new(FakeKeySetSource)
source.SetKeySet(testCase.jwks, testCase.modTime)
source.SetKeySet(testCase.jwks, testCase.modTime, testCase.pollTime)

r, err := http.NewRequest(testCase.method, "http://localhost"+testCase.path, nil)
require.NoError(t, err)
Expand All @@ -305,6 +306,7 @@ func TestHandlerHTTP(t *testing.T) {
path string
jwks *jose.JSONWebKeySet
modTime time.Time
pollTime time.Time
code int
body string
}{
Expand Down Expand Up @@ -443,7 +445,7 @@ func TestHandlerHTTP(t *testing.T) {
testCase := testCase
t.Run(testCase.name, func(t *testing.T) {
source := new(FakeKeySetSource)
source.SetKeySet(testCase.jwks, testCase.modTime)
source.SetKeySet(testCase.jwks, testCase.modTime, testCase.pollTime)

host := "domain.test"
if testCase.overrideHost != "" {
Expand All @@ -468,13 +470,14 @@ func TestHandlerProxied(t *testing.T) {
log, _ := test.NewNullLogger()
log.Level = logrus.DebugLevel
testCases := []struct {
name string
method string
path string
jwks *jose.JSONWebKeySet
modTime time.Time
code int
body string
name string
method string
path string
jwks *jose.JSONWebKeySet
modTime time.Time
pollTime time.Time
code int
body string
}{
{
name: "GET well-known",
Expand Down Expand Up @@ -560,7 +563,7 @@ func TestHandlerProxied(t *testing.T) {
testCase := testCase
t.Run(testCase.name, func(t *testing.T) {
source := new(FakeKeySetSource)
source.SetKeySet(testCase.jwks, testCase.modTime)
source.SetKeySet(testCase.jwks, testCase.modTime, testCase.pollTime)

r, err := http.NewRequest(testCase.method, "http://localhost"+testCase.path, nil)
require.NoError(t, err)
Expand All @@ -578,32 +581,6 @@ func TestHandlerProxied(t *testing.T) {
}
}

type FakeKeySetSource struct {
mu sync.Mutex
jwks *jose.JSONWebKeySet
modTime time.Time
}

func (s *FakeKeySetSource) SetKeySet(jwks *jose.JSONWebKeySet, modTime time.Time) {
s.mu.Lock()
defer s.mu.Unlock()
s.jwks = jwks
s.modTime = modTime
}

func (s *FakeKeySetSource) FetchKeySet() (*jose.JSONWebKeySet, time.Time, bool) {
s.mu.Lock()
defer s.mu.Unlock()
if s.jwks == nil {
return nil, time.Time{}, false
}
return s.jwks, s.modTime, true
}

func (s *FakeKeySetSource) Close() error {
return nil
}

func domainAllowlist(t *testing.T, domains ...string) DomainPolicy {
policy, err := DomainAllowlist(domains...)
require.NoError(t, err)
Expand Down

0 comments on commit 956443c

Please sign in to comment.