-
Notifications
You must be signed in to change notification settings - Fork 153
/
auth.go
143 lines (119 loc) · 4.13 KB
/
auth.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
package auth
import (
"context"
"crypto/rand"
"encoding/base64"
"fmt"
"net/http"
"net/url"
"github.com/sethvargo/go-limiter/httplimit"
"github.com/sethvargo/go-limiter/memorystore"
)
const (
// StateCookieName is the name of the cookie that holds state during auth flow.
StateCookieName = "state"
// IDTokenCookieName is the name of the cookie that holds the ID Token once
// the user has authenticated successfully with the OIDC Provider.
IDTokenCookieName = "id_token"
// AccessTokenCookieName is the name of the cookie that holds the access token once
// the user has authenticated successfully with the OIDC Provider. It's used for further
// resource requests from the provider.
AccessTokenCookieName = "access_token"
// AuthorizationTokenHeaderName is the name of the header that holds the bearer token
// used for token passthrough authentication.
AuthorizationTokenHeaderName = "Authorization"
// ScopeProfile is the "profile" scope
scopeProfile = "profile"
// ScopeEmail is the "email" scope
scopeEmail = "email"
// ScopeGroups is the "groups" scope
scopeGroups = "groups"
)
// RegisterAuthServer registers the /callback route under a specified prefix.
// This route is called by the OIDC Provider in order to pass back state after
// the authentication flow completes.
func RegisterAuthServer(mux *http.ServeMux, prefix string, srv *AuthServer, loginRequestRateLimit uint64) error {
store, err := memorystore.New(&memorystore.Config{
Tokens: loginRequestRateLimit,
})
if err != nil {
return err
}
middleware, err := httplimit.NewMiddleware(store, httplimit.IPKeyFunc())
if err != nil {
return err
}
mux.Handle(prefix, srv.OAuth2Flow())
mux.Handle(prefix+"/callback", srv.Callback())
mux.Handle(prefix+"/sign_in", middleware.Handle(srv.SignIn()))
mux.Handle(prefix+"/userinfo", srv.UserInfo())
mux.Handle(prefix+"/logout", srv.Logout())
return nil
}
type principalCtxKey struct{}
// Principal gets the principal from the context.
func Principal(ctx context.Context) *UserPrincipal {
principal, ok := ctx.Value(principalCtxKey{}).(*UserPrincipal)
if ok {
return principal
}
return nil
}
// UserPrincipal is a simple model for the user, including their ID and Groups.
type UserPrincipal struct {
ID string `json:"id"`
Groups []string `json:"groups"`
Token string `json:"-"`
}
// String returns the Principal ID and Groups as a string.
func (p *UserPrincipal) String() string {
return fmt.Sprintf("id=%q groups=%v", p.ID, p.Groups)
}
// WithPrincipal sets the principal into the context.
func WithPrincipal(ctx context.Context, p *UserPrincipal) context.Context {
return context.WithValue(ctx, principalCtxKey{}, p)
}
// WithAPIAuth middleware adds auth validation to API handlers.
//
// Unauthorized requests will be denied with a 401 status code.
func WithAPIAuth(next http.Handler, srv *AuthServer, publicRoutes []string) http.Handler {
adminAuth := NewJWTAdminCookiePrincipalGetter(srv.Log, srv.tokenSignerVerifier, IDTokenCookieName)
tokenAuth := NewBearerTokenPassthroughPrincipalGetter(srv.Log, nil, AuthorizationTokenHeaderName)
multi := MultiAuthPrincipal{adminAuth, tokenAuth}
if srv.oidcEnabled() {
headerAuth := NewJWTAuthorizationHeaderPrincipalGetter(srv.Log, srv.verifier())
cookieAuth := NewJWTCookiePrincipalGetter(srv.Log, srv.verifier(), IDTokenCookieName)
multi = append(multi, headerAuth, cookieAuth)
}
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
if IsPublicRoute(r.URL, publicRoutes) {
next.ServeHTTP(rw, r)
return
}
principal, err := multi.Principal(r)
if err != nil {
srv.Log.Error(err, "failed to get principal")
}
if principal == nil || err != nil {
JSONError(srv.Log, rw, "Authentication required", http.StatusUnauthorized)
return
}
next.ServeHTTP(rw, r.Clone(WithPrincipal(r.Context(), principal)))
})
}
func generateNonce() (string, error) {
b := make([]byte, 32)
_, err := rand.Read(b)
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(b), nil
}
func IsPublicRoute(u *url.URL, publicRoutes []string) bool {
for _, pr := range publicRoutes {
if u.Path == pr {
return true
}
}
return false
}