diff --git a/pkg/ensign/contexts/contexts.go b/pkg/ensign/contexts/contexts.go new file mode 100644 index 000000000..9599932a4 --- /dev/null +++ b/pkg/ensign/contexts/contexts.go @@ -0,0 +1,38 @@ +package contexts + +import ( + "context" + + "github.com/rotationalio/ensign/pkg/quarterdeck/tokens" +) + +// Ensign-specific context keys for passing values to concurrent requests +type contextKey uint8 + +// Allocate context keys to simplify context key usage in Ensign +const ( + KeyUnknown contextKey = iota + KeyClaims +) + +// WithClaims returns a copy of the parent context with the access claims stored as a +// value on the new context. Users can fetch the claims using the ClaimsFrom function. +func WithClaims(parent context.Context, claims *tokens.Claims) context.Context { + return context.WithValue(parent, KeyClaims, claims) +} + +// ClaimsFrom returns the claims from the context if they exist or false if not. +func ClaimsFrom(ctx context.Context) (*tokens.Claims, bool) { + claims, ok := ctx.Value(KeyClaims).(*tokens.Claims) + return claims, ok +} + +var contextKeyNames = []string{"unknown", "claims"} + +// String returns a human readable representation of the context key for easier debugging. +func (c contextKey) String() string { + if int(c) < len(contextKeyNames) { + return contextKeyNames[c] + } + return contextKeyNames[0] +} diff --git a/pkg/ensign/contexts/contexts_test.go b/pkg/ensign/contexts/contexts_test.go new file mode 100644 index 000000000..4d6e07853 --- /dev/null +++ b/pkg/ensign/contexts/contexts_test.go @@ -0,0 +1,42 @@ +package contexts_test + +import ( + "context" + "fmt" + "testing" + + "github.com/rotationalio/ensign/pkg/ensign/contexts" + "github.com/rotationalio/ensign/pkg/quarterdeck/tokens" + "github.com/stretchr/testify/require" +) + +func TestClaimsContext(t *testing.T) { + claims := &tokens.Claims{ + Name: "Barbara Testly", + Email: "btest@testing.io", + } + + parent, cancel := context.WithCancel(context.Background()) + ctx := contexts.WithClaims(parent, claims) + + cmpt, ok := contexts.ClaimsFrom(ctx) + require.True(t, ok) + require.Same(t, claims, cmpt) + + cancel() + require.ErrorIs(t, ctx.Err(), context.Canceled) +} + +func TestKeyString(t *testing.T) { + testCases := []struct { + key fmt.Stringer + expected string + }{ + {contexts.KeyUnknown, "unknown"}, + {contexts.KeyClaims, "claims"}, + } + + for _, tc := range testCases { + require.Equal(t, tc.expected, tc.key.String()) + } +} diff --git a/pkg/ensign/contexts/stream.go b/pkg/ensign/contexts/stream.go new file mode 100644 index 000000000..0c9e670d3 --- /dev/null +++ b/pkg/ensign/contexts/stream.go @@ -0,0 +1,24 @@ +package contexts + +import ( + "context" + + "google.golang.org/grpc" +) + +// Stream allows users to override the context on a grpc.ServerStream handler so that +// it returns a new context rather than the old context. It is advised to use the +// original stream's context as the new context's parent but this method does not +// enforce it and instead simply returns the context specified. +func Stream(s grpc.ServerStream, ctx context.Context) grpc.ServerStream { + return &stream{s, ctx} +} + +type stream struct { + grpc.ServerStream + ctx context.Context +} + +func (s *stream) Context() context.Context { + return s.ctx +} diff --git a/pkg/ensign/contexts/stream_test.go b/pkg/ensign/contexts/stream_test.go new file mode 100644 index 000000000..529d07219 --- /dev/null +++ b/pkg/ensign/contexts/stream_test.go @@ -0,0 +1,34 @@ +package contexts_test + +import ( + "context" + "testing" + + "github.com/rotationalio/ensign/pkg/ensign/contexts" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" +) + +func TestStream(t *testing.T) { + mock := &MockStream{} + stream := contexts.Stream(mock, context.WithValue(mock.Context(), contexts.KeyUnknown, "bar")) + + ctx := stream.Context() + require.Equal(t, "bar", ctx.Value(contexts.KeyUnknown).(string)) + + mock.cancel() + require.ErrorIs(t, ctx.Err(), context.Canceled) +} + +type MockStream struct { + grpc.ServerStream + ctx context.Context + cancel context.CancelFunc +} + +func (m *MockStream) Context() context.Context { + if m.ctx == nil { + m.ctx, m.cancel = context.WithCancel(context.Background()) + } + return m.ctx +} diff --git a/pkg/ensign/interceptors/auth.go b/pkg/ensign/interceptors/auth.go new file mode 100644 index 000000000..76b3b369d --- /dev/null +++ b/pkg/ensign/interceptors/auth.go @@ -0,0 +1,120 @@ +package interceptors + +import ( + "context" + "strings" + + "github.com/rotationalio/ensign/pkg/ensign/contexts" + "github.com/rotationalio/ensign/pkg/quarterdeck/middleware" + "github.com/rotationalio/ensign/pkg/quarterdeck/tokens" + "github.com/rs/zerolog/log" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" +) + +const ( + header = "authorization" // MUST BE LOWER CASE! + bearer = "Bearer " // MUST INCLUDE TRAILING SPACE! +) + +// Authenticator ensures that the RPC request has a valid Quarterdeck-issued JWT token +// in the credentials metadata of the request, otherwise it stops processing and returns +// an Unauthenticated error. A valid JWT token means that the token is supplied in the +// credentials, is unexpired, was signed by Quarterdeck private keys, and has the +// correct audience and issuer. +// +// This interceptor extracts the claims from the JWT token and adds them to the context +// of the request, ensuring that downstream interceptors and the handlers can access the +// claims without having to parse the JWT token in the credentials. +// +// In order to perform authentication, this middleware fetches public JSON Web Key Sets +// (JWKS) from the authorizing Quarterdeck server and caches them according to the +// Cache-Control or Expires headers in the response. As Quarterdeck keys are rotated, +// the cache must refresh the public keys in a background routine to correctly +// authenticate incoming credentials. Users can control how the JWKS are fetched and +// cached using AuthOptions from the Quarterdeck middleware package. +// +// Both Unary and Streaming interceptors can be returned from this middleware handler. +type Authenticator struct { + conf middleware.AuthOptions + validator tokens.Validator +} + +// Create an authenticator to handle both unary and streaming RPC calls, modifying the +// behavior of the authenticator using auth options from Quarterdeck middleware. +func NewAuthenticator(opts ...middleware.AuthOption) (auth *Authenticator, err error) { + auth = &Authenticator{ + conf: middleware.NewAuthOptions(opts...), + } + + if auth.validator, err = auth.conf.Validator(); err != nil { + return nil, err + } + return auth, nil +} + +// Authenticate a request using the access token credentials provided in the metadata. +func (a *Authenticator) authenticate(ctx context.Context) (_ context.Context, err error) { + var ( + claims *tokens.Claims + md metadata.MD + ok bool + ) + + if md, ok = metadata.FromIncomingContext(ctx); !ok { + return nil, status.Error(codes.Unauthenticated, "missing credentials") + } + + // Extract the authorization credentials (we expect [at least] 1 JWT token) + values := md[header] + if len(values) == 0 { + return nil, status.Error(codes.Unauthenticated, "missing credentials") + } + + // Loop through credentials to find the first valid claims + // NOTE: we only expect one token but are trying to future-proof the interceptor + for _, value := range values { + if !strings.HasPrefix(value, bearer) { + continue + } + + token := strings.TrimPrefix(value, bearer) + if claims, err = a.validator.Verify(token); err == nil { + break + } + } + + // Check to see if we found any valid claims in the request + if claims == nil { + log.Debug().Err(err).Int("tokens", len(values)).Msg("could not find a valid access token in request") + return nil, status.Error(codes.Unauthenticated, "invalid credentials") + } + + // Add the claims to the context so that downstream handlers can access it + return contexts.WithClaims(ctx, claims), nil +} + +// Return the Unary interceptor that uses the Authenticator handler. +func (a *Authenticator) Unary(opts ...middleware.AuthOption) grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (_ interface{}, err error) { + if ctx, err = a.authenticate(ctx); err != nil { + return nil, err + } + return handler(ctx, req) + } +} + +// Return the Stream interceptor that uses the Authenticator handler. +func (a *Authenticator) Stream(opts ...middleware.AuthOption) grpc.StreamServerInterceptor { + return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) (err error) { + var ctx context.Context + if ctx, err = a.authenticate(stream.Context()); err != nil { + return err + } + + stream = contexts.Stream(stream, ctx) + return handler(srv, stream) + } +} diff --git a/pkg/ensign/interceptors/auth_test.go b/pkg/ensign/interceptors/auth_test.go new file mode 100644 index 000000000..2eb98486d --- /dev/null +++ b/pkg/ensign/interceptors/auth_test.go @@ -0,0 +1,144 @@ +package interceptors_test + +import ( + "context" + "testing" + + api "github.com/rotationalio/ensign/pkg/api/v1beta1" + "github.com/rotationalio/ensign/pkg/ensign/contexts" + "github.com/rotationalio/ensign/pkg/ensign/interceptors" + "github.com/rotationalio/ensign/pkg/ensign/mock" + "github.com/rotationalio/ensign/pkg/quarterdeck/authtest" + "github.com/rotationalio/ensign/pkg/quarterdeck/middleware" + "github.com/rotationalio/ensign/pkg/quarterdeck/tokens" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/status" +) + +type testCredentials struct { + token string +} + +func (t *testCredentials) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { + return map[string]string{ + "Authorization": "Bearer " + t.token, + }, nil +} + +func (t *testCredentials) RequireTransportSecurity() bool { + return false +} + +func TestAuthenticator(t *testing.T) { + // Create the test authentication server + auth, err := authtest.NewServer() + require.NoError(t, err, "could not start authtest server") + defer auth.Close() + + // Create the interceptors and the mock gRPC server to test with + authenticator, err := interceptors.NewAuthenticator( + middleware.WithJWKSEndpoint(auth.KeysURL()), + middleware.WithAudience(authtest.Audience), + middleware.WithIssuer(authtest.Issuer), + ) + require.NoError(t, err, "could not create authenticator interceptors") + + opts := make([]grpc.ServerOption, 0, 2) + opts = append(opts, grpc.UnaryInterceptor(authenticator.Unary())) + opts = append(opts, grpc.StreamInterceptor(authenticator.Stream())) + srv := mock.New(nil, opts...) + + t.Run("Unary", func(t *testing.T) { + t.Cleanup(srv.Reset) + + // Ensure that the unary endpoint returns a decent response + srv.OnListTopics = func(ctx context.Context, _ *api.PageInfo) (*api.TopicsPage, error) { + // Make sure that the claims are in the context, otherwise return an error. + if _, ok := contexts.ClaimsFrom(ctx); !ok { + return nil, status.Error(codes.PermissionDenied, "no claims in context") + } + return &api.TopicsPage{}, nil + } + + // Create a client to trigger requests + ctx := context.Background() + client, err := srv.ResetClient(ctx) + require.NoError(t, err, "could not connect client to mock") + + // Should not be able to connect to RPC without authentication + _, err = client.ListTopics(ctx, &api.PageInfo{}) + require.EqualError(t, err, "rpc error: code = Unauthenticated desc = missing credentials") + + // Should not be able to connect with an invalid JWT token + client, err = srv.ResetClient(ctx, grpc.WithPerRPCCredentials(&testCredentials{"notarealjwtoken"}), grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err, "could not connect client to mock") + _, err = client.ListTopics(ctx, &api.PageInfo{}) + require.EqualError(t, err, "rpc error: code = Unauthenticated desc = invalid credentials") + + // Should be able to connect with a valid auth token and claims should be in context + token, err := auth.CreateAccessToken(&tokens.Claims{Email: "test@example.com"}) + require.NoError(t, err, "could not create access token") + client, err = srv.ResetClient(ctx, grpc.WithPerRPCCredentials(&testCredentials{token}), grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err, "could not connect client to mock") + + _, err = client.ListTopics(ctx, &api.PageInfo{}) + require.NoError(t, err, "could not access endpoint with valid token") + }) + + t.Run("Stream", func(t *testing.T) { + t.Cleanup(srv.Reset) + + // Create a client to trigger requests + ctx := context.Background() + client, err := srv.ResetClient(ctx) + require.NoError(t, err, "could not connect client to mock") + + // Handle stream RPC + srv.OnPublish = func(stream api.Ensign_PublishServer) error { + // Make sure that the claims are in the context, otherwise return an error. + if _, ok := contexts.ClaimsFrom(stream.Context()); !ok { + return status.Error(codes.PermissionDenied, "no claims in context") + } + + stream.Send(&api.Publication{}) + return nil + } + + // Should be able to connect to RPC without authentication + stream, err := client.Publish(ctx) + require.NoError(t, err, "expected to connect to stream without error") + + // Should not be able to send a message without authentication + _, err = stream.Recv() + require.EqualError(t, err, "rpc error: code = Unauthenticated desc = missing credentials") + + // Should not be able to connect with an invalid JWT token + client, err = srv.ResetClient(ctx, grpc.WithPerRPCCredentials(&testCredentials{"notarealjwtoken"}), grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err, "could not connect client to mock") + + // Should be able to connect to RPC without authentication + stream, err = client.Publish(ctx) + require.NoError(t, err, "expected to connect to stream without error") + + // Should not be able to send a message without authentication + _, err = stream.Recv() + require.EqualError(t, err, "rpc error: code = Unauthenticated desc = invalid credentials") + + // Should be able to connect with a valid auth token and claims should be in context + token, err := auth.CreateAccessToken(&tokens.Claims{Email: "test@example.com"}) + require.NoError(t, err, "could not create access token") + client, err = srv.ResetClient(ctx, grpc.WithPerRPCCredentials(&testCredentials{token}), grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err, "could not connect client to mock") + + // Should be able to connect to RPC without authentication + stream, err = client.Publish(ctx) + require.NoError(t, err, "expected to connect to stream without error") + + // Should not be able to send a message without authentication + _, err = stream.Recv() + require.NoError(t, err, "could not authenticate stream") + }) +} diff --git a/pkg/ensign/mock/mock.go b/pkg/ensign/mock/mock.go index 5137c1a9b..571ce3da1 100644 --- a/pkg/ensign/mock/mock.go +++ b/pkg/ensign/mock/mock.go @@ -85,6 +85,12 @@ func (s *Ensign) Client(ctx context.Context, opts ...grpc.DialOption) (client ap return s.client, nil } +// Reset the client with the new dial options +func (s *Ensign) ResetClient(ctx context.Context, opts ...grpc.DialOption) (api.EnsignClient, error) { + s.client = nil + return s.Client(ctx, opts...) +} + // Return the bufconn channel (helpful for dialing) func (s *Ensign) Channel() *bufconn.Listener { return s.bufnet