/
map_claims.go
195 lines (167 loc) · 4.49 KB
/
map_claims.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
// Copyright © 2024 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package jwt
import (
"bytes"
"crypto/subtle"
"encoding/json"
"errors"
"time"
jjson "github.com/go-jose/go-jose/v3/json"
"github.com/ory/x/errorsx"
)
var TimeFunc = time.Now
// MapClaims provides backwards compatible validations not available in `go-jose`.
// It was taken from [here](https://raw.githubusercontent.com/form3tech-oss/jwt-go/master/map_claims.go).
//
// Claims type that uses the map[string]interface{} for JSON decoding
// This is the default claims type if you don't supply one
type MapClaims map[string]interface{}
// Compares the aud claim against cmp.
// If required is false, this method will return true if the value matches or is unset
func (m MapClaims) VerifyAudience(cmp string, req bool) bool {
var aud []string
switch v := m["aud"].(type) {
case []string:
aud = v
case []interface{}:
for _, a := range v {
vs, ok := a.(string)
if !ok {
return false
}
aud = append(aud, vs)
}
case string:
aud = append(aud, v)
default:
return false
}
return verifyAud(aud, cmp, req)
}
// Compares the exp claim against cmp.
// If required is false, this method will return true if the value matches or is unset
func (m MapClaims) VerifyExpiresAt(cmp int64, req bool) bool {
if v, ok := m.toInt64("exp"); ok {
return verifyExp(v, cmp, req)
}
return !req
}
// Compares the iat claim against cmp.
// If required is false, this method will return true if the value matches or is unset
func (m MapClaims) VerifyIssuedAt(cmp int64, req bool) bool {
if v, ok := m.toInt64("iat"); ok {
return verifyIat(v, cmp, req)
}
return !req
}
// Compares the iss claim against cmp.
// If required is false, this method will return true if the value matches or is unset
func (m MapClaims) VerifyIssuer(cmp string, req bool) bool {
iss, _ := m["iss"].(string)
return verifyIss(iss, cmp, req)
}
// Compares the nbf claim against cmp.
// If required is false, this method will return true if the value matches or is unset
func (m MapClaims) VerifyNotBefore(cmp int64, req bool) bool {
if v, ok := m.toInt64("nbf"); ok {
return verifyNbf(v, cmp, req)
}
return !req
}
func (m MapClaims) toInt64(claim string) (int64, bool) {
switch t := m[claim].(type) {
case float64:
return int64(t), true
case int64:
return t, true
case json.Number:
v, err := t.Int64()
if err == nil {
return v, true
}
vf, err := t.Float64()
if err != nil {
return 0, false
}
return int64(vf), true
}
return 0, false
}
// Validates time based claims "exp, iat, nbf".
// There is no accounting for clock skew.
// As well, if any of the above claims are not in the token, it will still
// be considered a valid claim.
func (m MapClaims) Valid() error {
vErr := new(ValidationError)
now := TimeFunc().Unix()
if !m.VerifyExpiresAt(now, false) {
vErr.Inner = errors.New("Token is expired")
vErr.Errors |= ValidationErrorExpired
}
if !m.VerifyIssuedAt(now, false) {
vErr.Inner = errors.New("Token used before issued")
vErr.Errors |= ValidationErrorIssuedAt
}
if !m.VerifyNotBefore(now, false) {
vErr.Inner = errors.New("Token is not valid yet")
vErr.Errors |= ValidationErrorNotValidYet
}
if vErr.valid() {
return nil
}
return vErr
}
func (m MapClaims) UnmarshalJSON(b []byte) error {
// This custom unmarshal allows to configure the
// go-jose decoding settings since there is no other way
// see https://github.com/square/go-jose/issues/353.
// If issue is closed with a better solution
// this custom Unmarshal method can be removed
d := jjson.NewDecoder(bytes.NewReader(b))
mp := map[string]interface{}(m)
d.SetNumberType(jjson.UnmarshalIntOrFloat)
if err := d.Decode(&mp); err != nil {
return errorsx.WithStack(err)
}
return nil
}
func verifyAud(aud []string, cmp string, required bool) bool {
if len(aud) == 0 {
return !required
}
for _, a := range aud {
if subtle.ConstantTimeCompare([]byte(a), []byte(cmp)) != 0 {
return true
}
}
return false
}
func verifyExp(exp int64, now int64, required bool) bool {
if exp == 0 {
return !required
}
return now <= exp
}
func verifyIat(iat int64, now int64, required bool) bool {
if iat == 0 {
return !required
}
return now >= iat
}
func verifyIss(iss string, cmp string, required bool) bool {
if iss == "" {
return !required
}
if subtle.ConstantTimeCompare([]byte(iss), []byte(cmp)) != 0 {
return true
} else {
return false
}
}
func verifyNbf(nbf int64, now int64, required bool) bool {
if nbf == 0 {
return !required
}
return now >= nbf
}