Skip to content

Commit

Permalink
feat: allow setting the authority header in the CLI
Browse files Browse the repository at this point in the history
  • Loading branch information
zepatrik committed Oct 27, 2022
1 parent d98142d commit 17f10ef
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 49 deletions.
74 changes: 43 additions & 31 deletions cmd/client/grpc_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,34 +28,46 @@ const (

FlagInsecureNoTransportSecurity = "insecure-disable-transport-security"
FlagInsecureSkipHostVerification = "insecure-skip-hostname-verification"
FlagAuthority = "authority"

EnvReadRemote = "KETO_READ_REMOTE"
EnvWriteRemote = "KETO_WRITE_REMOTE"
EnvAuthToken = "KETO_BEARER_TOKEN" // nosec G101 -- just the key, not the value
EnvAuthority = "KETO_AUTHORITY"

ContextKeyTimeout contextKeys = "timeout"
)

type securityFlags struct {
type connectionDetails struct {
token, authority string
skipHostVerification bool
noTransportSecurity bool
}

func (sf *securityFlags) transportCredentials() grpc.DialOption {
switch {
case sf.noTransportSecurity:
return grpc.WithTransportCredentials(insecure.NewCredentials())
func (d *connectionDetails) dialOptions() (opts []grpc.DialOption) {
if d.token != "" {
opts = append(opts,
grpc.WithPerRPCCredentials(
oauth.NewOauthAccess(&oauth2.Token{AccessToken: d.token})))
}
if d.authority != "" {
opts = append(opts, grpc.WithAuthority(d.authority))
}

case sf.skipHostVerification:
return grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
// TLS settings
switch {
case d.noTransportSecurity:
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
case d.skipHostVerification:
opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
// nolint explicity set through scary flag
InsecureSkipVerify: true,
}))

})))
default:
// Defaults to the default host root CA bundle
return grpc.WithTransportCredentials(credentials.NewTLS(nil))
opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(nil)))
}
return opts
}

func getRemote(cmd *cobra.Command, flagRemote, envRemote string) (remote string) {
Expand All @@ -77,12 +89,17 @@ func getRemote(cmd *cobra.Command, flagRemote, envRemote string) (remote string)
return remote
}

func getToken(cmd *cobra.Command) (token string) {
return os.Getenv(EnvAuthToken)
func getAuthority(cmd *cobra.Command) string {
if cmd.Flags().Changed(FlagAuthority) {
return flagx.MustGetString(cmd, FlagAuthority)
}
return os.Getenv(EnvAuthority)
}

func getSecurityFlags(cmd *cobra.Command) securityFlags {
return securityFlags{
func getConnectionDetails(cmd *cobra.Command) connectionDetails {
return connectionDetails{
token: os.Getenv(EnvAuthToken),
authority: getAuthority(cmd),
skipHostVerification: flagx.MustGetBool(cmd, FlagInsecureSkipHostVerification),
noTransportSecurity: flagx.MustGetBool(cmd, FlagInsecureNoTransportSecurity),
}
Expand All @@ -91,20 +108,18 @@ func getSecurityFlags(cmd *cobra.Command) securityFlags {
func GetReadConn(cmd *cobra.Command) (*grpc.ClientConn, error) {
return Conn(cmd.Context(),
getRemote(cmd, FlagReadRemote, EnvReadRemote),
getToken(cmd),
getSecurityFlags(cmd),
getConnectionDetails(cmd),
)
}

func GetWriteConn(cmd *cobra.Command) (*grpc.ClientConn, error) {
return Conn(cmd.Context(),
getRemote(cmd, FlagWriteRemote, EnvWriteRemote),
getToken(cmd),
getSecurityFlags(cmd),
getConnectionDetails(cmd),
)
}

func Conn(ctx context.Context, remote, token string, security securityFlags) (*grpc.ClientConn, error) {
func Conn(ctx context.Context, remote string, details connectionDetails) (*grpc.ClientConn, error) {
timeout := 3 * time.Second
if d, ok := ctx.Value(ContextKeyTimeout).(time.Duration); ok {
timeout = d
Expand All @@ -113,23 +128,20 @@ func Conn(ctx context.Context, remote, token string, security securityFlags) (*g
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()

dialOpts := []grpc.DialOption{
grpc.WithBlock(),
grpc.WithDisableHealthCheck(),
}
dialOpts = append(dialOpts, security.transportCredentials())
if token != "" {
dialOpts = append(dialOpts,
grpc.WithPerRPCCredentials(
oauth.NewOauthAccess(&oauth2.Token{AccessToken: token})))
}

return grpc.DialContext(ctx, remote, dialOpts...)
return grpc.DialContext(
ctx,
remote,
append([]grpc.DialOption{
grpc.WithBlock(),
grpc.WithDisableHealthCheck(),
}, details.dialOptions()...)...,
)
}

func RegisterRemoteURLFlags(flags *pflag.FlagSet) {
flags.String(FlagReadRemote, "127.0.0.1:4466", "Remote address of the read API endpoint.")
flags.String(FlagWriteRemote, "127.0.0.1:4467", "Remote address of the write API endpoint.")
flags.String(FlagAuthority, "", "Set the authority header for the remote gRPC server.")
flags.Bool(FlagInsecureNoTransportSecurity, false, "Disables transport security. Do not use this in production.")
flags.Bool(FlagInsecureSkipHostVerification, false, "Disables hostname verification. Do not use this in production.")
}
54 changes: 40 additions & 14 deletions cmd/status/root_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func TestStatusCmd(t *testing.T) {
})

t.Run("case=block", func(t *testing.T) {
ctx := context.WithValue(context.Background(), client.ContextKeyTimeout, time.Second)
ctx := context.WithValue(context.Background(), client.ContextKeyTimeout, time.Millisecond)

l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
Expand Down Expand Up @@ -93,25 +93,27 @@ func TestStatusCmd(t *testing.T) {
}
}

func authInterceptor(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, errors.New("not authorized, no metadata")
func authInterceptor(header, validValue string) func(ctx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {
return func(ctx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, errors.New("not authorized, no metadata")
}
vals := md.Get(header)
if len(vals) != 1 {
return nil, errors.New("not authorized, no header values")
}
if vals[0] != validValue {
return nil, errors.New("not authorized, incorrect value")
}
return handler(ctx, req)
}
vals := md.Get("authorization")
if len(vals) != 1 {
return nil, errors.New("not authorized, no authorization header")
}
if vals[0] != "Bearer secret" {
return nil, errors.New("not authorized")
}
return handler(ctx, req)
}

func TestAuthorizedRequest(t *testing.T) {
ts := client.NewTestServer(
t, "read", []*namespace.Namespace{{Name: t.Name()}}, newStatusCmd,
driver.WithGRPCUnaryInterceptors(authInterceptor),
driver.WithGRPCUnaryInterceptors(authInterceptor("authorization", "Bearer secret")),
)
defer ts.Shutdown(t)

Expand All @@ -126,3 +128,27 @@ func TestAuthorizedRequest(t *testing.T) {
assert.Contains(t, out, "SERVING")
})
}

func TestAuthorityRequest(t *testing.T) {
ts := client.NewTestServer(
t, "read", []*namespace.Namespace{{Name: t.Name()}}, newStatusCmd,
driver.WithGRPCUnaryInterceptors(authInterceptor(":authority", "example.com")),
)
defer ts.Shutdown(t)

t.Run("case=no authority", func(t *testing.T) {
out := ts.Cmd.ExecExpectedErr(t)
assert.Contains(t, out, "not authorized")
})

t.Run("case=env authority", func(t *testing.T) {
t.Setenv("KETO_AUTHORITY", "example.com")
out := ts.Cmd.ExecNoErr(t)
assert.Contains(t, out, "SERVING")
})

t.Run("case=flag authority", func(t *testing.T) {
out := ts.Cmd.ExecNoErr(t, "--"+client.FlagAuthority, "example.com")
assert.Contains(t, out, "SERVING")
})
}
4 changes: 0 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ github.com/cenkalti/backoff/v3 v3.2.2/go.mod h1:cIeZDE3IrqwwJl6VUwCN6trj1oXrTS4r
github.com/cenkalti/backoff/v4 v4.1.3 h1:cFAlzYUlVYDysBEH2T5hyJZMh3+5+WCBvSnK6Q8UtC4=
github.com/cenkalti/backoff/v4 v4.1.3/go.mod h1:scbssz8iZGpm3xbr14ovlUdkxfGXNInqkPWOWmG2CLw=
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko=
github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc=
github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
Expand Down Expand Up @@ -355,7 +354,6 @@ github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 h1:+9834+KizmvFV7pXQGSXQTsaW
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0/go.mod h1:z0ButlSOZa5vEBq9m2m2hlwIgKw+rp3sdCBRoJY+30Y=
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk=
github.com/grpc-ecosystem/grpc-gateway v1.9.0/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY=
github.com/grpc-ecosystem/grpc-gateway v1.16.0 h1:gmcG1KaJ57LophUzW0Hy8NmPhnMZb4M0+kPpLofRdBo=
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.7.0/go.mod h1:hgWBS7lorOAVIJEQMi4ZsPv9hVvWI6+ch50m39Pf2Ks=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.12.0 h1:kr3j8iIMR4ywO/O0rvksXaJvauGGCMg2zAZIiNZ9uIQ=
Expand Down Expand Up @@ -422,7 +420,6 @@ github.com/inconshreveable/mousetrap v1.0.1/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLf
github.com/inhies/go-bytesize v0.0.0-20220417184213-4913239db9cf h1:FtEj8sfIcaaBfAKrE1Cwb61YDtYq9JxChK1c7AKce7s=
github.com/inhies/go-bytesize v0.0.0-20220417184213-4913239db9cf/go.mod h1:yrqSXGoD/4EKfF26AOGzscPOgTTJcyAwM2rpixWT+t4=
github.com/instana/testify v1.6.2-0.20200721153833-94b1851f4d65 h1:T25FL3WEzgmKB0m6XCJNZ65nw09/QIp3T1yXr487D+A=
github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0=
github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo=
github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk=
github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8=
Expand All @@ -445,7 +442,6 @@ github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 h1:DadwsjnMwFjfWc9y5W
github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65/go.mod h1:5R2h2EEX+qri8jOWMbJCtaPWkrrNc7OHwsp2TCqp7ak=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A=
github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78=
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA=
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg=
Expand Down

0 comments on commit 17f10ef

Please sign in to comment.