Skip to content

Commit

Permalink
core/zero: add support for managed mode from config file (#4756)
Browse files Browse the repository at this point in the history
  • Loading branch information
calebdoxsey committed Nov 17, 2023
1 parent eb729a5 commit 6810091
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 9 deletions.
4 changes: 2 additions & 2 deletions cmd/pomerium/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ func main() {

ctx := context.Background()
runFn := run
if zero_cmd.IsManagedMode() {
runFn = zero_cmd.Run
if zero_cmd.IsManagedMode(*configFile) {
runFn = func(ctx context.Context) error { return zero_cmd.Run(ctx, *configFile) }
}

if err := runFn(ctx); err != nil && !errors.Is(err, context.Canceled) {
Expand Down
8 changes: 4 additions & 4 deletions internal/zero/cmd/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ import (
)

// Run runs the pomerium zero command.
func Run(ctx context.Context) error {
func Run(ctx context.Context, configFile string) error {
err := setupLogger()
if err != nil {
return fmt.Errorf("error setting up logger: %w", err)
}

token := getToken()
token := getToken(configFile)
if token == "" {
return errors.New("no token provided")
}
Expand All @@ -37,8 +37,8 @@ func Run(ctx context.Context) error {
}

// IsManagedMode returns true if Pomerium should start in managed mode using this command.
func IsManagedMode() bool {
return getToken() != ""
func IsManagedMode(configFile string) bool {
return getToken(configFile) != ""
}

func withInterrupt(ctx context.Context) context.Context {
Expand Down
24 changes: 21 additions & 3 deletions internal/zero/cmd/env.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,31 @@
package cmd

import "os"
import (
"os"

"github.com/spf13/viper"
)

const (
// PomeriumZeroTokenEnv is the environment variable name for the API token.
//nolint: gosec
PomeriumZeroTokenEnv = "POMERIUM_ZERO_TOKEN"
)

func getToken() string {
return os.Getenv(PomeriumZeroTokenEnv)
func getToken(configFile string) string {
if token, ok := os.LookupEnv(PomeriumZeroTokenEnv); ok {
return token
}

if configFile != "" {
// load the token from the config file
v := viper.New()
v.SetConfigFile(configFile)
if v.ReadInConfig() == nil {
return v.GetString("pomerium_zero_token")
}
}

// we will fallback to normal pomerium if empty
return ""
}
41 changes: 41 additions & 0 deletions internal/zero/cmd/env_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package cmd

import (
"os"
"path/filepath"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func Test_getToken(t *testing.T) {
t.Run("empty", func(t *testing.T) {
assert.Equal(t, "", getToken(""))
})
t.Run("env", func(t *testing.T) {
t.Setenv("POMERIUM_ZERO_TOKEN", "FROM_ENV")
assert.Equal(t, "FROM_ENV", getToken(""))
})
t.Run("json", func(t *testing.T) {
fp := filepath.Join(t.TempDir(), "config.json")
require.NoError(t, os.WriteFile(fp, []byte(`{
"pomerium_zero_token": "FROM_JSON"
}`), 0o644))
assert.Equal(t, "FROM_JSON", getToken(fp))
})
t.Run("yaml", func(t *testing.T) {
fp := filepath.Join(t.TempDir(), "config.yaml")
require.NoError(t, os.WriteFile(fp, []byte(`
pomerium_zero_token: FROM_YAML
`), 0o644))
assert.Equal(t, "FROM_YAML", getToken(fp))
})
t.Run("toml", func(t *testing.T) {
fp := filepath.Join(t.TempDir(), "config.toml")
require.NoError(t, os.WriteFile(fp, []byte(`
pomerium_zero_token = "FROM_TOML"
`), 0o644))
assert.Equal(t, "FROM_TOML", getToken(fp))
})
}

0 comments on commit 6810091

Please sign in to comment.