/
session_token.go
336 lines (314 loc) · 11.1 KB
/
session_token.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
// Copyright 2022 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sessionstates
import (
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/sha512"
"crypto/tls"
"crypto/x509"
"encoding/json"
"strings"
"sync"
"time"
"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
"github.com/pingcap/tidb/util/logutil"
"go.uber.org/zap"
)
// Token-based authentication is used in session migration. We don't use typical authentication because the proxy
// cannot store the user passwords for security issues.
//
// The process of token-based authentication:
// 1. Before migrating the session, the proxy requires a token from server A.
// 2. Server A generates a token and signs it with a private key defined in the certificate.
// 3. The proxy authenticates with server B and sends the signed token as the password.
// 4. Server B checks the signature with the public key defined in the certificate and then verifies the token.
//
// The highlight is that the certificates on all the servers should be the same all the time.
// However, the certificates should be rotated periodically. Just in case of using different certificates to
// sign and check, a server should keep the old certificate for a while. A server will try both
// the 2 certificates to check the signature.
const (
// A token needs a lifetime to avoid brute force attack.
tokenLifetime = time.Minute
// LoadCertInterval is the interval of reloading the certificate. The certificate should be rotated periodically.
LoadCertInterval = 10 * time.Minute
// After a certificate is replaced, it's still valid for oldCertValidTime.
// oldCertValidTime must be a little longer than LoadCertInterval, because the previous server may
// sign with the old cert but the new server checks with the new cert.
// - server A loads the old cert at 00:00:00.
// - the cert is rotated at 00:00:01 on all servers.
// - server B loads the new cert at 00:00:02.
// - server A signs token with the old cert at 00:10:00.
// - server B reloads the same new cert again at 00:10:01, and it has 3 certs now.
// - server B receives the token at 00:10:02, so the old cert should be valid for more than 10m after replacement.
oldCertValidTime = 15 * time.Minute
)
// SessionToken represents the token used to authenticate with the new server.
type SessionToken struct {
Username string `json:"username"`
SignTime time.Time `json:"sign-time"`
ExpireTime time.Time `json:"expire-time"`
Signature []byte `json:"signature,omitempty"`
}
// CreateSessionToken creates a token for the proxy.
func CreateSessionToken(username string) (*SessionToken, error) {
now := getNow()
token := &SessionToken{
Username: username,
SignTime: now,
ExpireTime: now.Add(tokenLifetime),
}
tokenBytes, err := json.Marshal(token)
if err != nil {
return nil, errors.Trace(err)
}
if token.Signature, err = globalSigningCert.sign(tokenBytes); err != nil {
return nil, ErrCannotMigrateSession.GenWithStackByArgs(err.Error())
}
return token, nil
}
// ValidateSessionToken validates the token sent from the proxy.
func ValidateSessionToken(tokenBytes []byte, username string) (err error) {
var token SessionToken
if err = json.Unmarshal(tokenBytes, &token); err != nil {
return errors.Trace(err)
}
signature := token.Signature
// Clear the signature and marshal it again to get the original content.
token.Signature = nil
if tokenBytes, err = json.Marshal(token); err != nil {
return errors.Trace(err)
}
if err = globalSigningCert.checkSignature(tokenBytes, signature); err != nil {
return ErrCannotMigrateSession.GenWithStackByArgs(err.Error())
}
now := getNow()
if now.After(token.ExpireTime) {
return ErrCannotMigrateSession.GenWithStackByArgs("token expired", token.ExpireTime.String())
}
// An attacker may forge a very long lifetime to brute force, so we also need to check `SignTime`.
// However, we need to be tolerant of these problems:
// - The `tokenLifetime` may change between TiDB versions, so we can't check `token.SignTime.Add(tokenLifetime).Equal(token.ExpireTime)`
// - There may exist time bias between TiDB instances, so we can't check `now.After(token.SignTime)`
if token.SignTime.Add(tokenLifetime).Before(now) {
return ErrCannotMigrateSession.GenWithStackByArgs("token lifetime is too long", token.SignTime.String())
}
if !strings.EqualFold(username, token.Username) {
return ErrCannotMigrateSession.GenWithStackByArgs("username does not match", username, token.Username)
}
return nil
}
// SetKeyPath sets the path of key.pem and force load the certificate again.
func SetKeyPath(keyPath string) {
globalSigningCert.setKeyPath(keyPath)
}
// SetCertPath sets the path of key.pem and force load the certificate again.
func SetCertPath(certPath string) {
globalSigningCert.setCertPath(certPath)
}
// ReloadSigningCert is used to load the certificate periodically in a separate goroutine.
// It's impossible to know when the old certificate should expire without this goroutine:
// - If the certificate is rotated a minute ago, the old certificate should be still valid for a while.
// - If the certificate is rotated a month ago, the old certificate should expire for safety.
func ReloadSigningCert() {
globalSigningCert.lockAndLoad()
}
var globalSigningCert signingCert
// signingCert represents the parsed certificate used for token-based auth.
type signingCert struct {
sync.RWMutex
certPath string
keyPath string
// The cert file may happen to be rotated between signing and checking, so we keep the old cert for a while.
// certs contain all the certificates that are not expired yet.
certs []*certInfo
}
type certInfo struct {
cert *x509.Certificate
privKey crypto.PrivateKey
expireTime time.Time
}
func (sc *signingCert) setCertPath(certPath string) {
sc.Lock()
if certPath != sc.certPath {
sc.certPath = certPath
// It may fail expectedly because the key path is not set yet.
sc.checkAndLoadCert()
}
sc.Unlock()
}
func (sc *signingCert) setKeyPath(keyPath string) {
sc.Lock()
if keyPath != sc.keyPath {
sc.keyPath = keyPath
// It may fail expectedly because the cert path is not set yet.
sc.checkAndLoadCert()
}
sc.Unlock()
}
func (sc *signingCert) lockAndLoad() {
sc.Lock()
sc.checkAndLoadCert()
sc.Unlock()
}
func (sc *signingCert) checkAndLoadCert() {
if len(sc.certPath) == 0 || len(sc.keyPath) == 0 {
return
}
if err := sc.loadCert(); err != nil {
logutil.BgLogger().Warn("loading signing cert failed",
zap.String("cert path", sc.certPath),
zap.String("key path", sc.keyPath),
zap.Error(err))
} else {
logutil.BgLogger().Info("signing cert is loaded successfully",
zap.String("cert path", sc.certPath),
zap.String("key path", sc.keyPath))
}
}
// loadCert loads the cert and adds it into the cert list.
func (sc *signingCert) loadCert() error {
tlsCert, err := tls.LoadX509KeyPair(sc.certPath, sc.keyPath)
if err != nil {
return errors.Wrapf(err, "load x509 failed, cert path: %s, key path: %s", sc.certPath, sc.keyPath)
}
var cert *x509.Certificate
if tlsCert.Leaf != nil {
cert = tlsCert.Leaf
} else {
if cert, err = x509.ParseCertificate(tlsCert.Certificate[0]); err != nil {
return errors.Wrapf(err, "parse x509 cert failed, cert path: %s, key path: %s", sc.certPath, sc.keyPath)
}
}
// Rotate certs. Ensure that the expireTime of certs is in descending order.
now := getNow()
newCerts := make([]*certInfo, 0, len(sc.certs)+1)
newCerts = append(newCerts, &certInfo{
cert: cert,
privKey: tlsCert.PrivateKey,
expireTime: now.Add(LoadCertInterval + oldCertValidTime),
})
for i := 0; i < len(sc.certs); i++ {
// Discard the certs that are already expired.
if now.After(sc.certs[i].expireTime) {
break
}
newCerts = append(newCerts, sc.certs[i])
}
sc.certs = newCerts
return nil
}
// sign generates a signature with the content and the private key.
func (sc *signingCert) sign(content []byte) ([]byte, error) {
var (
signer crypto.Signer
opts crypto.SignerOpts
)
sc.RLock()
defer sc.RUnlock()
if len(sc.certs) == 0 {
return nil, errors.New("no certificate or key file to sign the data")
}
// Always sign the token with the latest cert.
certInfo := sc.certs[0]
switch key := certInfo.privKey.(type) {
case ed25519.PrivateKey:
signer = key
opts = crypto.Hash(0)
case *rsa.PrivateKey:
signer = key
var pssHash crypto.Hash
switch certInfo.cert.SignatureAlgorithm {
case x509.SHA256WithRSAPSS:
pssHash = crypto.SHA256
case x509.SHA384WithRSAPSS:
pssHash = crypto.SHA384
case x509.SHA512WithRSAPSS:
pssHash = crypto.SHA512
}
if pssHash != 0 {
h := pssHash.New()
h.Write(content)
content = h.Sum(nil)
opts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: pssHash}
break
}
switch certInfo.cert.SignatureAlgorithm {
case x509.SHA256WithRSA:
hashed := sha256.Sum256(content)
content = hashed[:]
opts = crypto.SHA256
case x509.SHA384WithRSA:
hashed := sha512.Sum384(content)
content = hashed[:]
opts = crypto.SHA384
case x509.SHA512WithRSA:
hashed := sha512.Sum512(content)
content = hashed[:]
opts = crypto.SHA512
default:
return nil, errors.Errorf("not supported private key type '%s' for signing", certInfo.cert.SignatureAlgorithm.String())
}
case *ecdsa.PrivateKey:
signer = key
default:
return nil, errors.Errorf("not supported private key type '%s' for signing", certInfo.cert.SignatureAlgorithm.String())
}
return signer.Sign(rand.Reader, content, opts)
}
// checkSignature checks the signature and the content.
func (sc *signingCert) checkSignature(content, signature []byte) error {
sc.RLock()
defer sc.RUnlock()
now := getNow()
var err error
for _, certInfo := range sc.certs {
// The expireTime is in descending order. So if the first one is expired, we skip the following.
if now.After(certInfo.expireTime) {
break
}
switch certInfo.privKey.(type) {
// ESDSA is special: `PrivateKey.Sign` doesn't match with `Certificate.CheckSignature`.
case *ecdsa.PrivateKey:
if !ecdsa.VerifyASN1(certInfo.cert.PublicKey.(*ecdsa.PublicKey), content, signature) {
err = errors.New("x509: ECDSA verification failure")
}
default:
err = certInfo.cert.CheckSignature(certInfo.cert.SignatureAlgorithm, content, signature)
}
if err == nil {
return nil
}
}
// no certs (possible) or all certs are expired (impossible)
if err == nil {
return errors.Errorf("no valid certificate to check the signature, cached certificates: %d", len(sc.certs))
}
return err
}
func getNow() time.Time {
now := time.Now()
failpoint.Inject("mockNowOffset", func(val failpoint.Value) {
if s := uint64(val.(int)); s != 0 {
now = now.Add(time.Duration(s))
}
})
return now
}