-
Notifications
You must be signed in to change notification settings - Fork 3
/
redis_store.go
283 lines (244 loc) · 7.7 KB
/
redis_store.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
package sessionKit
import (
"context"
"crypto/rand"
"encoding/base32"
"errors"
"github.com/richelieu-yang/chimera/v3/src/core/errorKit"
"github.com/richelieu-yang/chimera/v3/src/crypto/base64Kit"
"github.com/richelieu-yang/chimera/v3/src/randomKit"
"github.com/richelieu-yang/chimera/v3/src/serialize/gobKit"
"io"
"net/http"
"strconv"
"strings"
"time"
"github.com/gorilla/sessions"
"github.com/redis/go-redis/v9"
)
// RedisStore stores gorilla sessions in Redis
type RedisStore struct {
// Richelieu: MaxAge == 0的情况下,session键的超时时间
expirationForZeroMaxAge time.Duration
// client to connect to redis
client redis.UniversalClient
// default options to use when a new session is created
options sessions.Options
// key prefix with which the session will be stored
keyPrefix string
// key generator
keyGen KeyGenFunc
// session serializer
serializer SessionSerializer
}
// KeyGenFunc defines a function used by store to generate a key
type KeyGenFunc func() (string, error)
// NewRedisStore returns a new RedisStore with default configuration
/*
PS:
(1) cookie的value 由 传参keyGen 生成(即Redis中session键的后半部分).
e.g.
同一个Session,cookie的值: "01H815Q2YFYEFDZHAQ2478FFS9", Redis中session键: "session:01H815Q2YFYEFDZHAQ2478FFS9".
@param redisKeyPrefix Redis中session键的前半部分(固定)
@param keyGen Redis中session键的后半部分(不固定,动态生成)|| cookie的value
(a) 可以为nil,将采用默认值
@param opts Cookie的属性
@param expirationForZeroMaxAge (a) sessions.Options.MaxAge == 0的情况下,此属性才有效
(b) 传值可以参考 redisKit.Client 的 Set().
*/
func NewRedisStore(ctx context.Context, client redis.UniversalClient, redisKeyPrefix string, keyGen KeyGenFunc,
opts sessions.Options, expirationForZeroMaxAge time.Duration) (*RedisStore, error) {
// Richelieu
if err := client.Ping(ctx).Err(); err != nil {
return nil, err
}
store := &RedisStore{
expirationForZeroMaxAge: expirationForZeroMaxAge,
//options: sessions.Options{
// Path: "/",
// MaxAge: 86400 * 30,
//},
client: client,
//keyPrefix: "session:",
//keyGen: generateRandomKey,
serializer: GobSerializer{},
}
// Redis中session键的前半部分(固定)
store.KeyPrefix(redisKeyPrefix)
// Redis中session键的后半部分(不固定,动态生成)
if keyGen == nil {
keyGen = generateRandomKey
}
store.KeyGen(keyGen)
store.Options(opts)
return store, nil
}
// Get returns a session for the given name after adding it to the registry.
func (s *RedisStore) Get(r *http.Request, name string) (*sessions.Session, error) {
return sessions.GetRegistry(r).Get(s, name)
}
// New returns a session for the given name without adding it to the registry.
func (s *RedisStore) New(r *http.Request, name string) (*sessions.Session, error) {
session := sessions.NewSession(s, name)
opts := s.options
session.Options = &opts
session.IsNew = true
c, err := r.Cookie(name)
if err != nil {
return session, nil
}
session.ID = c.Value
err = s.load(r.Context(), session)
if err == nil {
session.IsNew = false
} else if err == redis.Nil {
err = nil // no data stored
}
return session, err
}
// Save adds a single session to the response.
//
// If the Options.MaxAge of the session is <= 0 then the session file will be
// deleted from the store. With this process it enforces the properly
// session cookie handling so no need to trust in the cookie management in the
// web browser.
func (s *RedisStore) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error {
// Delete if max-age is <= 0
// Richelieu
//if session.Options.MaxAge <= 0 {
if session.Options.MaxAge < 0 {
if err := s.delete(r.Context(), session); err != nil {
return err
}
http.SetCookie(w, sessions.NewCookie(session.Name(), "", session.Options))
return nil
}
if session.ID == "" {
id, err := s.keyGen()
if err != nil {
return errors.New("redisstore: failed to generate session id")
}
session.ID = id
}
if err := s.save(r.Context(), session); err != nil {
return err
}
http.SetCookie(w, sessions.NewCookie(session.Name(), session.ID, session.Options))
return nil
}
// Options set options to use when a new session is created
func (s *RedisStore) Options(opts sessions.Options) {
s.options = opts
}
// KeyPrefix sets the key prefix to store session in Redis
func (s *RedisStore) KeyPrefix(keyPrefix string) {
s.keyPrefix = keyPrefix
}
// KeyGen sets the key generator function
func (s *RedisStore) KeyGen(f KeyGenFunc) {
s.keyGen = f
}
// Serializer sets the session serializer to store session
func (s *RedisStore) Serializer(ss SessionSerializer) {
s.serializer = ss
}
// Close closes the Redis store
func (s *RedisStore) Close() error {
return s.client.Close()
}
// save writes session in Redis
/*
@param genFlag Richelieu: true: session.ID是新生成的
*/
func (s *RedisStore) save(ctx context.Context, session *sessions.Session) error {
b, err := s.serializer.Serialize(session)
if err != nil {
return err
}
// Richelieu
//return s.client.Set(ctx, s.keyPrefix+session.ID, b, time.Duration(session.Options.MaxAge)*time.Second).Err()
var expiration time.Duration
if session.Options.MaxAge == 0 {
expiration = s.expirationForZeroMaxAge
} else {
expiration = time.Duration(session.Options.MaxAge) * time.Second
}
if session.IsNew {
// 要避免: key在Redis中已存在(uuid、ulid等并不可靠)
for i := 0; i < 3; i++ {
ok, err := s.client.SetNX(ctx, s.keyPrefix+session.ID, b, expiration).Result()
if err != nil {
return err
}
if ok {
return nil
}
// 重复了,需要重新生成 session.ID
id, err := s.keyGen()
if err != nil {
return errorKit.Newf("fail to regenerate session id")
}
session.ID = id + "_" + strconv.Itoa(randomKit.Int(0, 123456))
}
return errorKit.Newf("multiple repetition")
}
return s.client.Set(ctx, s.keyPrefix+session.ID, b, expiration).Err()
}
// load reads session from Redis
func (s *RedisStore) load(ctx context.Context, session *sessions.Session) error {
cmd := s.client.Get(ctx, s.keyPrefix+session.ID)
if cmd.Err() != nil {
return cmd.Err()
}
b, err := cmd.Bytes()
if err != nil {
return err
}
return s.serializer.Deserialize(b, session)
}
// delete deletes session in Redis
func (s *RedisStore) delete(ctx context.Context, session *sessions.Session) error {
return s.client.Del(ctx, s.keyPrefix+session.ID).Err()
}
// SessionSerializer provides an interface for serialize/deserialize a session
type SessionSerializer interface {
Serialize(s *sessions.Session) ([]byte, error)
Deserialize(b []byte, s *sessions.Session) error
}
type GobSerializer struct{}
func (gs GobSerializer) Serialize(s *sessions.Session) ([]byte, error) {
// gob序列化
data, err := gobKit.Marshal(s.Values)
if err != nil {
return nil, err
}
// base64编码(Richelieu: 额外hex编码是为了防国产TongRDS,set后get的值有问题)
return base64Kit.Encode(data), nil
//buf := new(bytes.Buffer)
//enc := gob.NewEncoder(buf)
//err := enc.Encode(s.Values)
//if err != nil {
// return nil, err
//}
//data := buf.Bytes()
//return hexKit.Encode(data), nil
}
func (gs GobSerializer) Deserialize(d []byte, s *sessions.Session) error {
// base64解码
data, err := base64Kit.Decode(d)
if err != nil {
return err
}
// gob反序列化
return gobKit.Unmarshal(data, &s.Values)
//dec := gob.NewDecoder(bytes.NewBuffer(d))
//return dec.Decode(&s.Values)
}
// generateRandomKey returns a new random key
func generateRandomKey() (string, error) {
k := make([]byte, 64)
if _, err := io.ReadFull(rand.Reader, k); err != nil {
return "", err
}
return strings.TrimRight(base32.StdEncoding.EncodeToString(k), "="), nil
}