/
token.go
136 lines (118 loc) · 3.36 KB
/
token.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
// Copyright 2021-2022, the SS project owners. All rights reserved.
// Please see the OWNERS and LICENSE files for details.
package apiauth
import (
"crypto"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"strconv"
"time"
"github.com/palchukovsky/ss"
)
func newToken(source interface{}) (string, error) {
creator := newTokenCreator()
data, err := creator.SerializeValue(source)
if err != nil {
return "", err
}
signature, err := creator.Sign(data)
if err != nil {
return "", err
}
data = append(data, 0)
data = append(data, signature...)
return base64.RawStdEncoding.EncodeToString(data), nil
}
func parseToken(source string, result interface{}) (error, error) {
data, err := base64.RawStdEncoding.DecodeString(source)
if err != nil {
return err, nil
}
reader := newTokenReader()
data, signature, err := reader.GetValueAndSignature(data)
if err != nil {
return err, nil
}
if err := reader.VerifySignature(data, signature); err != nil {
return nil, err
}
if err := reader.ParseValue(data, result); err != nil {
return nil, err
}
return nil, nil
}
func parseTimeToken(source string) (ss.Time, error) {
result, err := strconv.ParseInt(source, 16, 64)
if err != nil {
return ss.Time{}, fmt.Errorf(`field to parse time token %q: "%w"`,
source, err)
}
return ss.NewTime(time.Unix(result, 0)), nil
}
////////////////////////////////////////////////////////////////////////////////
type tokenCreator struct{}
func newTokenCreator() tokenCreator { return tokenCreator{} }
func (tokenCreator) SerializeValue(value interface{}) ([]byte, error) {
result, err := json.Marshal(value)
if err != nil {
return nil, fmt.Errorf(`failed to serialize: "%w"`, err)
}
return result, nil
}
func (tokenCreator) Sign(data []byte) ([]byte, error) {
hash := crypto.SHA256
hasher := hash.New()
if _, err := hasher.Write(data); err != nil {
return nil, fmt.Errorf(`failed to calc hash for signature: "%w"`, err)
}
result, err := rsa.SignPSS(
rand.Reader,
ss.S.Config().PrivateKey.RSA.Get(),
hash,
hasher.Sum(nil),
&rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthAuto})
if err != nil {
return nil, fmt.Errorf(`failed to sign: "%w"`, err)
}
return result, nil
}
////////////////////////////////////////////////////////////////////////////////
type tokenReader struct{}
func newTokenReader() tokenReader { return tokenReader{} }
func (tokenReader) ParseValue(data []byte, result interface{}) error {
if err := json.Unmarshal(data, result); err != nil {
return fmt.Errorf(`failed to parse: "%w"`, err)
}
return nil
}
func (tokenReader) GetValueAndSignature(
data []byte) (value []byte, signature []byte, err error) {
for i, v := range data {
if v != 0 {
continue
}
return data[:i], data[i+1:], nil
}
return nil, nil, errors.New("wrong format")
}
func (tokenReader) VerifySignature(data []byte, signature []byte) error {
hash := crypto.SHA256
hasher := hash.New()
if _, err := hasher.Write(data); err != nil {
return fmt.Errorf(`failed to calc hash to verify signature: "%w"`, err)
}
err := rsa.VerifyPSS(
&ss.S.Config().PrivateKey.RSA.Get().PublicKey,
hash,
hasher.Sum(nil),
signature, &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthAuto})
if err != nil {
return fmt.Errorf(`failed to verify signature: "%w"`, err)
}
return nil
}
////////////////////////////////////////////////////////////////////////////////