-
Notifications
You must be signed in to change notification settings - Fork 0
/
saml.go
334 lines (291 loc) · 10.4 KB
/
saml.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
package auth
import (
"bytes"
"compress/flate"
"encoding/base64"
"io/ioutil"
"time"
"github.com/gravitational/teleport"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/trace"
log "github.com/Sirupsen/logrus"
"github.com/beevik/etree"
saml2 "github.com/russellhaering/gosaml2"
)
func (s *AuthServer) UpsertSAMLConnector(connector services.SAMLConnector) error {
return s.Identity.UpsertSAMLConnector(connector)
}
func (s *AuthServer) DeleteSAMLConnector(connectorName string) error {
return s.Identity.DeleteSAMLConnector(connectorName)
}
func (s *AuthServer) CreateSAMLAuthRequest(req services.SAMLAuthRequest) (*services.SAMLAuthRequest, error) {
connector, err := s.Identity.GetSAMLConnector(req.ConnectorID, true)
if err != nil {
return nil, trace.Wrap(err)
}
provider, err := s.getSAMLProvider(connector)
if err != nil {
return nil, trace.Wrap(err)
}
doc, err := provider.BuildAuthRequestDocument()
if err != nil {
return nil, trace.Wrap(err)
}
attr := doc.Root().SelectAttr("ID")
if attr == nil || attr.Value == "" {
return nil, trace.BadParameter("missing auth request ID")
}
req.ID = attr.Value
req.RedirectURL, err = provider.BuildAuthURLFromDocument("", doc)
if err != nil {
return nil, trace.Wrap(err)
}
err = s.Identity.CreateSAMLAuthRequest(req, defaults.SAMLAuthRequestTTL)
if err != nil {
return nil, trace.Wrap(err)
}
return &req, nil
}
func (s *AuthServer) getSAMLProvider(conn services.SAMLConnector) (*saml2.SAMLServiceProvider, error) {
s.lock.Lock()
defer s.lock.Unlock()
providerPack, ok := s.samlProviders[conn.GetName()]
if ok && providerPack.connector.Equals(conn) {
return providerPack.provider, nil
}
delete(s.samlProviders, conn.GetName())
serviceProvider, err := conn.GetServiceProvider(s.clock)
if err != nil {
return nil, trace.Wrap(err)
}
s.samlProviders[conn.GetName()] = &samlProvider{connector: conn, provider: serviceProvider}
return serviceProvider, nil
}
// buildSAMLRoles takes a connector and claims and returns a slice of roles. If the claims
// match a concrete roles in the connector, those roles are returned directly. If the
// claims match a template role in the connector, then that role is first created from
// the template, then returned.
func (a *AuthServer) buildSAMLRoles(connector services.SAMLConnector, assertionInfo saml2.AssertionInfo, expiresAt time.Time) ([]string, error) {
roles := connector.MapAttributes(assertionInfo)
if len(roles) == 0 {
role, err := connector.RoleFromTemplate(assertionInfo)
if err != nil {
log.Warningf("[SAML] Unable to map claims to roles or role templates for %q: %v", connector.GetName(), err)
return nil, trace.AccessDenied("unable to map claims to roles or role templates for %q: %v", connector.GetName(), err)
}
// figure out ttl for role. expires = now + ttl => ttl = expires - now
ttl := expiresAt.Sub(a.clock.Now())
// upsert templated role
err = a.Access.UpsertRole(role, ttl)
if err != nil {
log.Warningf("[SAML] Unable to upsert templated role for connector: %q: %v", connector.GetName(), err)
return nil, trace.AccessDenied("unable to upsert templated role: %q: %v", connector.GetName(), err)
}
roles = []string{role.GetName()}
}
return roles, nil
}
func (a *AuthServer) createSAMLUser(connector services.SAMLConnector, assertionInfo saml2.AssertionInfo, expiresAt time.Time) error {
roles, err := a.buildSAMLRoles(connector, assertionInfo, expiresAt)
if err != nil {
return trace.Wrap(err)
}
log.Debugf("[SAML] %v/%v is a dynamic identity, generating user with roles: %v", connector.GetName(), assertionInfo.NameID, roles)
user, err := services.GetUserMarshaler().GenerateUser(&services.UserV2{
Kind: services.KindUser,
Version: services.V2,
Metadata: services.Metadata{
Name: assertionInfo.NameID,
Namespace: defaults.Namespace,
},
Spec: services.UserSpecV2{
Roles: roles,
Expires: expiresAt,
SAMLIdentities: []services.ExternalIdentity{{ConnectorID: connector.GetName(), Username: assertionInfo.NameID}},
CreatedBy: services.CreatedBy{
User: services.UserRef{Name: "system"},
Time: time.Now().UTC(),
Connector: &services.ConnectorRef{
Type: teleport.ConnectorSAML,
ID: connector.GetName(),
Identity: assertionInfo.NameID,
},
},
},
})
if err != nil {
return trace.Wrap(err)
}
// check if a user exists already
existingUser, err := a.GetUser(assertionInfo.NameID)
if err != nil {
if !trace.IsNotFound(err) {
return trace.Wrap(err)
}
}
// check if exisiting user is a non-saml user, if so, return an error
if existingUser != nil {
connectorRef := existingUser.GetCreatedBy().Connector
if connectorRef == nil || connectorRef.Type != teleport.ConnectorSAML || connectorRef.ID != connector.GetName() {
return trace.AlreadyExists("user %q already exists and is not SAML user", existingUser.GetName())
}
}
// no non-saml user exists, create or update the exisiting saml user
err = a.UpsertUser(user)
if err != nil {
return trace.Wrap(err)
}
return nil
}
func parseSAMLInResponseTo(response string) (string, error) {
raw, _ := base64.StdEncoding.DecodeString(response)
doc := etree.NewDocument()
err := doc.ReadFromBytes(raw)
if err != nil {
// Attempt to inflate the response in case it happens to be compressed (as with one case at saml.oktadev.com)
buf, err := ioutil.ReadAll(flate.NewReader(bytes.NewReader(raw)))
if err != nil {
return "", trace.Wrap(err)
}
doc = etree.NewDocument()
err = doc.ReadFromBytes(buf)
if err != nil {
return "", trace.Wrap(err)
}
}
if doc.Root() == nil {
return "", trace.BadParameter("unable to parse response")
}
el := doc.Root()
responseTo := el.SelectAttr("InResponseTo")
if responseTo == nil {
return "", trace.BadParameter("identity provider initiated flows are not supported")
}
if responseTo.Value == "" {
return "", trace.BadParameter("InResponseTo can not be empty")
}
return responseTo.Value, nil
}
// SAMLAuthResponse is returned when auth server validated callback parameters
// returned from SAML identity provider
type SAMLAuthResponse struct {
// Username is authenticated teleport username
Username string `json:"username"`
// Identity contains validated SAML identity
Identity services.ExternalIdentity `json:"identity"`
// Web session will be generated by auth server if requested in SAMLAuthRequest
Session services.WebSession `json:"session,omitempty"`
// Cert will be generated by certificate authority
Cert []byte `json:"cert,omitempty"`
// Req is original SAML auth request
Req services.SAMLAuthRequest `json:"req"`
// HostSigners is a list of signing host public keys
// trusted by proxy, used in console login
HostSigners []services.CertAuthority `json:"host_signers"`
}
// ValidateSAMLResponse consumes attribute statements from SAML identity provider
func (a *AuthServer) ValidateSAMLResponse(samlResponse string) (*SAMLAuthResponse, error) {
requestID, err := parseSAMLInResponseTo(samlResponse)
if err != nil {
return nil, trace.Wrap(err)
}
request, err := a.Identity.GetSAMLAuthRequest(requestID)
if err != nil {
return nil, trace.Wrap(err)
}
connector, err := a.Identity.GetSAMLConnector(request.ConnectorID, true)
if err != nil {
return nil, trace.Wrap(err)
}
provider, err := a.getSAMLProvider(connector)
if err != nil {
return nil, trace.Wrap(err)
}
assertionInfo, err := provider.RetrieveAssertionInfo(samlResponse)
if err != nil {
log.Warningf("SAML error: %v", err)
return nil, trace.AccessDenied("bad SAML response")
}
if assertionInfo.WarningInfo.InvalidTime {
log.Warningf("SAML error, invalid time")
return nil, trace.AccessDenied("bad SAML response")
}
if assertionInfo.WarningInfo.NotInAudience {
log.Warningf("SAML error, not in audience")
return nil, trace.AccessDenied("bad SAML response")
}
log.Debugf("[SAML] Obtained Assertions for %q", assertionInfo.NameID)
for key, val := range assertionInfo.Values {
var vals []string
for _, vv := range val.Values {
vals = append(vals, vv.Value)
}
log.Debugf("[SAML] Assertion: %q: %q", key, vals)
}
log.Debugf("[SAML] Assertion Warnings: %+v", assertionInfo.WarningInfo)
log.Debugf("[SAML] Applying %v claims to roles mappings", len(connector.GetAttributesToRoles()))
if len(connector.GetAttributesToRoles()) == 0 {
return nil, trace.BadParameter("SAML does not support binding to local users")
}
// TODO(klizhentas) use SessionNotOnOrAfter to calculate expiration time
expiresAt := a.clock.Now().Add(defaults.CertDuration)
if err := a.createSAMLUser(connector, *assertionInfo, expiresAt); err != nil {
return nil, trace.Wrap(err)
}
identity := services.ExternalIdentity{
ConnectorID: request.ConnectorID,
Username: assertionInfo.NameID,
}
user, err := a.Identity.GetUserBySAMLIdentity(identity)
if err != nil {
return nil, trace.Wrap(err)
}
response := &SAMLAuthResponse{
Req: *request,
Identity: identity,
Username: user.GetName(),
}
var roles services.RoleSet
roles, err = services.FetchRoles(user.GetRoles(), a.Access)
if err != nil {
return nil, trace.Wrap(err)
}
sessionTTL := roles.AdjustSessionTTL(utils.ToTTL(a.clock, expiresAt))
bearerTokenTTL := utils.MinTTL(BearerTokenTTL, sessionTTL)
if request.CreateWebSession {
sess, err := a.NewWebSession(user.GetName())
if err != nil {
return nil, trace.Wrap(err)
}
// session will expire based on identity TTL and allowed session TTL
sess.SetExpiryTime(a.clock.Now().UTC().Add(sessionTTL))
// bearer token will expire based on the expected session renewal
sess.SetBearerTokenExpiryTime(a.clock.Now().UTC().Add(bearerTokenTTL))
if err := a.UpsertWebSession(user.GetName(), sess); err != nil {
return nil, trace.Wrap(err)
}
response.Session = sess
}
if len(request.PublicKey) != 0 {
certTTL := utils.MinTTL(utils.ToTTL(a.clock, expiresAt), request.CertTTL)
allowedLogins, err := roles.CheckLogins(certTTL)
if err != nil {
return nil, trace.Wrap(err)
}
cert, err := a.GenerateUserCert(request.PublicKey, user, allowedLogins, certTTL, roles.CanForwardAgents(), request.Compatibility)
if err != nil {
return nil, trace.Wrap(err)
}
response.Cert = cert
authorities, err := a.GetCertAuthorities(services.HostCA, false)
if err != nil {
return nil, trace.Wrap(err)
}
for _, authority := range authorities {
response.HostSigners = append(response.HostSigners, authority)
}
}
return response, nil
}