Skip to content

Commit

Permalink
API: Add support for JWT generated by IDPs (#871)
Browse files Browse the repository at this point in the history
* Adding JWT_WELLKNOWN and JWT_JWKS to support RSA 256 JWT Tokens
* Update of the Go version and tweak to docker-compose to bring back local unit tests
* Updated the unit tests for fetchJWKS and wellknown parameter
* - Update of Dockerfile of testdata to use golang 1.22
- Refactoring of logic in JwtMiddleware
- Removing of comment in french that was a duplicate of correct comment in middlewares_test.go
- Added timeout to WellKnown Config http call
- Renaming of JWTWellKnown config parameter to JWTWellKnownURL for more precision
* Moved the close server statements to defer function to ensure proper cleanup.
  • Loading branch information
mcharest-mcn committed Apr 16, 2024
1 parent 366719f commit a3b6302
Show file tree
Hide file tree
Showing 9 changed files with 323 additions and 8 deletions.
68 changes: 68 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
package config

import (
"context"
"encoding/json"
"fmt"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"

"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/prest/prest/adapters"
"github.com/prest/prest/cache"

"net/http"
"time"

homedir "github.com/mitchellh/go-homedir"
"github.com/spf13/viper"
"github.com/structy/log"
Expand Down Expand Up @@ -82,6 +89,8 @@ type Prest struct {
PGCache bool
JWTKey string
JWTAlgo string
JWTWellKnownURL string
JWTJWKS string
JWTWhiteList []string
JSONAggType string
MigrationsPath string
Expand Down Expand Up @@ -171,6 +180,8 @@ func viperCfg() {

viper.SetDefault("jwt.default", true)
viper.SetDefault("jwt.algo", "HS256")
viper.SetDefault("jwt.wellknownurl", "")
viper.SetDefault("jwt.jwks", "")
viper.SetDefault("jwt.whitelist", []string{"/auth"})

viper.SetDefault("json.agg.type", "jsonb_agg")
Expand Down Expand Up @@ -271,7 +282,10 @@ func Parse(cfg *Prest) {
cfg.SingleDB = viper.GetBool("pg.single")
cfg.JWTKey = viper.GetString("jwt.key")
cfg.JWTAlgo = viper.GetString("jwt.algo")
cfg.JWTWellKnownURL = viper.GetString("jwt.wellknownurl")
cfg.JWTJWKS = viper.GetString("jwt.jwks")
cfg.JWTWhiteList = viper.GetStringSlice("jwt.whitelist")
fetchJWKS(cfg)

cfg.JSONAggType = getJSONAgg()

Expand Down Expand Up @@ -358,6 +372,60 @@ func parseDatabaseURL(cfg *Prest) {
}
}

// fetchJWKS tries to get the JWKS from the URL in the config
func fetchJWKS(cfg *Prest) {
if cfg.JWTWellKnownURL == "" {
log.Debugln("no JWT WellKnown url found, skipping")
return
}
if cfg.JWTJWKS != "" {
log.Debugln("JWKS already set, skipping")
return
}

// Call provider to obtain .well-known config
client := &http.Client{
Timeout: 5 * time.Second,
}

r, err := client.Get(cfg.JWTWellKnownURL)
if err != nil {
log.Errorf("Cannot get .well-known configuration from '%s'. err: %v\n", cfg.JWTWellKnownURL, err)
return
}
defer r.Body.Close()

var wellKnown map[string]interface{}
err = json.NewDecoder(r.Body).Decode(&wellKnown)
if err != nil {
log.Errorf("Failed to decode JSON: %v\n", err)
return
}

//Retrieve the JWKS from the endpoint
uri, ok := wellKnown["jwks_uri"].(string)
if !ok {
log.Errorf("Unable to convert .WellKnown configuration of jwks_uri to a string.")
return
}

JWKSet, err := jwk.Fetch(context.Background(), uri)
if err != nil {
err := fmt.Errorf("failed to parse JWK: %s", err)
log.Errorf("Failed to fetch JWK: %v\n", err)
return
}

//Convert set to json string
jwkSetJSON, err := json.Marshal(JWKSet)
if err != nil {
log.Errorf("Failed to marshal JWKSet to JSON: %v\n", err)
return
}

cfg.JWTJWKS = string(jwkSetJSON)
}

func portFromEnv(cfg *Prest) {
if os.Getenv("PORT") == "" {
log.Debugln("could not find PORT in env")
Expand Down
55 changes: 55 additions & 0 deletions config/config_test.go

Large diffs are not rendered by default.

13 changes: 11 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,24 @@ require (

require (
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/lestrrat-go/blackmagic v1.0.2 // indirect
github.com/lestrrat-go/httpcc v1.0.1 // indirect
github.com/lestrrat-go/httprc v1.0.4 // indirect
github.com/lestrrat-go/iter v1.0.2 // indirect
github.com/lestrrat-go/jwx/v2 v2.0.20 // indirect
github.com/lestrrat-go/option v1.0.1 // indirect
github.com/magiconair/properties v1.8.7 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/pelletier/go-toml/v2 v2.1.0 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/sagikazarmark/locafero v0.4.0 // indirect
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
github.com/segmentio/asm v1.2.0 // indirect
github.com/sourcegraph/conc v0.3.0 // indirect
github.com/spf13/afero v1.11.0 // indirect
github.com/spf13/cast v1.6.0 // indirect
Expand All @@ -46,9 +55,9 @@ require (
github.com/tidwall/rtred v0.1.2 // indirect
github.com/tidwall/tinyqueue v0.1.1 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/crypto v0.17.0 // indirect
golang.org/x/crypto v0.21.0 // indirect
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect
golang.org/x/sys v0.15.0 // indirect
golang.org/x/sys v0.18.0 // indirect
golang.org/x/text v0.14.0 // indirect
golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
Expand Down
23 changes: 23 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,17 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 h1:8UrgZ3GkP4i/CLijOJx79Yu+etlyjdBU4sfcs2WYQMs=
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0/go.mod h1:v57UDF4pDQJcEfFUCRop3lJL149eHGSe9Jvczhzjo/0=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w=
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
Expand All @@ -37,6 +41,18 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/lestrrat-go/blackmagic v1.0.2 h1:Cg2gVSc9h7sz9NOByczrbUvLopQmXrfFx//N+AkAr5k=
github.com/lestrrat-go/blackmagic v1.0.2/go.mod h1:UrEqBzIR2U6CnzVyUtfM6oZNMt/7O7Vohk2J0OGSAtU=
github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZrIE=
github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E=
github.com/lestrrat-go/httprc v1.0.4 h1:bAZymwoZQb+Oq8MEbyipag7iSq6YIga8Wj6GOiJGdI8=
github.com/lestrrat-go/httprc v1.0.4/go.mod h1:mwwz3JMTPBjHUkkDv/IGJ39aALInZLrhBp0X7KGUZlo=
github.com/lestrrat-go/iter v1.0.2 h1:gMXo1q4c2pHmC3dn8LzRhJfP1ceCbgSiT9lUydIzltI=
github.com/lestrrat-go/iter v1.0.2/go.mod h1:Momfcq3AnRlRjI5b5O8/G5/BvpzrhoFTZcn06fEOPt4=
github.com/lestrrat-go/jwx/v2 v2.0.20 h1:sAgXuWS/t8ykxS9Bi2Qtn5Qhpakw1wrcjxChudjolCc=
github.com/lestrrat-go/jwx/v2 v2.0.20/go.mod h1:UlCSmKqw+agm5BsOBfEAbTvKsEApaGNqHAEUTv5PJC4=
github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU=
github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I=
github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/lib/pq v1.8.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
Expand Down Expand Up @@ -68,6 +84,8 @@ github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6ke
github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4=
github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE=
github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ=
github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys=
github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs=
github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc=
github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
Expand All @@ -86,6 +104,7 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
Expand Down Expand Up @@ -127,6 +146,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g=
golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
Expand All @@ -144,6 +165,8 @@ golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
Expand Down
2 changes: 1 addition & 1 deletion middlewares/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func initApp() {
if !config.PrestConf.Debug && config.PrestConf.EnableDefaultJWT {
MiddlewareStack = append(
MiddlewareStack,
JwtMiddleware(config.PrestConf.JWTKey, config.PrestConf.JWTAlgo))
JwtMiddleware(config.PrestConf.JWTKey, config.PrestConf.JWTJWKS))
}
if config.PrestConf.Cache.Enabled {
MiddlewareStack = append(MiddlewareStack, CacheMiddleware(&config.PrestConf.Cache))
Expand Down
36 changes: 34 additions & 2 deletions middlewares/middlewares.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"strings"
"time"

"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/prest/prest/config"
pctx "github.com/prest/prest/context"
"github.com/prest/prest/controllers/auth"
Expand Down Expand Up @@ -122,7 +123,7 @@ func AccessControl() negroni.Handler {
}

// JwtMiddleware check if actual request have JWT
func JwtMiddleware(key string, algo string) negroni.Handler {
func JwtMiddleware(key string, JWKSet string) negroni.Handler {
return negroni.HandlerFunc(func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
match, err := MatchURL(r.URL.String())
if err != nil {
Expand All @@ -147,7 +148,38 @@ func JwtMiddleware(key string, algo string) negroni.Handler {
return
}
out := auth.Claims{}
if err := tok.Claims([]byte(key), &out); err != nil {
var rawkey interface{} = []byte(key)

if JWKSet != "" {
parsedJWKSet, err := jwk.ParseString(JWKSet)
if err != nil {
err := fmt.Errorf("failed to parse JWKSet JSON string: %v", err)
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
for it := parsedJWKSet.Keys(context.Background()); it.Next(context.Background()); {
pair := it.Pair()
key := pair.Value.(jwk.Key)

if key.KeyID() == tok.Headers[0].KeyID {
if err := key.Raw(&rawkey); err != nil {
err := fmt.Errorf("failed to create public key: %s", err)
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
}
}
//Check if rawkey is empty
if key, ok := rawkey.(string); ok {
if key == "" {
err := fmt.Errorf("the token's key was not found in the JWKS")
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
}
}

if err := tok.Claims(rawkey, &out); err != nil {
http.Error(w, ErrJWTValidate.Error(), http.StatusUnauthorized)
return
}
Expand Down
Loading

0 comments on commit a3b6302

Please sign in to comment.