Skip to content

Commit

Permalink
JWT white list support
Browse files Browse the repository at this point in the history
Signed-off-by: Avelino <t@avelino.xxx>
  • Loading branch information
avelino committed Oct 31, 2020
1 parent ab086a5 commit e6a4244
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 1 deletion.
3 changes: 3 additions & 0 deletions config/config.go
Expand Up @@ -53,6 +53,7 @@ type Prest struct {
PGConnTimeout int
JWTKey string
JWTAlgo string
JWTWhiteList string `toml:"jwt_whitelist" cfg:"jwt_whitelist"`
MigrationsPath string
QueriesPath string
AccessConf AccessConf
Expand Down Expand Up @@ -103,6 +104,7 @@ func viperCfg() {
viper.SetDefault("debug", false)
viper.SetDefault("jwt.default", true)
viper.SetDefault("jwt.algo", "HS256")
viper.SetDefault("jwt.whitelist", "/auth")
viper.SetDefault("cors.allowheaders", []string{"*"})
viper.SetDefault("cache.enable", true)
viper.SetDefault("context", "/")
Expand Down Expand Up @@ -177,6 +179,7 @@ func Parse(cfg *Prest) (err error) {
cfg.PGConnTimeout = viper.GetInt("pg.conntimeout")
cfg.JWTKey = viper.GetString("jwt.key")
cfg.JWTAlgo = viper.GetString("jwt.algo")
cfg.JWTWhiteList = viper.GetString("jwt.whitelist")
cfg.MigrationsPath = viper.GetString("migrations")
cfg.AccessConf.Restrict = viper.GetBool("access.restrict")
cfg.QueriesPath = viper.GetString("queries.location")
Expand Down
22 changes: 21 additions & 1 deletion middlewares/middlewares.go
Expand Up @@ -3,6 +3,7 @@ package middlewares
import (
"context"
"fmt"
"log"
"net/http"
"net/http/httptest"
"strconv"
Expand Down Expand Up @@ -104,7 +105,26 @@ func JwtMiddleware(key string, algo string) negroni.Handler {
},
SigningMethod: jwt.GetSigningMethod(algo),
})
return negroni.HandlerFunc(jwtMiddleware.HandlerWithNext)
// return negroni.HandlerFunc(jwtMiddleware.HandlerWithNext)

return negroni.HandlerFunc(func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
match, err := MatchURL(r.URL.String())
if err != nil {
http.Error(w, fmt.Sprintf(`{"error": "%v"}`, err), http.StatusInternalServerError)
return
}
if match {
next(w, r)
return
}
err = jwtMiddleware.CheckJWT(w, r)
if err != nil {
log.Println("check jwt error", err.Error())
w.Write([]byte(fmt.Sprintf(`{"error": "%v"}`, err.Error())))
return
}
next(w, r)
})
}

// Cors middleware
Expand Down
13 changes: 13 additions & 0 deletions middlewares/utils.go
Expand Up @@ -6,9 +6,11 @@ import (
"io/ioutil"
"net/http"
"net/http/httptest"
"regexp"
"strings"

"github.com/clbanning/mxj/j2x"
"github.com/prest/prest/config"
"github.com/prest/prest/middlewares/statements"
)

Expand Down Expand Up @@ -112,3 +114,14 @@ func checkCors(r *http.Request, origin []string) (allowed bool) {
}
return
}

// MatchURL matches the given url with a whitelist from config.core
func MatchURL(url string) (match bool, err error) {
for _, exp := range strings.Fields(config.PrestConf.JWTWhiteList) {
match, err = regexp.Match(exp, []byte(url))
if match || err != nil {
return
}
}
return
}
48 changes: 48 additions & 0 deletions middlewares/utils_test.go
Expand Up @@ -56,3 +56,51 @@ func Test_checkCors(t *testing.T) {
t.Error("expected false, got true")
}
}


func TestMatchURL(t *testing.T) {
test := []struct {
Label string
URL string
JWTWhiteList string
match bool
}{
{
Label: "auth",
URL: "/auth",
JWTWhiteList: `\/auth`,
match: true,
},
{
Label: "auth regex",
URL: "/auth/any",
JWTWhiteList: `\/auth\/.*`,
match: true,
},
{
Label: "auth2 lock",
URL: "/auth2",
JWTWhiteList: `\/auth`,
match: true,
},
{
Label: "multi allow",
URL: "/auth",
JWTWhiteList: `\/auth \/databases`,
match: true,
}
}

for _, tt := range test {
t.Run(tt.Label, func(t *testing.T) {
config.Get.JWTWhiteList = tt.JWTWhiteList
match, err := MatchURL(tt.URL)
if err != nil {
t.Error(err)
}
if match != tt.match {
t.Errorf("expected %v, but got %v\n", tt.match, match)
}
})
}
}

0 comments on commit e6a4244

Please sign in to comment.