Skip to content

Commit

Permalink
config: return errors on invalid URLs, fix linting (#1829)
Browse files Browse the repository at this point in the history
  • Loading branch information
calebdoxsey committed Jan 27, 2021
1 parent a8a7032 commit bec9805
Show file tree
Hide file tree
Showing 12 changed files with 255 additions and 148 deletions.
1 change: 0 additions & 1 deletion .golangci.yml
Expand Up @@ -71,7 +71,6 @@ linters:
- stylecheck
- typecheck
- unconvert
- unparam
- unused
- varcheck
# - asciicheck
Expand Down
26 changes: 18 additions & 8 deletions authorize/check_response.go
Expand Up @@ -49,7 +49,7 @@ func (a *Authorize) okResponse(reply *evaluator.Result) *envoy_service_auth_v2.C
func (a *Authorize) deniedResponse(
in *envoy_service_auth_v2.CheckRequest,
code int32, reason string, headers map[string]string,
) *envoy_service_auth_v2.CheckResponse {
) (*envoy_service_auth_v2.CheckResponse, error) {
returnHTMLError := true
inHeaders := in.GetAttributes().GetRequest().GetHttp().GetHeaders()
if inHeaders != nil {
Expand All @@ -59,15 +59,19 @@ func (a *Authorize) deniedResponse(
if returnHTMLError {
return a.htmlDeniedResponse(in, code, reason, headers)
}
return a.plainTextDeniedResponse(code, reason, headers)
return a.plainTextDeniedResponse(code, reason, headers), nil
}

func (a *Authorize) htmlDeniedResponse(
in *envoy_service_auth_v2.CheckRequest,
code int32, reason string, headers map[string]string,
) *envoy_service_auth_v2.CheckResponse {
) (*envoy_service_auth_v2.CheckResponse, error) {
opts := a.currentOptions.Load()
debugEndpoint := opts.GetAuthenticateURL().ResolveReference(&url.URL{Path: "/.pomerium/"})
authenticateURL, err := opts.GetAuthenticateURL()
if err != nil {
return nil, err
}
debugEndpoint := authenticateURL.ResolveReference(&url.URL{Path: "/.pomerium/"})

// create go-style http request
r := getHTTPRequestFromCheckRequest(in)
Expand Down Expand Up @@ -97,7 +101,7 @@ func (a *Authorize) htmlDeniedResponse(
}

var buf bytes.Buffer
err := a.templates.ExecuteTemplate(&buf, "error.html", map[string]interface{}{
err = a.templates.ExecuteTemplate(&buf, "error.html", map[string]interface{}{
"Status": code,
"StatusText": reason,
"CanDebug": code/100 == 4,
Expand Down Expand Up @@ -127,7 +131,7 @@ func (a *Authorize) htmlDeniedResponse(
Body: buf.String(),
},
},
}
}, nil
}

func (a *Authorize) plainTextDeniedResponse(code int32, reason string, headers map[string]string) *envoy_service_auth_v2.CheckResponse {
Expand All @@ -152,10 +156,16 @@ func (a *Authorize) plainTextDeniedResponse(code int32, reason string, headers m
}
}

func (a *Authorize) redirectResponse(in *envoy_service_auth_v2.CheckRequest) *envoy_service_auth_v2.CheckResponse {
func (a *Authorize) redirectResponse(in *envoy_service_auth_v2.CheckRequest) (*envoy_service_auth_v2.CheckResponse, error) {
opts := a.currentOptions.Load()
authenticateURL, err := opts.GetAuthenticateURL()
if err != nil {
return nil, err
}

signinURL := opts.GetAuthenticateURL().ResolveReference(&url.URL{Path: "/.pomerium/sign_in"})
signinURL := authenticateURL.ResolveReference(&url.URL{
Path: "/.pomerium/sign_in",
})
q := signinURL.Query()

// always assume https scheme
Expand Down
3 changes: 2 additions & 1 deletion authorize/check_response_test.go
Expand Up @@ -280,7 +280,8 @@ func TestAuthorize_deniedResponse(t *testing.T) {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got := a.deniedResponse(tc.in, tc.code, tc.reason, tc.headers)
got, err := a.deniedResponse(tc.in, tc.code, tc.reason, tc.headers)
require.NoError(t, err)
assert.Equal(t, tc.want.Status.Code, got.Status.Code)
assert.Equal(t, tc.want.Status.Message, got.Status.Message)
assert.Equal(t, tc.want.GetDeniedResponse().GetHeaders(), got.GetDeniedResponse().GetHeaders())
Expand Down
13 changes: 9 additions & 4 deletions authorize/grpc.go
Expand Up @@ -78,11 +78,11 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v2.CheckRe
return a.okResponse(reply), nil
case reply.Status == http.StatusUnauthorized:
if isForwardAuth && hreq.URL.Path == "/verify" {
return a.deniedResponse(in, http.StatusUnauthorized, "Unauthenticated", nil), nil
return a.deniedResponse(in, http.StatusUnauthorized, "Unauthenticated", nil)
}
return a.redirectResponse(in), nil
return a.redirectResponse(in)
}
return a.deniedResponse(in, int32(reply.Status), reply.Message, nil), nil
return a.deniedResponse(in, int32(reply.Status), reply.Message, nil)
}

func (a *Authorize) forceSync(ctx context.Context, ss *sessions.State) error {
Expand Down Expand Up @@ -212,9 +212,14 @@ func (a *Authorize) isForwardAuth(req *envoy_service_auth_v2.CheckRequest) bool
return false
}

forwardAuthURL, err := opts.GetForwardAuthURL()
if err != nil {
return false
}

checkURL := getCheckRequestURL(req)

return urlutil.StripPort(checkURL.Host) == urlutil.StripPort(opts.GetForwardAuthURL().Host)
return urlutil.StripPort(checkURL.Host) == urlutil.StripPort(forwardAuthURL.Host)
}

func (a *Authorize) getEvaluatorRequestFromCheckRequest(in *envoy_service_auth_v2.CheckRequest, sessionState *sessions.State) *evaluator.Request {
Expand Down
41 changes: 21 additions & 20 deletions config/options.go
Expand Up @@ -710,45 +710,46 @@ func (o *Options) Validate() error {
}

// GetAuthenticateURL returns the AuthenticateURL in the options or 127.0.0.1.
func (o *Options) GetAuthenticateURL() *url.URL {
func (o *Options) GetAuthenticateURL() (*url.URL, error) {
if o != nil && o.AuthenticateURL != nil {
return o.AuthenticateURL
return o.AuthenticateURL, nil
}
u, _ := url.Parse("https://127.0.0.1")
return u
return url.Parse("https://127.0.0.1")
}

// GetAuthorizeURL returns the AuthorizeURL in the options or 127.0.0.1:5443.
func (o *Options) GetAuthorizeURL() *url.URL {
func (o *Options) GetAuthorizeURL() (*url.URL, error) {
if o != nil && o.AuthorizeURL != nil {
return o.AuthorizeURL
return o.AuthorizeURL, nil
}
u, _ := url.Parse("http://127.0.0.1" + DefaultAlternativeAddr)
return u
return url.Parse("http://127.0.0.1" + DefaultAlternativeAddr)
}

// GetDataBrokerURL returns the DataBrokerURL in the options or 127.0.0.1:5443.
func (o *Options) GetDataBrokerURL() *url.URL {
func (o *Options) GetDataBrokerURL() (*url.URL, error) {
if o != nil && o.DataBrokerURL != nil {
return o.DataBrokerURL
return o.DataBrokerURL, nil
}
u, _ := url.Parse("http://127.0.0.1" + DefaultAlternativeAddr)
return u
return url.Parse("http://127.0.0.1" + DefaultAlternativeAddr)
}

// GetForwardAuthURL returns the ForwardAuthURL in the options or 127.0.0.1.
func (o *Options) GetForwardAuthURL() *url.URL {
func (o *Options) GetForwardAuthURL() (*url.URL, error) {
if o != nil && o.ForwardAuthURL != nil {
return o.ForwardAuthURL
return o.ForwardAuthURL, nil
}
u, _ := url.Parse("https://127.0.0.1")
return u
return url.Parse("https://127.0.0.1")
}

// GetOauthOptions gets the oauth.Options for the given config options.
func (o *Options) GetOauthOptions() oauth.Options {
redirectURL := o.GetAuthenticateURL()
redirectURL.Path = o.AuthenticateCallbackPath
func (o *Options) GetOauthOptions() (oauth.Options, error) {
redirectURL, err := o.GetAuthenticateURL()
if err != nil {
return oauth.Options{}, err
}
redirectURL = redirectURL.ResolveReference(&url.URL{
Path: o.AuthenticateCallbackPath,
})
return oauth.Options{
RedirectURL: redirectURL,
ProviderName: o.Provider,
Expand All @@ -757,7 +758,7 @@ func (o *Options) GetOauthOptions() oauth.Options {
ClientSecret: o.ClientSecret,
Scopes: o.Scopes,
ServiceAccount: o.ServiceAccount,
}
}, nil
}

// GetAllPolicies gets all the policies in the options.
Expand Down
11 changes: 8 additions & 3 deletions config/options_test.go
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

var cmpOptIgnoreUnexported = cmpopts.IgnoreUnexported(Options{})
Expand Down Expand Up @@ -498,7 +499,7 @@ func TestOptions_DefaultURL(t *testing.T) {
}
tests := []struct {
name string
f func() *url.URL
f func() (*url.URL, error)
expectedURLStr string
}{
{"default authenticate url", defaultOptions.GetAuthenticateURL, "https://127.0.0.1"},
Expand All @@ -515,7 +516,9 @@ func TestOptions_DefaultURL(t *testing.T) {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
assert.Equal(t, tc.expectedURLStr, tc.f().String())
u, err := tc.f()
require.NoError(t, err)
assert.Equal(t, tc.expectedURLStr, u.String())
})
}
}
Expand All @@ -530,7 +533,9 @@ func mustParseURL(str string) *url.URL {

func TestOptions_GetOauthOptions(t *testing.T) {
opts := &Options{AuthenticateURL: mustParseURL("https://authenticate.example.com")}
oauthOptions, err := opts.GetOauthOptions()
require.NoError(t, err)

// Test that oauth redirect url hostname must point to authenticate url hostname.
assert.Equal(t, opts.AuthenticateURL.Hostname(), opts.GetOauthOptions().RedirectURL.Hostname())
assert.Equal(t, opts.AuthenticateURL.Hostname(), oauthOptions.RedirectURL.Hostname())
}
13 changes: 11 additions & 2 deletions databroker/cache.go
Expand Up @@ -81,13 +81,17 @@ func New(cfg *config.Config) (*DataBroker, error) {
}

dataBrokerServer := newDataBrokerServer(cfg)
dataBrokerURL, err := cfg.Options.GetDataBrokerURL()
if err != nil {
return nil, err
}

c := &DataBroker{
dataBrokerServer: dataBrokerServer,
localListener: localListener,
localGRPCServer: localGRPCServer,
localGRPCConnection: localGRPCConnection,
deprecatedCacheClusterDomain: cfg.Options.GetDataBrokerURL().Hostname(),
deprecatedCacheClusterDomain: dataBrokerURL.Hostname(),
dataBrokerStorageType: cfg.Options.DataBrokerStorageType,
}
c.Register(c.localGRPCServer)
Expand Down Expand Up @@ -138,7 +142,12 @@ func (c *DataBroker) update(cfg *config.Config) error {
return fmt.Errorf("databroker: bad option: %w", err)
}

authenticator, err := identity.NewAuthenticator(cfg.Options.GetOauthOptions())
oauthOptions, err := cfg.Options.GetOauthOptions()
if err != nil {
return fmt.Errorf("databroker: invalid oauth options: %w", err)
}

authenticator, err := identity.NewAuthenticator(oauthOptions)
if err != nil {
return fmt.Errorf("databroker: failed to create authenticator: %w", err)
}
Expand Down
8 changes: 7 additions & 1 deletion internal/cmd/pomerium/pomerium.go
Expand Up @@ -158,9 +158,15 @@ func setupAuthenticate(src config.Source, controlPlane *controlplane.Server) err
if err != nil {
return fmt.Errorf("error creating authenticate service: %w", err)
}

authenticateURL, err := src.GetConfig().Options.GetAuthenticateURL()
if err != nil {
return fmt.Errorf("error getting authenticate URL: %w", err)
}

src.OnConfigChange(svc.OnConfigChange)
svc.OnConfigChange(src.GetConfig())
host := urlutil.StripPort(src.GetConfig().Options.GetAuthenticateURL().Host)
host := urlutil.StripPort(authenticateURL.Host)
sr := controlPlane.HTTPRouter.Host(host).Subrouter()
svc.Mount(sr)
log.Info().Str("host", host).Msg("enabled authenticate service")
Expand Down
22 changes: 12 additions & 10 deletions internal/controlplane/xds_clusters.go
Expand Up @@ -50,9 +50,9 @@ func (srv *Server) buildClusters(options *config.Options) ([]*envoy_config_clust
Scheme: "http",
Host: srv.HTTPListener.Addr().String(),
}
authzURL := &url.URL{
Scheme: options.GetAuthorizeURL().Scheme,
Host: options.GetAuthorizeURL().Host,
authzURL, err := options.GetAuthorizeURL()
if err != nil {
return nil, err
}

controlGRPC, err := srv.buildInternalCluster(options, "pomerium-control-plane-grpc", grpcURL, true)
Expand Down Expand Up @@ -132,22 +132,22 @@ func (srv *Server) buildPolicyCluster(options *config.Options, policy *config.Po

func (srv *Server) buildInternalEndpoints(options *config.Options, dst *url.URL) ([]Endpoint, error) {
var endpoints []Endpoint
if ts, err := srv.buildInternalTransportSocket(options, dst); err != nil {
ts, err := srv.buildInternalTransportSocket(options, dst)
if err != nil {
return nil, err
} else {
endpoints = append(endpoints, NewEndpoint(dst, ts))
}
endpoints = append(endpoints, NewEndpoint(dst, ts))
return endpoints, nil
}

func (srv *Server) buildPolicyEndpoints(policy *config.Policy) ([]Endpoint, error) {
var endpoints []Endpoint
for _, dst := range policy.Destinations {
if ts, err := srv.buildPolicyTransportSocket(policy, dst); err != nil {
ts, err := srv.buildPolicyTransportSocket(policy, dst)
if err != nil {
return nil, err
} else {
endpoints = append(endpoints, NewEndpoint(dst, ts))
}
endpoints = append(endpoints, NewEndpoint(dst, ts))
}
return endpoints, nil
}
Expand Down Expand Up @@ -246,7 +246,9 @@ func (srv *Server) buildPolicyTransportSocket(policy *config.Policy, dst *url.UR
}, nil
}

func (srv *Server) buildPolicyValidationContext(policy *config.Policy, dst *url.URL) (*envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext, error) {
func (srv *Server) buildPolicyValidationContext(
policy *config.Policy, dst *url.URL,
) (*envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext, error) {
if dst == nil {
return nil, nil
}
Expand Down

0 comments on commit bec9805

Please sign in to comment.