-
-
Notifications
You must be signed in to change notification settings - Fork 1.5k
/
session.go
180 lines (149 loc) · 5.76 KB
/
session.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
// Copyright © 2022 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package oauth2
import (
"encoding/json"
"time"
"github.com/pkg/errors"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"github.com/mohae/deepcopy"
"github.com/ory/fosite"
"github.com/ory/fosite/handler/openid"
"github.com/ory/fosite/token/jwt"
"github.com/ory/x/stringslice"
)
// swagger:ignore
type Session struct {
*openid.DefaultSession `json:"id_token"`
Extra map[string]interface{} `json:"extra"`
KID string `json:"kid"`
ClientID string `json:"client_id"`
ConsentChallenge string `json:"consent_challenge"`
ExcludeNotBeforeClaim bool `json:"exclude_not_before_claim"`
AllowedTopLevelClaims []string `json:"allowed_top_level_claims"`
}
func NewSession(subject string) *Session {
return NewSessionWithCustomClaims(subject, nil)
}
func NewSessionWithCustomClaims(subject string, allowedTopLevelClaims []string) *Session {
return &Session{
DefaultSession: &openid.DefaultSession{
Claims: new(jwt.IDTokenClaims),
Headers: new(jwt.Headers),
Subject: subject,
},
Extra: map[string]interface{}{},
AllowedTopLevelClaims: allowedTopLevelClaims,
}
}
func (s *Session) GetJWTClaims() jwt.JWTClaimsContainer {
//a slice of claims that are reserved and should not be overridden
var reservedClaims = []string{"iss", "sub", "aud", "exp", "nbf", "iat", "jti", "client_id", "scp", "ext"}
//remove any reserved claims from the custom claims
allowedClaimsFromConfigWithoutReserved := stringslice.Filter(s.AllowedTopLevelClaims, func(s string) bool {
return stringslice.Has(reservedClaims, s)
})
//our new extra map which will be added to the jwt
var topLevelExtraWithMirrorExt = map[string]interface{}{}
//setting every allowed claim top level in jwt with respective value
for _, allowedClaim := range allowedClaimsFromConfigWithoutReserved {
if cl, ok := s.Extra[allowedClaim]; ok {
topLevelExtraWithMirrorExt[allowedClaim] = cl
}
}
//for every other claim that was already reserved and for mirroring, add original extra under "ext"
topLevelExtraWithMirrorExt["ext"] = s.Extra
claims := &jwt.JWTClaims{
Subject: s.Subject,
Issuer: s.DefaultSession.Claims.Issuer,
//set our custom extra map as claims.Extra
Extra: topLevelExtraWithMirrorExt,
ExpiresAt: s.GetExpiresAt(fosite.AccessToken),
IssuedAt: time.Now(),
// No need to set the audience because that's being done by fosite automatically.
// Audience: s.Audience,
// The JTI MUST NOT BE FIXED or refreshing tokens will yield the SAME token
// JTI: s.JTI,
// These are set by the DefaultJWTStrategy
// Scope: s.Scope,
// Setting these here will cause the token to have the same iat/nbf values always
// IssuedAt: s.DefaultSession.Claims.IssuedAt,
// NotBefore: s.DefaultSession.Claims.IssuedAt,
}
if !s.ExcludeNotBeforeClaim {
claims.NotBefore = claims.IssuedAt
}
if claims.Extra == nil {
claims.Extra = map[string]interface{}{}
}
claims.Extra["client_id"] = s.ClientID
return claims
}
func (s *Session) GetJWTHeader() *jwt.Headers {
return &jwt.Headers{
Extra: map[string]interface{}{"kid": s.KID},
}
}
func (s *Session) Clone() fosite.Session {
if s == nil {
return nil
}
return deepcopy.Copy(s).(fosite.Session)
}
var keyRewrites = map[string]string{
"Extra": "extra",
"KID": "kid",
"ClientID": "client_id",
"ConsentChallenge": "consent_challenge",
"ExcludeNotBeforeClaim": "exclude_not_before_claim",
"AllowedTopLevelClaims": "allowed_top_level_claims",
"idToken.Headers.Extra": "id_token.headers.extra",
"idToken.ExpiresAt": "id_token.expires_at",
"idToken.Username": "id_token.username",
"idToken.Subject": "id_token.subject",
"idToken.Claims.JTI": "id_token.id_token_claims.jti",
"idToken.Claims.Issuer": "id_token.id_token_claims.iss",
"idToken.Claims.Subject": "id_token.id_token_claims.sub",
"idToken.Claims.Audience": "id_token.id_token_claims.aud",
"idToken.Claims.Nonce": "id_token.id_token_claims.nonce",
"idToken.Claims.ExpiresAt": "id_token.id_token_claims.exp",
"idToken.Claims.IssuedAt": "id_token.id_token_claims.iat",
"idToken.Claims.RequestedAt": "id_token.id_token_claims.rat",
"idToken.Claims.AuthTime": "id_token.id_token_claims.auth_time",
"idToken.Claims.AccessTokenHash": "id_token.id_token_claims.at_hash",
"idToken.Claims.AuthenticationContextClassReference": "id_token.id_token_claims.acr",
"idToken.Claims.AuthenticationMethodsReferences": "id_token.id_token_claims.amr",
"idToken.Claims.CodeHash": "id_token.id_token_claims.c_hash",
"idToken.Claims.Extra": "id_token.id_token_claims.ext",
}
func (s *Session) UnmarshalJSON(original []byte) (err error) {
transformed := original
originalParsed := gjson.ParseBytes(original)
for oldKey, newKey := range keyRewrites {
if !originalParsed.Get(oldKey).Exists() {
continue
}
transformed, err = sjson.SetRawBytes(transformed, newKey, []byte(originalParsed.Get(oldKey).Raw))
if err != nil {
return errors.WithStack(err)
}
}
for orig := range keyRewrites {
transformed, err = sjson.DeleteBytes(transformed, orig)
if err != nil {
return errors.WithStack(err)
}
}
if originalParsed.Get("idToken").Exists() {
transformed, err = sjson.DeleteBytes(transformed, "idToken")
if err != nil {
return errors.WithStack(err)
}
}
type t Session
if err := json.Unmarshal(transformed, (*t)(s)); err != nil {
return errors.WithStack(err)
}
return nil
}