/
sessions.go
207 lines (185 loc) · 5.71 KB
/
sessions.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
package server
import (
"crypto/rand"
"encoding/ascii85"
"encoding/gob"
"errors"
"net/http"
"time"
"github.com/alexedwards/scs/v2"
"github.com/crewjam/saml"
"github.com/crewjam/saml/samlsp"
"github.com/sjansen/bouncer/internal/authz"
)
const idleTimeout = time.Hour // TODO
const sessionCookieName = "sessionid"
const sessionLifetime = 8 * time.Hour // TODO
const trackerCookieName = "relaystate"
const trackerLifetime = 5 * time.Minute
func (s *Server) addSCS(relaystate, sessions scs.Store) {
domain := s.config.AppURL.Hostname()
// relaystate
sm := scs.New()
sm.Cookie.Domain = domain
sm.Cookie.HttpOnly = true
sm.Cookie.Name = trackerCookieName
sm.Cookie.Persist = false
sm.Cookie.SameSite = http.SameSiteNoneMode
if domain == "localhost" || domain == "127.0.0.1" {
sm.Cookie.Secure = false
} else {
sm.Cookie.Secure = true
}
sm.IdleTimeout = trackerLifetime
sm.Lifetime = trackerLifetime
if relaystate != nil {
sm.Store = relaystate
}
s.relaystate = sm
s.saml.RequestTracker = s
gob.Register([]samlsp.TrackedRequest{})
// sessions
sm = scs.New()
sm.Cookie.Domain = domain
sm.Cookie.HttpOnly = true
sm.Cookie.Name = sessionCookieName
sm.Cookie.Persist = true
if domain == "localhost" || domain == "127.0.0.1" {
sm.Cookie.Secure = false
} else {
sm.Cookie.SameSite = http.SameSiteNoneMode
// NOTE: enabling strict mode triggers a redirect loop when
// running behind CloudFront and I haven't figured out why
//sm.Cookie.SameSite = http.SameSiteStrictMode
sm.Cookie.Secure = true
}
sm.IdleTimeout = idleTimeout
sm.Lifetime = sessionLifetime
if sessions != nil {
sm.Store = sessions
}
s.sess = sm
s.saml.Session = s
}
// CreateSession is called when we have received a valid SAML assertion and
// should create a new session and modify the http response accordingly, e.g. by
// setting a cookie.
func (s *Server) CreateSession(w http.ResponseWriter, r *http.Request, assertion *saml.Assertion) error {
ctx := r.Context()
err := s.sess.RenewToken(ctx)
if err != nil {
return err
}
u := authz.User{}
for _, attributeStatement := range assertion.AttributeStatements {
for _, attr := range attributeStatement.Attributes {
claimName := attr.FriendlyName
if claimName == "" {
claimName = attr.Name
}
for _, value := range attr.Values {
switch claimName {
case "email":
u.Email = value.Value
case "firstName":
u.GivenName = value.Value
case "lastName":
u.FamilyName = value.Value
case "roles":
u.Roles = append(u.Roles, value.Value)
}
}
}
}
s.sess.Put(ctx, "User", u)
_ = s.relaystate.Destroy(ctx)
return nil
}
// DeleteSession is called to modify the response such that it removed the current
// session, e.g. by deleting a cookie.
func (s *Server) DeleteSession(w http.ResponseWriter, r *http.Request) error {
return s.sess.Destroy(r.Context())
}
// GetSession returns the current samlsp.Session associated with the request, or
// ErrNoSession if there is no valid session.
func (s *Server) GetSession(r *http.Request) (samlsp.Session, error) {
ctx := r.Context()
if u, ok := s.sess.Get(ctx, "User").(authz.User); ok {
return &u, nil
}
return nil, samlsp.ErrNoSession
}
// ErrNoTrackedRequest is returned for invalid and expired relay states
var ErrNoTrackedRequest = errors.New("saml: tracked request not present")
const trackedRequestsKey = "TrackedRequests"
const trackedRequestsLimit = 10
// GetTrackedRequest returns a pending tracked request.
func (s *Server) GetTrackedRequest(r *http.Request, index string) (*samlsp.TrackedRequest, error) {
requests, ok := s.relaystate.Get(r.Context(), trackedRequestsKey).([]samlsp.TrackedRequest)
if !ok {
return nil, ErrNoTrackedRequest
}
for _, r := range requests {
if r.Index == index {
return &r, nil
}
}
return nil, ErrNoTrackedRequest
}
// GetTrackedRequests returns all the pending tracked requests
func (s *Server) GetTrackedRequests(r *http.Request) []samlsp.TrackedRequest {
requests, ok := s.relaystate.Get(r.Context(), trackedRequestsKey).([]samlsp.TrackedRequest)
if ok {
return requests
}
return []samlsp.TrackedRequest{}
}
// StopTrackingRequest stops tracking the SAML request given by index, which is a string
// previously returned from TrackRequest
func (s *Server) StopTrackingRequest(w http.ResponseWriter, r *http.Request, index string) error {
ctx := r.Context()
requests, ok := s.relaystate.Get(ctx, trackedRequestsKey).([]samlsp.TrackedRequest)
if ok {
for i := len(requests) - 1; i >= 0; i-- {
if requests[i].Index == index {
copy(requests[i:], requests[i+1:])
requests = requests[:len(requests)-1]
}
}
if len(requests) > 0 {
s.relaystate.Put(ctx, trackedRequestsKey, requests)
} else {
s.relaystate.Remove(ctx, trackedRequestsKey)
}
}
return nil
}
// TrackRequest starts tracking the SAML request with the given ID. It returns an
// `index` that should be used as the RelayState in the SAMl request flow.
func (s *Server) TrackRequest(w http.ResponseWriter, r *http.Request, samlRequestID string) (string, error) {
src := make([]byte, 29)
if _, err := rand.Read(src); err != nil {
return "", err
}
dst := make([]byte, ascii85.MaxEncodedLen(len(src)))
ascii85.Encode(dst, src)
index := string(dst)
request := samlsp.TrackedRequest{
Index: index,
SAMLRequestID: samlRequestID,
URI: r.URL.String(),
}
ctx := r.Context()
requests, ok := s.relaystate.Get(ctx, trackedRequestsKey).([]samlsp.TrackedRequest)
switch {
case ok && len(requests) < trackedRequestsLimit:
requests = append(requests, request)
case ok:
copy(requests, requests[1:])
requests[len(requests)-1] = request
default:
requests = []samlsp.TrackedRequest{request}
}
s.relaystate.Put(ctx, trackedRequestsKey, requests)
return index, nil
}