Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
274 lines (247 sloc) 6.73 KB
// Copyright (c) 2012-2016 The Revel Framework Authors, All rights reserved.
// Revel Framework source code and usage is governed by a MIT style
// license that can be found in the LICENSE file.
package cache
import (
"time"
"github.com/garyburd/redigo/redis"
"github.com/revel/revel"
)
// RedisCache wraps the Redis client to meet the Cache interface.
type RedisCache struct {
pool *redis.Pool
defaultExpiration time.Duration
}
// NewRedisCache returns a new RedisCache with given parameters
// until redigo supports sharding/clustering, only one host will be in hostList
func NewRedisCache(host string, password string, defaultExpiration time.Duration) RedisCache {
var pool = &redis.Pool{
MaxIdle: revel.Config.IntDefault("cache.redis.maxidle", 5),
MaxActive: revel.Config.IntDefault("cache.redis.maxactive", 0),
IdleTimeout: time.Duration(revel.Config.IntDefault("cache.redis.idletimeout", 240)) * time.Second,
Dial: func() (redis.Conn, error) {
protocol := revel.Config.StringDefault("cache.redis.protocol", "tcp")
toc := time.Millisecond * time.Duration(revel.Config.IntDefault("cache.redis.timeout.connect", 10000))
tor := time.Millisecond * time.Duration(revel.Config.IntDefault("cache.redis.timeout.read", 5000))
tow := time.Millisecond * time.Duration(revel.Config.IntDefault("cache.redis.timeout.write", 5000))
c, err := redis.Dial(protocol, host,
redis.DialConnectTimeout(toc),
redis.DialReadTimeout(tor),
redis.DialWriteTimeout(tow))
if err != nil {
return nil, err
}
if len(password) > 0 {
if _, err = c.Do("AUTH", password); err != nil {
_ = c.Close()
return nil, err
}
} else {
// check with PING
if _, err = c.Do("PING"); err != nil {
_ = c.Close()
return nil, err
}
}
return c, err
},
// custom connection test method
TestOnBorrow: func(c redis.Conn, t time.Time) error {
_, err := c.Do("PING")
return err
},
}
return RedisCache{pool, defaultExpiration}
}
func (c RedisCache) Set(key string, value interface{}, expires time.Duration) error {
conn := c.pool.Get()
defer func() {
_ = conn.Close()
}()
return c.invoke(conn.Do, key, value, expires)
}
func (c RedisCache) Add(key string, value interface{}, expires time.Duration) error {
conn := c.pool.Get()
defer func() {
_ = conn.Close()
}()
existed, err := exists(conn, key)
if err != nil {
return err
} else if existed {
return ErrNotStored
}
return c.invoke(conn.Do, key, value, expires)
}
func (c RedisCache) Replace(key string, value interface{}, expires time.Duration) error {
conn := c.pool.Get()
defer func() {
_ = conn.Close()
}()
existed, err := exists(conn, key)
if err != nil {
return err
} else if !existed {
return ErrNotStored
}
err = c.invoke(conn.Do, key, value, expires)
if value == nil {
return ErrNotStored
}
return err
}
func (c RedisCache) Get(key string, ptrValue interface{}) error {
conn := c.pool.Get()
defer func() {
_ = conn.Close()
}()
raw, err := conn.Do("GET", key)
if err != nil {
return err
} else if raw == nil {
return ErrCacheMiss
}
item, err := redis.Bytes(raw, err)
if err != nil {
return err
}
return Deserialize(item, ptrValue)
}
func generalizeStringSlice(strs []string) []interface{} {
ret := make([]interface{}, len(strs))
for i, str := range strs {
ret[i] = str
}
return ret
}
func (c RedisCache) GetMulti(keys ...string) (Getter, error) {
conn := c.pool.Get()
defer func() {
_ = conn.Close()
}()
items, err := redis.Values(conn.Do("MGET", generalizeStringSlice(keys)...))
if err != nil {
return nil, err
} else if items == nil {
return nil, ErrCacheMiss
}
m := make(map[string][]byte)
for i, key := range keys {
m[key] = nil
if i < len(items) && items[i] != nil {
s, ok := items[i].([]byte)
if ok {
m[key] = s
}
}
}
return RedisItemMapGetter(m), nil
}
func exists(conn redis.Conn, key string) (bool, error) {
return redis.Bool(conn.Do("EXISTS", key))
}
func (c RedisCache) Delete(key string) error {
conn := c.pool.Get()
defer func() {
_ = conn.Close()
}()
existed, err := redis.Bool(conn.Do("DEL", key))
if err == nil && !existed {
err = ErrCacheMiss
}
return err
}
func (c RedisCache) Increment(key string, delta uint64) (uint64, error) {
conn := c.pool.Get()
defer func() {
_ = conn.Close()
}()
// Check for existence *before* increment as per the cache contract.
// redis will auto create the key, and we don't want that. Since we need to do increment
// ourselves instead of natively via INCRBY (redis doesn't support wrapping), we get the value
// and do the exists check this way to minimize calls to Redis
val, err := conn.Do("GET", key)
if err != nil {
return 0, err
} else if val == nil {
return 0, ErrCacheMiss
}
currentVal, err := redis.Int64(val, nil)
if err != nil {
return 0, err
}
sum := currentVal + int64(delta)
_, err = conn.Do("SET", key, sum)
if err != nil {
return 0, err
}
return uint64(sum), nil
}
func (c RedisCache) Decrement(key string, delta uint64) (newValue uint64, err error) {
conn := c.pool.Get()
defer func() {
_ = conn.Close()
}()
// Check for existence *before* increment as per the cache contract.
// redis will auto create the key, and we don't want that, hence the exists call
existed, err := exists(conn, key)
if err != nil {
return 0, err
} else if !existed {
return 0, ErrCacheMiss
}
// Decrement contract says you can only go to 0
// so we go fetch the value and if the delta is greater than the amount,
// 0 out the value
currentVal, err := redis.Int64(conn.Do("GET", key))
if err != nil {
return 0, err
}
if delta > uint64(currentVal) {
var tempint int64
tempint, err = redis.Int64(conn.Do("DECRBY", key, currentVal))
return uint64(tempint), err
}
tempint, err := redis.Int64(conn.Do("DECRBY", key, delta))
return uint64(tempint), err
}
func (c RedisCache) Flush() error {
conn := c.pool.Get()
defer func() {
_ = conn.Close()
}()
_, err := conn.Do("FLUSHALL")
return err
}
func (c RedisCache) invoke(f func(string, ...interface{}) (interface{}, error),
key string, value interface{}, expires time.Duration) error {
switch expires {
case DefaultExpiryTime:
expires = c.defaultExpiration
case ForEverNeverExpiry:
expires = time.Duration(0)
}
b, err := Serialize(value)
if err != nil {
return err
}
conn := c.pool.Get()
defer func() {
_ = conn.Close()
}()
if expires > 0 {
_, err = f("SETEX", key, int32(expires/time.Second), b)
return err
}
_, err = f("SET", key, b)
return err
}
// RedisItemMapGetter implements a Getter on top of the returned item map.
type RedisItemMapGetter map[string][]byte
func (g RedisItemMapGetter) Get(key string, ptrValue interface{}) error {
item, ok := g[key]
if !ok {
return ErrCacheMiss
}
return Deserialize(item, ptrValue)
}
You can’t perform that action at this time.