Skip to content

Commit

Permalink
feat: added initial implementation of authorization middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
JigarJoshi committed Mar 22, 2022
1 parent 37a2ec7 commit 5678cce
Show file tree
Hide file tree
Showing 8 changed files with 151 additions and 264 deletions.
5 changes: 4 additions & 1 deletion go.mod
Expand Up @@ -4,6 +4,7 @@ go 1.18

require (
github.com/apple/foundationdb/bindings/go v0.0.0-20211207225159-47b9a81d1c10
github.com/auth0/go-jwt-middleware/v2 v2.0.0
github.com/buger/jsonparser v1.1.1
github.com/davecgh/go-spew v1.1.1
github.com/deepmap/oapi-codegen v1.9.1
Expand All @@ -23,6 +24,7 @@ require (
github.com/spf13/pflag v1.0.5
github.com/spf13/viper v1.10.1
github.com/stretchr/testify v1.7.0
github.com/valyala/bytebufferpool v1.0.0
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1
google.golang.org/genproto v0.0.0-20220118154757-00ab72f36ad5
google.golang.org/grpc v1.43.0
Expand Down Expand Up @@ -59,7 +61,6 @@ require (
github.com/spf13/cast v1.4.1 // indirect
github.com/spf13/jwalterweatherman v1.1.0 // indirect
github.com/subosito/gotenv v1.2.0 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/valyala/fasthttp v1.33.0 // indirect
github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f // indirect
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect
Expand All @@ -68,11 +69,13 @@ require (
github.com/yudai/gojsondiff v1.0.0 // indirect
github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82 // indirect
github.com/yudai/pp v2.0.1+incompatible // indirect
golang.org/x/crypto v0.0.0-20220112180741-5e0467b6c7ce // indirect
golang.org/x/net v0.0.0-20220111093109-d55c255bac03 // indirect
golang.org/x/sys v0.0.0-20220111092808-5a964db01320 // indirect
golang.org/x/text v0.3.7 // indirect
google.golang.org/grpc/examples v0.0.0-20220215234149-ec717cad7395 // indirect
gopkg.in/ini.v1 v1.66.2 // indirect
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
moul.io/http2curl v1.0.0 // indirect
)
255 changes: 5 additions & 250 deletions go.sum

Large diffs are not rendered by default.

16 changes: 16 additions & 0 deletions server/config/options.go
Expand Up @@ -15,6 +15,8 @@
package config

import (
"time"

"github.com/tigrisdata/tigrisdb/util/log"
)

Expand All @@ -26,9 +28,17 @@ type ServerConfig struct {
type Config struct {
Server ServerConfig `yaml:"server" json:"server"`
Log log.LogConfig
Auth AuthConfig `yaml:"auth" json:"auth"`
FoundationDB FoundationDBConfig
}

type AuthConfig struct {
IssuerURL string
Audience string
JWKSCacheTimeout time.Duration
LogOnly bool
}

var DefaultConfig = Config{
Log: log.LogConfig{
Level: "trace",
Expand All @@ -37,6 +47,12 @@ var DefaultConfig = Config{
Host: "0.0.0.0",
Port: 8081,
},
Auth: AuthConfig{
IssuerURL: "https://tigrisdata-dev.us.auth0.com/",
Audience: "https://tigris-db-api",
JWKSCacheTimeout: 5 * time.Minute,
LogOnly: true,
},
}

// FoundationDBConfig keeps FoundationDB configuration parameters
Expand Down
2 changes: 1 addition & 1 deletion server/grpc/grpc.go
Expand Up @@ -30,7 +30,7 @@ type Server struct {
func NewServer(cfg *config.Config) *Server {
s := &Server{}

unary, stream := middleware.Get()
unary, stream := middleware.Get(cfg)
s.Server = grpc.NewServer(grpc.StreamInterceptor(stream), grpc.UnaryInterceptor(unary))

return s
Expand Down
2 changes: 1 addition & 1 deletion server/http/http.go
Expand Up @@ -43,7 +43,7 @@ func NewServer(cfg *config.Config) *Server {
},
}

unary, stream := middleware.Get()
unary, stream := middleware.Get(cfg)

s.Inproc.WithServerStreamInterceptor(stream)
s.Inproc.WithServerUnaryInterceptor(unary)
Expand Down
57 changes: 49 additions & 8 deletions server/midddleware/auth.go
Expand Up @@ -16,12 +16,17 @@ package middleware

import (
"context"
"net/url"
"strings"
"time"

"github.com/auth0/go-jwt-middleware/v2/jwks"
"github.com/auth0/go-jwt-middleware/v2/validator"
"github.com/grpc-ecosystem/go-grpc-middleware/util/metautils"
"github.com/rs/zerolog/log"
api "github.com/tigrisdata/tigrisdb/api/server/v1"
"github.com/tigrisdata/tigrisdb/server/config"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

var (
Expand All @@ -40,20 +45,56 @@ func getHeader(ctx context.Context, header string) string {
func AuthFromMD(ctx context.Context, expectedScheme string) (string, error) {
val := getHeader(ctx, headerAuthorize)
if val == "" {
return "", status.Errorf(codes.Unauthenticated, "Request unauthenticated with "+expectedScheme)
return "", api.Error(codes.Unauthenticated, "request unauthenticated with "+expectedScheme)
}
splits := strings.SplitN(val, " ", 2)
if len(splits) < 2 {
return "", status.Errorf(codes.Unauthenticated, "Bad authorization string")
return "", api.Error(codes.Unauthenticated, "bad authorization string")
}
if !strings.EqualFold(splits[0], expectedScheme) {
return "", status.Errorf(codes.Unauthenticated, "Request unauthenticated with "+expectedScheme)
return "", api.Error(codes.Unauthenticated, "request unauthenticated with bearer")
}
return splits[1], nil
}

func AuthFunc(ctx context.Context) (context.Context, error) {
_, err := AuthFromMD(ctx, "bearer")
log.Debug().Str("error", err.Error()).Msg("auth interceptor")
return context.WithValue(ctx, key("role"), "admin"), nil
func GetJWTValidator(config *config.Config) *validator.Validator {
issuerURL, _ := url.Parse(config.Auth.IssuerURL)
provider := jwks.NewCachingProvider(issuerURL, config.Auth.JWKSCacheTimeout)

jwtValidator, err := validator.New(
provider.KeyFunc,
validator.RS256,
issuerURL.String(),
[]string{config.Auth.Audience},
validator.WithAllowedClockSkew(time.Minute),
)

if err != nil {
log.Fatal().Err(err).Msg("Failed to configure JWTValidator")
}
return jwtValidator
}

func AuthFunction(ctx context.Context, jwtValidator *validator.Validator, config *config.Config) (ctxResult context.Context, err error) {
defer func() {
if err != nil {
log.Warn().Bool("log_only?", config.Auth.LogOnly).Str("error", err.Error()).Err(err).Msg("could not validate token")
if config.Auth.LogOnly {
err = nil
}
}
}()

token, err := AuthFromMD(ctx, "bearer")
if err != nil {
return ctx, err
}

validToken, err := jwtValidator.ValidateToken(ctx, token)
if err != nil {
return ctx, api.Error(codes.Unauthenticated, err.Error())
}

log.Debug().Msg("Valid token received")
return context.WithValue(ctx, key("token"), validToken), nil
}
63 changes: 63 additions & 0 deletions server/midddleware/auth_test.go
@@ -0,0 +1,63 @@
package middleware

import (
"context"
"testing"

"github.com/auth0/go-jwt-middleware/v2/validator"
"github.com/stretchr/testify/require"
api "github.com/tigrisdata/tigrisdb/api/server/v1"
"github.com/tigrisdata/tigrisdb/server/config"
"github.com/tigrisdata/tigrisdb/util/log"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
)

func TestNoToken(t *testing.T) {
enforcedAuthConfig := config.Config{
Server: config.ServerConfig{},
Log: log.LogConfig{
Level: "trace",
},
Auth: config.AuthConfig{
IssuerURL: "",
Audience: "",
JWKSCacheTimeout: 0,
LogOnly: false,
},
FoundationDB: config.FoundationDBConfig{},
}

t.Run("log_only mode: no token", func(t *testing.T) {
ctx, err := AuthFunction(context.TODO(), &validator.Validator{}, &config.DefaultConfig)
require.NotNil(t, ctx)
require.Nil(t, err)
})

t.Run("enforcing mode: no token", func(t *testing.T) {
_, err := AuthFunction(context.TODO(), &validator.Validator{}, &enforcedAuthConfig)
require.NotNil(t, err)
require.Equal(t, err, api.Error(codes.Unauthenticated, "request unauthenticated with bearer"))
})

t.Run("enforcing mode: Bad authorization string1", func(t *testing.T) {
incomingCtx := metadata.NewIncomingContext(context.TODO(), metadata.Pairs("authorization", "bearer"))
_, err := AuthFunction(incomingCtx, &validator.Validator{}, &enforcedAuthConfig)
require.NotNil(t, err)
require.Equal(t, err, api.Error(codes.Unauthenticated, "bad authorization string"))
})

t.Run("enforcing mode: Bad token", func(t *testing.T) {
incomingCtx := metadata.NewIncomingContext(context.TODO(), metadata.Pairs("authorization", "bearer somebadtoken"))
_, err := AuthFunction(incomingCtx, &validator.Validator{}, &enforcedAuthConfig)
require.NotNil(t, err)
require.Equal(t, err, api.Error(codes.Unauthenticated, "could not parse the token: square/go-jose: compact JWS format must have three parts"))
})

t.Run("enforcing mode: Bad token 2", func(t *testing.T) {
incomingCtx := metadata.NewIncomingContext(context.TODO(), metadata.Pairs("authorization", "bearer some.bad.token"))
_, err := AuthFunction(incomingCtx, &validator.Validator{}, &enforcedAuthConfig)
require.NotNil(t, err)
require.Contains(t, err.Error(), "could not parse the token: illegal base64 data")
})
}
15 changes: 12 additions & 3 deletions server/midddleware/middleware.go
Expand Up @@ -15,6 +15,8 @@
package middleware

import (
"context"

middleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpc_zerolog "github.com/grpc-ecosystem/go-grpc-middleware/providers/zerolog/v2"
grpc_opentracing "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing"
Expand All @@ -23,17 +25,24 @@ import (
grpc_ratelimit "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/ratelimit"
grpc_recovery "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/recovery"
"github.com/rs/zerolog/log"
"github.com/tigrisdata/tigrisdb/server/config"
"google.golang.org/grpc"
)

func Get() (grpc.UnaryServerInterceptor, grpc.StreamServerInterceptor) {
func Get(config *config.Config) (grpc.UnaryServerInterceptor, grpc.StreamServerInterceptor) {
jwtValidator := GetJWTValidator(config)
// inline closure to access the state of jwtValidator
authFunction := func(ctx context.Context) (context.Context, error) {
return AuthFunction(ctx, jwtValidator, config)
}

// adding all the middlewares for the server stream
//
// Note: we don't add validate here and rather call it in server code because the validator interceptor returns gRPC
// error which is not convertible to the internal rest error code.
stream := middleware.ChainStreamServer(
grpc_ratelimit.StreamServerInterceptor(&RateLimiter{}),
grpc_auth.StreamServerInterceptor(AuthFunc),
grpc_auth.StreamServerInterceptor(authFunction),
grpc_logging.StreamServerInterceptor(grpc_zerolog.InterceptorLogger(log.Logger), []grpc_logging.Option{}...),
// grpc_logging.PayloadStreamServerInterceptor(grpc_zerolog.InterceptorLogger(log.Logger), alwaysLoggingDeciderServer, time.RFC3339), // To log payload
grpc_opentracing.StreamServerInterceptor(),
Expand All @@ -47,7 +56,7 @@ func Get() (grpc.UnaryServerInterceptor, grpc.StreamServerInterceptor) {
unary := middleware.ChainUnaryServer(
PprofUnaryServerInterceptor(),
grpc_ratelimit.UnaryServerInterceptor(&RateLimiter{}),
grpc_auth.UnaryServerInterceptor(AuthFunc),
grpc_auth.UnaryServerInterceptor(authFunction),
grpc_logging.UnaryServerInterceptor(grpc_zerolog.InterceptorLogger(log.Logger)),
// grpc_logging.PayloadUnaryServerInterceptor(grpc_zerolog.InterceptorLogger(log.Logger), alwaysLoggingDeciderServer, time.RFC3339), //To log payload
TimeoutUnaryServerInterceptor(DefaultTimeout),
Expand Down

0 comments on commit 5678cce

Please sign in to comment.