-
Notifications
You must be signed in to change notification settings - Fork 0
/
auth.go
146 lines (120 loc) · 3.24 KB
/
auth.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
package auth
import (
"encoding/base64"
"encoding/json"
"fmt"
"strings"
)
type AuthLib interface {
// Name returns the name of the underlying authentication library.
Name() AuthLibName
// Authorize returns whether user has access to take the given action
// on the given resource.
Authorize(userId string, usergroups []string, resource any, action string) bool
}
type AuthLibName string
const (
AuthLibNameCasbin AuthLibName = "casbin"
AuthLibNameGorbac AuthLibName = "gorbac"
AuthLibNameOso AuthLibName = "oso"
)
type JOSEHeader map[string]string
const (
HeaderMediaType = "typ"
HeaderKeyAlgorithm = "alg"
HeaderKeyID = "kid"
)
type JWT struct {
RawHeader string
Header JOSEHeader
RawPayload string
Payload []byte
Signature []byte
}
func (j *JWT) KeyID() (string, bool) {
kID, ok := j.Header[HeaderKeyID]
return kID, ok
}
func (j *JWT) Claims() (Claims, error) {
return decodeClaims(j.Payload)
}
func (j *JWT) DecodeClaims(out interface{}) error {
return json.Unmarshal(j.Payload, out)
}
// Encoded data part of the token which may be signed.
func (j *JWT) Data() string {
return strings.Join([]string{j.RawHeader, j.RawPayload}, ".")
}
// Full encoded JWT token string in format: header.claims.signature
func (j *JWT) Encode() string {
d := j.Data()
s := encodeSegment(j.Signature)
return strings.Join([]string{d, s}, ".")
}
func ParseJWT(raw string) (*JWT, error) {
parts := strings.Split(raw, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("malformed JWT, only %d segments", len(parts))
}
rawSig := parts[2]
jwt := &JWT{
RawHeader: parts[0],
RawPayload: parts[1],
}
header, err := decodeHeader(jwt.RawHeader)
if err != nil {
return nil, fmt.Errorf("malformed JWT, unable to decode header, %s", err)
}
if err = header.validate(); err != nil {
return nil, fmt.Errorf("malformed JWT, %s", err)
}
jwt.Header = header
payload, err := decodeSegment(jwt.RawPayload)
if err != nil {
return nil, fmt.Errorf("malformed JWT, unable to decode payload: %s", err)
}
jwt.Payload = payload
sig, err := decodeSegment(rawSig)
if err != nil {
return nil, fmt.Errorf("malformed JWT, unable to decode signature: %s", err)
}
jwt.Signature = sig
return jwt, nil
}
func decodeHeader(seg string) (JOSEHeader, error) {
b, err := decodeSegment(seg)
if err != nil {
return nil, err
}
var h JOSEHeader
err = json.Unmarshal(b, &h)
if err != nil {
return nil, err
}
return h, nil
}
// Decode JWT specific base64url encoding with padding stripped
func decodeSegment(seg string) ([]byte, error) {
if l := len(seg) % 4; l != 0 {
seg += strings.Repeat("=", 4-l)
}
return base64.URLEncoding.DecodeString(seg)
}
func (j JOSEHeader) validate() error {
if _, exists := j[HeaderKeyAlgorithm]; !exists {
return fmt.Errorf("header missing %q parameter", HeaderKeyAlgorithm)
}
return nil
}
type Claims map[string]interface{}
func decodeClaims(payload []byte) (Claims, error) {
var c Claims
if err := json.Unmarshal(payload, &c); err != nil {
return nil, fmt.Errorf("malformed JWT claims, unable to decode: %v", err)
}
return c, nil
}
// Encode JWT specific base64url encoding with padding stripped
func encodeSegment(seg []byte) string {
return strings.TrimRight(base64.URLEncoding.EncodeToString(seg), "=")
}