/
private_key_jwt.go
146 lines (127 loc) · 4.68 KB
/
private_key_jwt.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
// Licensed to SolID under one or more contributor
// license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright
// ownership. SolID licenses this file to you under
// the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package clientauthentication
import (
"context"
"encoding/json"
"fmt"
"time"
"gopkg.in/square/go-jose.v2"
clientv1 "zntr.io/solid/api/oidc/client/v1"
"zntr.io/solid/oidc"
"zntr.io/solid/sdk/jwk"
"zntr.io/solid/sdk/rfcerrors"
"zntr.io/solid/server/storage"
)
// PrivateKeyJWT authentication method.
func PrivateKeyJWT(clients storage.ClientReader) AuthenticationProcessor {
return &privateKeyJWTAuthentication{
clients: clients,
}
}
type privateJWTClaims struct {
JTI string `json:"jti"`
Subject string `json:"sub"`
Issuer string `json:"iss"`
Audience string `json:"aud"`
Expires uint64 `json:"exp"`
IssuedAt uint64 `json:"iat"`
}
type privateKeyJWTAuthentication struct {
clients storage.ClientReader
}
//nolint:funlen,gocyclo // to refactor
func (p *privateKeyJWTAuthentication) Authenticate(ctx context.Context, req *clientv1.AuthenticateRequest) (*clientv1.AuthenticateResponse, error) {
res := &clientv1.AuthenticateResponse{}
// Validate request
if req == nil {
res.Error = rfcerrors.InvalidRequest().Build()
return res, fmt.Errorf("unable to process nil request")
}
// Validate required fields for this authentication method
if req.ClientAssertionType == nil {
res.Error = rfcerrors.InvalidRequest().Build()
return res, fmt.Errorf("client_assertion_type must be defined")
}
if *req.ClientAssertionType != oidc.AssertionTypeJWTBearer {
res.Error = rfcerrors.InvalidRequest().Build()
return res, fmt.Errorf("client_assertion_type must equals '%s', got '%s'", oidc.AssertionTypeJWTBearer, *req.ClientAssertionType)
}
if req.ClientAssertion == nil {
res.Error = rfcerrors.InvalidRequest().Build()
return res, fmt.Errorf("client_assertion must be defined")
}
if *req.ClientAssertion == "" {
res.Error = rfcerrors.InvalidRequest().Build()
return res, fmt.Errorf("client_assertion must not be empty")
}
// Decode assertion without validation first
rawAssertion, err := jose.ParseSigned(*req.ClientAssertion)
if err != nil {
res.Error = rfcerrors.InvalidRequest().Build()
return res, fmt.Errorf("assertion is syntaxically invalid: %w", err)
}
// Retrieve payload claims
var claims privateJWTClaims
if errDecode := json.Unmarshal(rawAssertion.UnsafePayloadWithoutVerification(), &claims); errDecode != nil {
res.Error = rfcerrors.InvalidRequest().Build()
return res, fmt.Errorf("unable to decode payload claims: %w", errDecode)
}
// Validate claims
if claims.Issuer == "" || claims.Subject == "" || claims.Audience == "" || claims.JTI == "" || claims.Expires == 0 {
res.Error = rfcerrors.InvalidRequest().Build()
return res, fmt.Errorf("iss, sub, aud, jti, exp are mandatory and not empty")
}
if claims.Issuer != claims.Subject {
res.Error = rfcerrors.InvalidRequest().Build()
return res, fmt.Errorf("iss and sub must be identic")
}
if claims.Expires < uint64(time.Now().Unix()) {
res.Error = rfcerrors.InvalidRequest().Build()
return res, fmt.Errorf("expired token")
}
// Check client in storage
client, err := p.clients.Get(ctx, claims.Issuer)
if err != nil {
if err != storage.ErrNotFound {
res.Error = rfcerrors.ServerError().Build()
return res, fmt.Errorf("error during client retrieval: %w", err)
}
res.Error = rfcerrors.InvalidClient().Build()
return res, fmt.Errorf("client not found")
}
// Retrieve JWK associated to the client
if len(client.Jwks) == 0 {
res.Error = rfcerrors.InvalidClient().Build()
return res, fmt.Errorf("client jwks is nil")
}
// Parse JWKS
var jwks jose.JSONWebKeySet
if err := json.Unmarshal(client.Jwks, &jwks); err != nil {
res.Error = rfcerrors.InvalidClient().Build()
return res, fmt.Errorf("client jwks is invalid: %w", err)
}
// Try to validate assertion with one of keys
if err := jwk.ValidateSignature(&jwks, rawAssertion); err != nil {
res.Error = rfcerrors.InvalidClient().Build()
return res, fmt.Errorf("client assertion is invalid: %w", err)
}
// Assign client to result
res.Client = client
// No error
return res, nil
}