-
Notifications
You must be signed in to change notification settings - Fork 351
/
auth_middleware.go
156 lines (148 loc) · 5.17 KB
/
auth_middleware.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
package api
import (
"context"
"fmt"
"net/http"
"strconv"
"strings"
"github.com/getkin/kin-openapi/openapi3"
"github.com/getkin/kin-openapi/routers"
"github.com/getkin/kin-openapi/routers/legacy"
"github.com/golang-jwt/jwt"
"github.com/treeverse/lakefs/pkg/auth"
"github.com/treeverse/lakefs/pkg/auth/model"
"github.com/treeverse/lakefs/pkg/logging"
)
// extractSecurityRequirements using Swagger returns an array of security requirements set for the request.
func extractSecurityRequirements(router routers.Router, r *http.Request) (openapi3.SecurityRequirements, error) {
// Find route
route, _, err := router.FindRoute(r)
if err != nil {
return nil, err
}
if route.Operation.Security == nil {
return route.Swagger.Security, nil
}
return *route.Operation.Security, nil
}
func AuthMiddleware(logger logging.Logger, swagger *openapi3.Swagger, authenticator auth.Authenticator, authService auth.Service) func(next http.Handler) http.Handler {
router, err := legacy.NewRouter(swagger)
if err != nil {
panic(err)
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
securityRequirements, err := extractSecurityRequirements(router, r)
if err != nil {
writeError(w, http.StatusBadRequest, err)
return
}
user, err := checkSecurityRequirements(r, securityRequirements, logger, authenticator, authService)
if err != nil {
writeError(w, http.StatusUnauthorized, err)
return
}
if user != nil {
r = r.WithContext(context.WithValue(r.Context(), UserContextKey, user))
}
next.ServeHTTP(w, r)
})
}
}
// checkSecurityRequirements goes over the security requirements and check the authentication. returns the user information and error if the security check was required.
// it will return nil user and error in case of no security checks to match.
func checkSecurityRequirements(r *http.Request, securityRequirements openapi3.SecurityRequirements, logger logging.Logger, authenticator auth.Authenticator, authService auth.Service) (*model.User, error) {
ctx := r.Context()
var user *model.User
var err error
logger = logger.WithContext(ctx)
for _, securityRequirement := range securityRequirements {
for provider := range securityRequirement {
switch provider {
case "jwt_token":
// validate jwt token from header
authHeaderValue := r.Header.Get("Authorization")
if authHeaderValue == "" {
continue
}
parts := strings.Fields(authHeaderValue)
if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") {
continue
}
token := parts[1]
user, err = userByToken(ctx, logger, authService, token)
case "basic_auth":
// validate using basic auth
accessKey, secretKey, ok := r.BasicAuth()
if !ok {
continue
}
user, err = userByAuth(ctx, logger, authenticator, authService, accessKey, secretKey)
case "cookie_auth":
// validate jwt token from cookie
jwtCookie, _ := r.Cookie(JWTCookieName)
if jwtCookie == nil {
continue
}
user, err = userByToken(ctx, logger, authService, jwtCookie.Value)
default:
// unknown security requirement to check
logger.WithField("provider", provider).Error("Authentication middleware unknown security requirement provider")
return nil, ErrAuthenticatingRequest
}
if err != nil {
return nil, err
}
if user != nil {
return user, nil
}
}
}
return nil, nil
}
func userByToken(ctx context.Context, logger logging.Logger, authService auth.Service, tokenString string) (*model.User, error) {
claims := &jwt.StandardClaims{}
token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("%w: %s", ErrUnexpectedSigningMethod, token.Header["alg"])
}
return authService.SecretStore().SharedSecret(), nil
})
if err != nil {
return nil, ErrAuthenticatingRequest
}
claims, ok := token.Claims.(*jwt.StandardClaims)
if !ok || !token.Valid {
return nil, ErrAuthenticatingRequest
}
const base = 10
const bitSize = 32
id, err := strconv.ParseInt(claims.Subject, base, bitSize)
if err != nil {
logger.WithField("subject", claims.Subject).Info("could not parse user ID on token")
return nil, ErrAuthenticatingRequest
}
userData, err := authService.GetUserByID(ctx, int(id))
if err != nil {
logger.WithFields(logging.Fields{
"user_id": id,
"subject": claims.Subject,
}).Debug("could not find user id by credentials")
return nil, ErrAuthenticatingRequest
}
return userData, nil
}
func userByAuth(ctx context.Context, logger logging.Logger, authenticator auth.Authenticator, authService auth.Service, accessKey string, secretKey string) (*model.User, error) {
// TODO(ariels): Rename keys.
id, err := authenticator.AuthenticateUser(ctx, accessKey, secretKey)
if err != nil {
logger.WithError(err).WithField("user", accessKey).Error("authenticate")
return nil, ErrAuthenticatingRequest
}
user, err := authService.GetUserByID(ctx, id)
if err != nil {
logger.WithError(err).WithFields(logging.Fields{"user_id": id}).Debug("could not find user id by credentials")
return nil, ErrAuthenticatingRequest
}
return user, nil
}