forked from klanmiko/jwx
/
rsa.go
86 lines (75 loc) · 1.95 KB
/
rsa.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
package verify
import (
"crypto"
"crypto/rsa"
"github.com/lestrrat-go/jwx/jwa"
"github.com/pkg/errors"
)
var rsaVerifyFuncs = map[jwa.SignatureAlgorithm]rsaVerifyFunc{}
func init() {
algs := map[jwa.SignatureAlgorithm]struct {
Hash crypto.Hash
VerifyFunc func(crypto.Hash) rsaVerifyFunc
}{
jwa.RS256: {
Hash: crypto.SHA256,
VerifyFunc: makeVerifyPKCS1v15,
},
jwa.RS384: {
Hash: crypto.SHA384,
VerifyFunc: makeVerifyPKCS1v15,
},
jwa.RS512: {
Hash: crypto.SHA512,
VerifyFunc: makeVerifyPKCS1v15,
},
jwa.PS256: {
Hash: crypto.SHA256,
VerifyFunc: makeVerifyPSS,
},
jwa.PS384: {
Hash: crypto.SHA384,
VerifyFunc: makeVerifyPSS,
},
jwa.PS512: {
Hash: crypto.SHA512,
VerifyFunc: makeVerifyPSS,
},
}
for alg, item := range algs {
rsaVerifyFuncs[alg] = item.VerifyFunc(item.Hash)
}
}
func makeVerifyPKCS1v15(hash crypto.Hash) rsaVerifyFunc {
return rsaVerifyFunc(func(payload, signature []byte, key *rsa.PublicKey) error {
h := hash.New()
h.Write(payload)
return rsa.VerifyPKCS1v15(key, hash, h.Sum(nil), signature)
})
}
func makeVerifyPSS(hash crypto.Hash) rsaVerifyFunc {
return rsaVerifyFunc(func(payload, signature []byte, key *rsa.PublicKey) error {
h := hash.New()
h.Write(payload)
return rsa.VerifyPSS(key, hash, h.Sum(nil), signature, nil)
})
}
func newRSA(alg jwa.SignatureAlgorithm) (*RSAVerifier, error) {
verifyfn, ok := rsaVerifyFuncs[alg]
if !ok {
return nil, errors.Errorf(`unsupported algorithm while trying to create RSA verifier: %s`, alg)
}
return &RSAVerifier{
verify: verifyfn,
}, nil
}
func (v RSAVerifier) Verify(payload, signature []byte, key interface{}) error {
if key == nil {
return errors.New(`missing public key while verifying payload`)
}
rsakey, ok := key.(*rsa.PublicKey)
if !ok {
return errors.Errorf(`invalid key type %T. *rsa.PublicKey is required`, key)
}
return v.verify(payload, signature, rsakey)
}