From 3c1f2bd45a62c38c604b39ce4111ca0b1514641c Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Thu, 28 May 2015 15:51:19 +0300 Subject: [PATCH] Allow setting and scaning interface{} values. --- command.go | 98 +++++++---- command_test.go | 9 +- commands.go | 411 ++++++++++++++++++++++++++++++++--------------- commands_test.go | 93 +++++++++++ conn.go | 6 +- multi.go | 12 +- parser.go | 239 ++++++++++++++++++++++++--- parser_test.go | 2 +- pubsub.go | 29 ++-- redis_test.go | 37 +++-- 10 files changed, 716 insertions(+), 220 deletions(-) diff --git a/command.go b/command.go index fcea708cb..f32569749 100644 --- a/command.go +++ b/command.go @@ -1,6 +1,7 @@ package redis import ( + "bytes" "fmt" "strconv" "strings" @@ -28,7 +29,7 @@ var ( ) type Cmder interface { - args() []string + args() []interface{} parseReply(*bufio.Reader) error setErr(error) reset() @@ -38,7 +39,7 @@ type Cmder interface { clusterKey() string Err() error - String() string + fmt.Stringer } func setCmdsErr(cmds []Cmder, e error) { @@ -54,12 +55,21 @@ func resetCmds(cmds []Cmder) { } func cmdString(cmd Cmder, val interface{}) string { - s := strings.Join(cmd.args(), " ") + var ss []string + for _, arg := range cmd.args() { + ss = append(ss, fmt.Sprint(arg)) + } + s := strings.Join(ss, " ") if err := cmd.Err(); err != nil { return s + ": " + err.Error() } if val != nil { - return s + ": " + fmt.Sprint(val) + switch vv := val.(type) { + case []byte: + return s + ": " + string(vv) + default: + return s + ": " + fmt.Sprint(val) + } } return s @@ -68,7 +78,7 @@ func cmdString(cmd Cmder, val interface{}) string { //------------------------------------------------------------------------------ type baseCmd struct { - _args []string + _args []interface{} err error @@ -84,7 +94,7 @@ func (cmd *baseCmd) Err() error { return nil } -func (cmd *baseCmd) args() []string { +func (cmd *baseCmd) args() []interface{} { return cmd._args } @@ -102,7 +112,7 @@ func (cmd *baseCmd) writeTimeout() *time.Duration { func (cmd *baseCmd) clusterKey() string { if cmd._clusterKeyPos > 0 && cmd._clusterKeyPos < len(cmd._args) { - return cmd._args[cmd._clusterKeyPos] + return fmt.Sprint(cmd._args[cmd._clusterKeyPos]) } return "" } @@ -123,7 +133,7 @@ type Cmd struct { val interface{} } -func NewCmd(args ...string) *Cmd { +func NewCmd(args ...interface{}) *Cmd { return &Cmd{baseCmd: baseCmd{_args: args}} } @@ -146,6 +156,11 @@ func (cmd *Cmd) String() string { func (cmd *Cmd) parseReply(rd *bufio.Reader) error { cmd.val, cmd.err = parseReply(rd, parseSlice) + // Convert to string to preserve old behaviour. + // TODO: remove in v4 + if v, ok := cmd.val.([]byte); ok { + cmd.val = string(v) + } return cmd.err } @@ -157,7 +172,7 @@ type SliceCmd struct { val []interface{} } -func NewSliceCmd(args ...string) *SliceCmd { +func NewSliceCmd(args ...interface{}) *SliceCmd { return &SliceCmd{baseCmd: baseCmd{_args: args, _clusterKeyPos: 1}} } @@ -196,11 +211,11 @@ type StatusCmd struct { val string } -func NewStatusCmd(args ...string) *StatusCmd { +func NewStatusCmd(args ...interface{}) *StatusCmd { return &StatusCmd{baseCmd: baseCmd{_args: args, _clusterKeyPos: 1}} } -func newKeylessStatusCmd(args ...string) *StatusCmd { +func newKeylessStatusCmd(args ...interface{}) *StatusCmd { return &StatusCmd{baseCmd: baseCmd{_args: args}} } @@ -227,7 +242,7 @@ func (cmd *StatusCmd) parseReply(rd *bufio.Reader) error { cmd.err = err return err } - cmd.val = v.(string) + cmd.val = string(v.([]byte)) return nil } @@ -239,7 +254,7 @@ type IntCmd struct { val int64 } -func NewIntCmd(args ...string) *IntCmd { +func NewIntCmd(args ...interface{}) *IntCmd { return &IntCmd{baseCmd: baseCmd{_args: args, _clusterKeyPos: 1}} } @@ -279,7 +294,7 @@ type DurationCmd struct { precision time.Duration } -func NewDurationCmd(precision time.Duration, args ...string) *DurationCmd { +func NewDurationCmd(precision time.Duration, args ...interface{}) *DurationCmd { return &DurationCmd{ precision: precision, baseCmd: baseCmd{_args: args, _clusterKeyPos: 1}, @@ -321,7 +336,7 @@ type BoolCmd struct { val bool } -func NewBoolCmd(args ...string) *BoolCmd { +func NewBoolCmd(args ...interface{}) *BoolCmd { return &BoolCmd{baseCmd: baseCmd{_args: args, _clusterKeyPos: 1}} } @@ -342,6 +357,8 @@ func (cmd *BoolCmd) String() string { return cmdString(cmd, cmd.val) } +var ok = []byte("OK") + func (cmd *BoolCmd) parseReply(rd *bufio.Reader) error { v, err := parseReply(rd, nil) // `SET key value NX` returns nil when key already exists. @@ -357,8 +374,8 @@ func (cmd *BoolCmd) parseReply(rd *bufio.Reader) error { case int64: cmd.val = vv == 1 return nil - case string: - cmd.val = vv == "OK" + case []byte: + cmd.val = bytes.Equal(vv, ok) return nil default: return fmt.Errorf("got %T, wanted int64 or string") @@ -370,23 +387,27 @@ func (cmd *BoolCmd) parseReply(rd *bufio.Reader) error { type StringCmd struct { baseCmd - val string + val []byte } -func NewStringCmd(args ...string) *StringCmd { +func NewStringCmd(args ...interface{}) *StringCmd { return &StringCmd{baseCmd: baseCmd{_args: args, _clusterKeyPos: 1}} } func (cmd *StringCmd) reset() { - cmd.val = "" + cmd.val = nil cmd.err = nil } func (cmd *StringCmd) Val() string { - return cmd.val + return string(cmd.val) } func (cmd *StringCmd) Result() (string, error) { + return cmd.Val(), cmd.err +} + +func (cmd *StringCmd) Bytes() ([]byte, error) { return cmd.val, cmd.err } @@ -394,21 +415,28 @@ func (cmd *StringCmd) Int64() (int64, error) { if cmd.err != nil { return 0, cmd.err } - return strconv.ParseInt(cmd.val, 10, 64) + return strconv.ParseInt(cmd.Val(), 10, 64) } func (cmd *StringCmd) Uint64() (uint64, error) { if cmd.err != nil { return 0, cmd.err } - return strconv.ParseUint(cmd.val, 10, 64) + return strconv.ParseUint(cmd.Val(), 10, 64) } func (cmd *StringCmd) Float64() (float64, error) { if cmd.err != nil { return 0, cmd.err } - return strconv.ParseFloat(cmd.val, 64) + return strconv.ParseFloat(cmd.Val(), 64) +} + +func (cmd *StringCmd) Scan(val interface{}) error { + if cmd.err != nil { + return cmd.err + } + return scan(cmd.val, val) } func (cmd *StringCmd) String() string { @@ -421,7 +449,9 @@ func (cmd *StringCmd) parseReply(rd *bufio.Reader) error { cmd.err = err return err } - cmd.val = v.(string) + b := v.([]byte) + cmd.val = make([]byte, len(b)) + copy(cmd.val, b) return nil } @@ -433,7 +463,7 @@ type FloatCmd struct { val float64 } -func NewFloatCmd(args ...string) *FloatCmd { +func NewFloatCmd(args ...interface{}) *FloatCmd { return &FloatCmd{baseCmd: baseCmd{_args: args, _clusterKeyPos: 1}} } @@ -456,7 +486,7 @@ func (cmd *FloatCmd) parseReply(rd *bufio.Reader) error { cmd.err = err return err } - cmd.val, cmd.err = strconv.ParseFloat(v.(string), 64) + cmd.val, cmd.err = strconv.ParseFloat(string(v.([]byte)), 64) return cmd.err } @@ -468,7 +498,7 @@ type StringSliceCmd struct { val []string } -func NewStringSliceCmd(args ...string) *StringSliceCmd { +func NewStringSliceCmd(args ...interface{}) *StringSliceCmd { return &StringSliceCmd{baseCmd: baseCmd{_args: args, _clusterKeyPos: 1}} } @@ -507,7 +537,7 @@ type BoolSliceCmd struct { val []bool } -func NewBoolSliceCmd(args ...string) *BoolSliceCmd { +func NewBoolSliceCmd(args ...interface{}) *BoolSliceCmd { return &BoolSliceCmd{baseCmd: baseCmd{_args: args, _clusterKeyPos: 1}} } @@ -546,7 +576,7 @@ type StringStringMapCmd struct { val map[string]string } -func NewStringStringMapCmd(args ...string) *StringStringMapCmd { +func NewStringStringMapCmd(args ...interface{}) *StringStringMapCmd { return &StringStringMapCmd{baseCmd: baseCmd{_args: args, _clusterKeyPos: 1}} } @@ -585,7 +615,7 @@ type StringIntMapCmd struct { val map[string]int64 } -func NewStringIntMapCmd(args ...string) *StringIntMapCmd { +func NewStringIntMapCmd(args ...interface{}) *StringIntMapCmd { return &StringIntMapCmd{baseCmd: baseCmd{_args: args, _clusterKeyPos: 1}} } @@ -624,7 +654,7 @@ type ZSliceCmd struct { val []Z } -func NewZSliceCmd(args ...string) *ZSliceCmd { +func NewZSliceCmd(args ...interface{}) *ZSliceCmd { return &ZSliceCmd{baseCmd: baseCmd{_args: args, _clusterKeyPos: 1}} } @@ -664,7 +694,7 @@ type ScanCmd struct { keys []string } -func NewScanCmd(args ...string) *ScanCmd { +func NewScanCmd(args ...interface{}) *ScanCmd { return &ScanCmd{baseCmd: baseCmd{_args: args, _clusterKeyPos: 1}} } @@ -720,7 +750,7 @@ type ClusterSlotCmd struct { val []ClusterSlotInfo } -func NewClusterSlotCmd(args ...string) *ClusterSlotCmd { +func NewClusterSlotCmd(args ...interface{}) *ClusterSlotCmd { return &ClusterSlotCmd{baseCmd: baseCmd{_args: args, _clusterKeyPos: 1}} } diff --git a/command_test.go b/command_test.go index c1c968e94..1218724e4 100644 --- a/command_test.go +++ b/command_test.go @@ -26,7 +26,7 @@ var _ = Describe("Command", func() { Expect(client.Close()).NotTo(HaveOccurred()) }) - It("should have a plain string result", func() { + It("should implement Stringer", func() { set := client.Set("foo", "bar", 0) Expect(set.String()).To(Equal("SET foo bar: OK")) @@ -117,6 +117,13 @@ var _ = Describe("Command", func() { Expect(f).To(Equal(float64(10))) }) + It("Cmd should return string", func() { + cmd := redis.NewCmd("PING") + client.Process(cmd) + Expect(cmd.Err()).NotTo(HaveOccurred()) + Expect(cmd.Val()).To(Equal("PONG")) + }) + Describe("races", func() { var C, N = 10, 1000 if testing.Short() { diff --git a/commands.go b/commands.go index 7f4cbe082..56167d508 100644 --- a/commands.go +++ b/commands.go @@ -7,14 +7,18 @@ import ( "time" ) -func formatFloat(f float64) string { - return strconv.FormatFloat(f, 'f', -1, 64) -} - func formatInt(i int64) string { return strconv.FormatInt(i, 10) } +func formatUint(i uint64) string { + return strconv.FormatUint(i, 10) +} + +func formatFloat(f float64) string { + return strconv.FormatFloat(f, 'f', -1, 64) +} + func readTimeout(timeout time.Duration) time.Duration { if timeout == 0 { return 0 @@ -22,14 +26,6 @@ func readTimeout(timeout time.Duration) time.Duration { return timeout + time.Second } -type commandable struct { - process func(cmd Cmder) -} - -func (c *commandable) Process(cmd Cmder) { - c.process(cmd) -} - func usePrecise(dur time.Duration) bool { return dur < time.Second || dur%time.Second != 0 } @@ -41,7 +37,7 @@ func formatMs(dur time.Duration) string { dur, time.Millisecond, ) } - return strconv.FormatInt(int64(dur/time.Millisecond), 10) + return formatInt(int64(dur / time.Millisecond)) } func formatSec(dur time.Duration) string { @@ -51,7 +47,15 @@ func formatSec(dur time.Duration) string { dur, time.Second, ) } - return strconv.FormatInt(int64(dur/time.Second), 10) + return formatInt(int64(dur / time.Second)) +} + +type commandable struct { + process func(cmd Cmder) +} + +func (c *commandable) Process(cmd Cmder) { + c.process(cmd) } //------------------------------------------------------------------------------ @@ -80,7 +84,7 @@ func (c *commandable) Quit() *StatusCmd { } func (c *commandable) Select(index int64) *StatusCmd { - cmd := newKeylessStatusCmd("SELECT", strconv.FormatInt(index, 10)) + cmd := newKeylessStatusCmd("SELECT", formatInt(index)) c.Process(cmd) return cmd } @@ -88,7 +92,11 @@ func (c *commandable) Select(index int64) *StatusCmd { //------------------------------------------------------------------------------ func (c *commandable) Del(keys ...string) *IntCmd { - args := append([]string{"DEL"}, keys...) + args := make([]interface{}, 1+len(keys)) + args[0] = "DEL" + for i, key := range keys { + args[1+i] = key + } cmd := NewIntCmd(args...) c.Process(cmd) return cmd @@ -113,7 +121,7 @@ func (c *commandable) Expire(key string, expiration time.Duration) *BoolCmd { } func (c *commandable) ExpireAt(key string, tm time.Time) *BoolCmd { - cmd := NewBoolCmd("EXPIREAT", key, strconv.FormatInt(tm.Unix(), 10)) + cmd := NewBoolCmd("EXPIREAT", key, formatInt(tm.Unix())) c.Process(cmd) return cmd } @@ -130,7 +138,7 @@ func (c *commandable) Migrate(host, port, key string, db int64, timeout time.Dur host, port, key, - strconv.FormatInt(db, 10), + formatInt(db), formatMs(timeout), ) cmd._clusterKeyPos = 3 @@ -140,13 +148,18 @@ func (c *commandable) Migrate(host, port, key string, db int64, timeout time.Dur } func (c *commandable) Move(key string, db int64) *BoolCmd { - cmd := NewBoolCmd("MOVE", key, strconv.FormatInt(db, 10)) + cmd := NewBoolCmd("MOVE", key, formatInt(db)) c.Process(cmd) return cmd } func (c *commandable) ObjectRefCount(keys ...string) *IntCmd { - args := append([]string{"OBJECT", "REFCOUNT"}, keys...) + args := make([]interface{}, 2+len(keys)) + args[0] = "OBJECT" + args[1] = "REFCOUNT" + for i, key := range keys { + args[2+i] = key + } cmd := NewIntCmd(args...) cmd._clusterKeyPos = 2 c.Process(cmd) @@ -154,7 +167,12 @@ func (c *commandable) ObjectRefCount(keys ...string) *IntCmd { } func (c *commandable) ObjectEncoding(keys ...string) *StringCmd { - args := append([]string{"OBJECT", "ENCODING"}, keys...) + args := make([]interface{}, 2+len(keys)) + args[0] = "OBJECT" + args[1] = "ENCODING" + for i, key := range keys { + args[2+i] = key + } cmd := NewStringCmd(args...) cmd._clusterKeyPos = 2 c.Process(cmd) @@ -162,7 +180,12 @@ func (c *commandable) ObjectEncoding(keys ...string) *StringCmd { } func (c *commandable) ObjectIdleTime(keys ...string) *DurationCmd { - args := append([]string{"OBJECT", "IDLETIME"}, keys...) + args := make([]interface{}, 2+len(keys)) + args[0] = "OBJECT" + args[1] = "IDLETIME" + for i, key := range keys { + args[2+i] = key + } cmd := NewDurationCmd(time.Second, args...) cmd._clusterKeyPos = 2 c.Process(cmd) @@ -185,7 +208,7 @@ func (c *commandable) PExpireAt(key string, tm time.Time) *BoolCmd { cmd := NewBoolCmd( "PEXPIREAT", key, - strconv.FormatInt(tm.UnixNano()/int64(time.Millisecond), 10), + formatInt(tm.UnixNano()/int64(time.Millisecond)), ) c.Process(cmd) return cmd @@ -219,7 +242,7 @@ func (c *commandable) Restore(key string, ttl int64, value string) *StatusCmd { cmd := NewStatusCmd( "RESTORE", key, - strconv.FormatInt(ttl, 10), + formatInt(ttl), value, ) c.Process(cmd) @@ -236,7 +259,7 @@ type Sort struct { } func (c *commandable) Sort(key string, sort Sort) *StringSliceCmd { - args := []string{"SORT", key} + args := []interface{}{"SORT", key} if sort.By != "" { args = append(args, "BY", sort.By) } @@ -273,12 +296,12 @@ func (c *commandable) Type(key string) *StatusCmd { } func (c *commandable) Scan(cursor int64, match string, count int64) *ScanCmd { - args := []string{"SCAN", strconv.FormatInt(cursor, 10)} + args := []interface{}{"SCAN", formatInt(cursor)} if match != "" { args = append(args, "MATCH", match) } if count > 0 { - args = append(args, "COUNT", strconv.FormatInt(count, 10)) + args = append(args, "COUNT", formatInt(count)) } cmd := NewScanCmd(args...) c.Process(cmd) @@ -286,12 +309,12 @@ func (c *commandable) Scan(cursor int64, match string, count int64) *ScanCmd { } func (c *commandable) SScan(key string, cursor int64, match string, count int64) *ScanCmd { - args := []string{"SSCAN", key, strconv.FormatInt(cursor, 10)} + args := []interface{}{"SSCAN", key, formatInt(cursor)} if match != "" { args = append(args, "MATCH", match) } if count > 0 { - args = append(args, "COUNT", strconv.FormatInt(count, 10)) + args = append(args, "COUNT", formatInt(count)) } cmd := NewScanCmd(args...) c.Process(cmd) @@ -299,12 +322,12 @@ func (c *commandable) SScan(key string, cursor int64, match string, count int64) } func (c *commandable) HScan(key string, cursor int64, match string, count int64) *ScanCmd { - args := []string{"HSCAN", key, strconv.FormatInt(cursor, 10)} + args := []interface{}{"HSCAN", key, formatInt(cursor)} if match != "" { args = append(args, "MATCH", match) } if count > 0 { - args = append(args, "COUNT", strconv.FormatInt(count, 10)) + args = append(args, "COUNT", formatInt(count)) } cmd := NewScanCmd(args...) c.Process(cmd) @@ -312,12 +335,12 @@ func (c *commandable) HScan(key string, cursor int64, match string, count int64) } func (c *commandable) ZScan(key string, cursor int64, match string, count int64) *ScanCmd { - args := []string{"ZSCAN", key, strconv.FormatInt(cursor, 10)} + args := []interface{}{"ZSCAN", key, formatInt(cursor)} if match != "" { args = append(args, "MATCH", match) } if count > 0 { - args = append(args, "COUNT", strconv.FormatInt(count, 10)) + args = append(args, "COUNT", formatInt(count)) } cmd := NewScanCmd(args...) c.Process(cmd) @@ -337,12 +360,12 @@ type BitCount struct { } func (c *commandable) BitCount(key string, bitCount *BitCount) *IntCmd { - args := []string{"BITCOUNT", key} + args := []interface{}{"BITCOUNT", key} if bitCount != nil { args = append( args, - strconv.FormatInt(bitCount.Start, 10), - strconv.FormatInt(bitCount.End, 10), + formatInt(bitCount.Start), + formatInt(bitCount.End), ) } cmd := NewIntCmd(args...) @@ -351,8 +374,13 @@ func (c *commandable) BitCount(key string, bitCount *BitCount) *IntCmd { } func (c *commandable) bitOp(op, destKey string, keys ...string) *IntCmd { - args := []string{"BITOP", op, destKey} - args = append(args, keys...) + args := make([]interface{}, 3+len(keys)) + args[0] = "BITOP" + args[1] = op + args[2] = destKey + for i, key := range keys { + args[3+i] = key + } cmd := NewIntCmd(args...) c.Process(cmd) return cmd @@ -375,13 +403,17 @@ func (c *commandable) BitOpNot(destKey string, key string) *IntCmd { } func (c *commandable) BitPos(key string, bit int64, pos ...int64) *IntCmd { - args := []string{"BITPOS", key, formatInt(bit)} + args := make([]interface{}, 3+len(pos)) + args[0] = "BITPOS" + args[1] = key + args[2] = formatInt(bit) switch len(pos) { case 0: case 1: - args = append(args, formatInt(pos[0])) + args[3] = formatInt(pos[0]) case 2: - args = append(args, formatInt(pos[0]), formatInt(pos[1])) + args[3] = formatInt(pos[0]) + args[4] = formatInt(pos[1]) default: panic("too many arguments") } @@ -397,7 +429,7 @@ func (c *commandable) Decr(key string) *IntCmd { } func (c *commandable) DecrBy(key string, decrement int64) *IntCmd { - cmd := NewIntCmd("DECRBY", key, strconv.FormatInt(decrement, 10)) + cmd := NewIntCmd("DECRBY", key, formatInt(decrement)) c.Process(cmd) return cmd } @@ -409,7 +441,7 @@ func (c *commandable) Get(key string) *StringCmd { } func (c *commandable) GetBit(key string, offset int64) *IntCmd { - cmd := NewIntCmd("GETBIT", key, strconv.FormatInt(offset, 10)) + cmd := NewIntCmd("GETBIT", key, formatInt(offset)) c.Process(cmd) return cmd } @@ -418,14 +450,14 @@ func (c *commandable) GetRange(key string, start, end int64) *StringCmd { cmd := NewStringCmd( "GETRANGE", key, - strconv.FormatInt(start, 10), - strconv.FormatInt(end, 10), + formatInt(start), + formatInt(end), ) c.Process(cmd) return cmd } -func (c *commandable) GetSet(key, value string) *StringCmd { +func (c *commandable) GetSet(key string, value interface{}) *StringCmd { cmd := NewStringCmd("GETSET", key, value) c.Process(cmd) return cmd @@ -438,7 +470,7 @@ func (c *commandable) Incr(key string) *IntCmd { } func (c *commandable) IncrBy(key string, value int64) *IntCmd { - cmd := NewIntCmd("INCRBY", key, strconv.FormatInt(value, 10)) + cmd := NewIntCmd("INCRBY", key, formatInt(value)) c.Process(cmd) return cmd } @@ -450,28 +482,43 @@ func (c *commandable) IncrByFloat(key string, value float64) *FloatCmd { } func (c *commandable) MGet(keys ...string) *SliceCmd { - args := append([]string{"MGET"}, keys...) + args := make([]interface{}, 1+len(keys)) + args[0] = "MGET" + for i, key := range keys { + args[1+i] = key + } cmd := NewSliceCmd(args...) c.Process(cmd) return cmd } func (c *commandable) MSet(pairs ...string) *StatusCmd { - args := append([]string{"MSET"}, pairs...) + args := make([]interface{}, 1+len(pairs)) + args[0] = "MSET" + for i, pair := range pairs { + args[1+i] = pair + } cmd := NewStatusCmd(args...) c.Process(cmd) return cmd } func (c *commandable) MSetNX(pairs ...string) *BoolCmd { - args := append([]string{"MSETNX"}, pairs...) + args := make([]interface{}, 1+len(pairs)) + args[0] = "MSETNX" + for i, pair := range pairs { + args[1+i] = pair + } cmd := NewBoolCmd(args...) c.Process(cmd) return cmd } -func (c *commandable) Set(key, value string, expiration time.Duration) *StatusCmd { - args := []string{"SET", key, value} +func (c *commandable) Set(key string, value interface{}, expiration time.Duration) *StatusCmd { + args := make([]interface{}, 3, 5) + args[0] = "SET" + args[1] = key + args[2] = value if expiration > 0 { if usePrecise(expiration) { args = append(args, "PX", formatMs(expiration)) @@ -488,14 +535,14 @@ func (c *commandable) SetBit(key string, offset int64, value int) *IntCmd { cmd := NewIntCmd( "SETBIT", key, - strconv.FormatInt(offset, 10), - strconv.FormatInt(int64(value), 10), + formatInt(offset), + formatInt(int64(value)), ) c.Process(cmd) return cmd } -func (c *commandable) SetNX(key, value string, expiration time.Duration) *BoolCmd { +func (c *commandable) SetNX(key string, value interface{}, expiration time.Duration) *BoolCmd { var cmd *BoolCmd if expiration == 0 { // Use old `SETNX` to support old Redis versions. @@ -511,7 +558,7 @@ func (c *commandable) SetNX(key, value string, expiration time.Duration) *BoolCm return cmd } -func (c *Client) SetXX(key, value string, expiration time.Duration) *BoolCmd { +func (c *Client) SetXX(key string, value interface{}, expiration time.Duration) *BoolCmd { var cmd *BoolCmd if usePrecise(expiration) { cmd = NewBoolCmd("SET", key, value, "PX", formatMs(expiration), "XX") @@ -523,7 +570,7 @@ func (c *Client) SetXX(key, value string, expiration time.Duration) *BoolCmd { } func (c *commandable) SetRange(key string, offset int64, value string) *IntCmd { - cmd := NewIntCmd("SETRANGE", key, strconv.FormatInt(offset, 10), value) + cmd := NewIntCmd("SETRANGE", key, formatInt(offset), value) c.Process(cmd) return cmd } @@ -537,7 +584,12 @@ func (c *commandable) StrLen(key string) *IntCmd { //------------------------------------------------------------------------------ func (c *commandable) HDel(key string, fields ...string) *IntCmd { - args := append([]string{"HDEL", key}, fields...) + args := make([]interface{}, 2+len(fields)) + args[0] = "HDEL" + args[1] = key + for i, field := range fields { + args[2+i] = field + } cmd := NewIntCmd(args...) c.Process(cmd) return cmd @@ -568,7 +620,7 @@ func (c *commandable) HGetAllMap(key string) *StringStringMapCmd { } func (c *commandable) HIncrBy(key, field string, incr int64) *IntCmd { - cmd := NewIntCmd("HINCRBY", key, field, strconv.FormatInt(incr, 10)) + cmd := NewIntCmd("HINCRBY", key, field, formatInt(incr)) c.Process(cmd) return cmd } @@ -592,14 +644,26 @@ func (c *commandable) HLen(key string) *IntCmd { } func (c *commandable) HMGet(key string, fields ...string) *SliceCmd { - args := append([]string{"HMGET", key}, fields...) + args := make([]interface{}, 2+len(fields)) + args[0] = "HMGET" + args[1] = key + for i, field := range fields { + args[2+i] = field + } cmd := NewSliceCmd(args...) c.Process(cmd) return cmd } func (c *commandable) HMSet(key, field, value string, pairs ...string) *StatusCmd { - args := append([]string{"HMSET", key, field, value}, pairs...) + args := make([]interface{}, 4+len(pairs)) + args[0] = "HMSET" + args[1] = key + args[2] = field + args[3] = value + for i, pair := range pairs { + args[4+i] = pair + } cmd := NewStatusCmd(args...) c.Process(cmd) return cmd @@ -626,8 +690,12 @@ func (c *commandable) HVals(key string) *StringSliceCmd { //------------------------------------------------------------------------------ func (c *commandable) BLPop(timeout time.Duration, keys ...string) *StringSliceCmd { - args := append([]string{"BLPOP"}, keys...) - args = append(args, formatSec(timeout)) + args := make([]interface{}, 2+len(keys)) + args[0] = "BLPOP" + for i, key := range keys { + args[1+i] = key + } + args[len(args)-1] = formatSec(timeout) cmd := NewStringSliceCmd(args...) cmd.setReadTimeout(readTimeout(timeout)) c.Process(cmd) @@ -635,8 +703,12 @@ func (c *commandable) BLPop(timeout time.Duration, keys ...string) *StringSliceC } func (c *commandable) BRPop(timeout time.Duration, keys ...string) *StringSliceCmd { - args := append([]string{"BRPOP"}, keys...) - args = append(args, formatSec(timeout)) + args := make([]interface{}, 2+len(keys)) + args[0] = "BRPOP" + for i, key := range keys { + args[1+i] = key + } + args[len(args)-1] = formatSec(timeout) cmd := NewStringSliceCmd(args...) cmd.setReadTimeout(readTimeout(timeout)) c.Process(cmd) @@ -656,7 +728,7 @@ func (c *commandable) BRPopLPush(source, destination string, timeout time.Durati } func (c *commandable) LIndex(key string, index int64) *StringCmd { - cmd := NewStringCmd("LINDEX", key, strconv.FormatInt(index, 10)) + cmd := NewStringCmd("LINDEX", key, formatInt(index)) c.Process(cmd) return cmd } @@ -680,7 +752,12 @@ func (c *commandable) LPop(key string) *StringCmd { } func (c *commandable) LPush(key string, values ...string) *IntCmd { - args := append([]string{"LPUSH", key}, values...) + args := make([]interface{}, 2+len(values)) + args[0] = "LPUSH" + args[1] = key + for i, value := range values { + args[2+i] = value + } cmd := NewIntCmd(args...) c.Process(cmd) return cmd @@ -696,21 +773,21 @@ func (c *commandable) LRange(key string, start, stop int64) *StringSliceCmd { cmd := NewStringSliceCmd( "LRANGE", key, - strconv.FormatInt(start, 10), - strconv.FormatInt(stop, 10), + formatInt(start), + formatInt(stop), ) c.Process(cmd) return cmd } func (c *commandable) LRem(key string, count int64, value string) *IntCmd { - cmd := NewIntCmd("LREM", key, strconv.FormatInt(count, 10), value) + cmd := NewIntCmd("LREM", key, formatInt(count), value) c.Process(cmd) return cmd } func (c *commandable) LSet(key string, index int64, value string) *StatusCmd { - cmd := NewStatusCmd("LSET", key, strconv.FormatInt(index, 10), value) + cmd := NewStatusCmd("LSET", key, formatInt(index), value) c.Process(cmd) return cmd } @@ -719,8 +796,8 @@ func (c *commandable) LTrim(key string, start, stop int64) *StatusCmd { cmd := NewStatusCmd( "LTRIM", key, - strconv.FormatInt(start, 10), - strconv.FormatInt(stop, 10), + formatInt(start), + formatInt(stop), ) c.Process(cmd) return cmd @@ -739,7 +816,12 @@ func (c *commandable) RPopLPush(source, destination string) *StringCmd { } func (c *commandable) RPush(key string, values ...string) *IntCmd { - args := append([]string{"RPUSH", key}, values...) + args := make([]interface{}, 2+len(values)) + args[0] = "RPUSH" + args[1] = key + for i, value := range values { + args[2+i] = value + } cmd := NewIntCmd(args...) c.Process(cmd) return cmd @@ -754,7 +836,12 @@ func (c *commandable) RPushX(key string, value string) *IntCmd { //------------------------------------------------------------------------------ func (c *commandable) SAdd(key string, members ...string) *IntCmd { - args := append([]string{"SADD", key}, members...) + args := make([]interface{}, 2+len(members)) + args[0] = "SADD" + args[1] = key + for i, member := range members { + args[2+i] = member + } cmd := NewIntCmd(args...) c.Process(cmd) return cmd @@ -767,28 +854,46 @@ func (c *commandable) SCard(key string) *IntCmd { } func (c *commandable) SDiff(keys ...string) *StringSliceCmd { - args := append([]string{"SDIFF"}, keys...) + args := make([]interface{}, 1+len(keys)) + args[0] = "SDIFF" + for i, key := range keys { + args[1+i] = key + } cmd := NewStringSliceCmd(args...) c.Process(cmd) return cmd } func (c *commandable) SDiffStore(destination string, keys ...string) *IntCmd { - args := append([]string{"SDIFFSTORE", destination}, keys...) + args := make([]interface{}, 2+len(keys)) + args[0] = "SDIFFSTORE" + args[1] = destination + for i, key := range keys { + args[2+i] = key + } cmd := NewIntCmd(args...) c.Process(cmd) return cmd } func (c *commandable) SInter(keys ...string) *StringSliceCmd { - args := append([]string{"SINTER"}, keys...) + args := make([]interface{}, 1+len(keys)) + args[0] = "SINTER" + for i, key := range keys { + args[1+i] = key + } cmd := NewStringSliceCmd(args...) c.Process(cmd) return cmd } func (c *commandable) SInterStore(destination string, keys ...string) *IntCmd { - args := append([]string{"SINTERSTORE", destination}, keys...) + args := make([]interface{}, 2+len(keys)) + args[0] = "SINTERSTORE" + args[1] = destination + for i, key := range keys { + args[2+i] = key + } cmd := NewIntCmd(args...) c.Process(cmd) return cmd @@ -825,21 +930,35 @@ func (c *commandable) SRandMember(key string) *StringCmd { } func (c *commandable) SRem(key string, members ...string) *IntCmd { - args := append([]string{"SREM", key}, members...) + args := make([]interface{}, 2+len(members)) + args[0] = "SREM" + args[1] = key + for i, member := range members { + args[2+i] = member + } cmd := NewIntCmd(args...) c.Process(cmd) return cmd } func (c *commandable) SUnion(keys ...string) *StringSliceCmd { - args := append([]string{"SUNION"}, keys...) + args := make([]interface{}, 1+len(keys)) + args[0] = "SUNION" + for i, key := range keys { + args[1+i] = key + } cmd := NewStringSliceCmd(args...) c.Process(cmd) return cmd } func (c *commandable) SUnionStore(destination string, keys ...string) *IntCmd { - args := append([]string{"SUNIONSTORE", destination}, keys...) + args := make([]interface{}, 2+len(keys)) + args[0] = "SUNIONSTORE" + args[1] = destination + for i, key := range keys { + args[2+i] = key + } cmd := NewIntCmd(args...) c.Process(cmd) return cmd @@ -858,7 +977,7 @@ type ZStore struct { } func (c *commandable) ZAdd(key string, members ...Z) *IntCmd { - args := make([]string, 2+2*len(members)) + args := make([]interface{}, 2+2*len(members)) args[0] = "ZADD" args[1] = key for i, m := range members { @@ -893,12 +1012,17 @@ func (c *commandable) ZInterStore( store ZStore, keys ...string, ) *IntCmd { - args := []string{"ZINTERSTORE", destination, strconv.FormatInt(int64(len(keys)), 10)} - args = append(args, keys...) + args := make([]interface{}, 3+len(keys)) + args[0] = "ZINTERSTORE" + args[1] = destination + args[2] = strconv.Itoa(len(keys)) + for i, key := range keys { + args[3+i] = key + } if len(store.Weights) > 0 { args = append(args, "WEIGHTS") for _, weight := range store.Weights { - args = append(args, strconv.FormatInt(weight, 10)) + args = append(args, formatInt(weight)) } } if store.Aggregate != "" { @@ -910,11 +1034,11 @@ func (c *commandable) ZInterStore( } func (c *commandable) zRange(key string, start, stop int64, withScores bool) *StringSliceCmd { - args := []string{ + args := []interface{}{ "ZRANGE", key, - strconv.FormatInt(start, 10), - strconv.FormatInt(stop, 10), + formatInt(start), + formatInt(stop), } if withScores { args = append(args, "WITHSCORES") @@ -929,11 +1053,11 @@ func (c *commandable) ZRange(key string, start, stop int64) *StringSliceCmd { } func (c *commandable) ZRangeWithScores(key string, start, stop int64) *ZSliceCmd { - args := []string{ + args := []interface{}{ "ZRANGE", key, - strconv.FormatInt(start, 10), - strconv.FormatInt(stop, 10), + formatInt(start), + formatInt(stop), "WITHSCORES", } cmd := NewZSliceCmd(args...) @@ -947,7 +1071,7 @@ type ZRangeByScore struct { } func (c *commandable) zRangeByScore(key string, opt ZRangeByScore, withScores bool) *StringSliceCmd { - args := []string{"ZRANGEBYSCORE", key, opt.Min, opt.Max} + args := []interface{}{"ZRANGEBYSCORE", key, opt.Min, opt.Max} if withScores { args = append(args, "WITHSCORES") } @@ -955,8 +1079,8 @@ func (c *commandable) zRangeByScore(key string, opt ZRangeByScore, withScores bo args = append( args, "LIMIT", - strconv.FormatInt(opt.Offset, 10), - strconv.FormatInt(opt.Count, 10), + formatInt(opt.Offset), + formatInt(opt.Count), ) } cmd := NewStringSliceCmd(args...) @@ -969,13 +1093,13 @@ func (c *commandable) ZRangeByScore(key string, opt ZRangeByScore) *StringSliceC } func (c *commandable) ZRangeByScoreWithScores(key string, opt ZRangeByScore) *ZSliceCmd { - args := []string{"ZRANGEBYSCORE", key, opt.Min, opt.Max, "WITHSCORES"} + args := []interface{}{"ZRANGEBYSCORE", key, opt.Min, opt.Max, "WITHSCORES"} if opt.Offset != 0 || opt.Count != 0 { args = append( args, "LIMIT", - strconv.FormatInt(opt.Offset, 10), - strconv.FormatInt(opt.Count, 10), + formatInt(opt.Offset), + formatInt(opt.Count), ) } cmd := NewZSliceCmd(args...) @@ -990,7 +1114,12 @@ func (c *commandable) ZRank(key, member string) *IntCmd { } func (c *commandable) ZRem(key string, members ...string) *IntCmd { - args := append([]string{"ZREM", key}, members...) + args := make([]interface{}, 2+len(members)) + args[0] = "ZREM" + args[1] = key + for i, member := range members { + args[2+i] = member + } cmd := NewIntCmd(args...) c.Process(cmd) return cmd @@ -1000,8 +1129,8 @@ func (c *commandable) ZRemRangeByRank(key string, start, stop int64) *IntCmd { cmd := NewIntCmd( "ZREMRANGEBYRANK", key, - strconv.FormatInt(start, 10), - strconv.FormatInt(stop, 10), + formatInt(start), + formatInt(stop), ) c.Process(cmd) return cmd @@ -1026,13 +1155,13 @@ func (c *commandable) ZRevRangeWithScores(key string, start, stop int64) *ZSlice } func (c *commandable) ZRevRangeByScore(key string, opt ZRangeByScore) *StringSliceCmd { - args := []string{"ZREVRANGEBYSCORE", key, opt.Max, opt.Min} + args := []interface{}{"ZREVRANGEBYSCORE", key, opt.Max, opt.Min} if opt.Offset != 0 || opt.Count != 0 { args = append( args, "LIMIT", - strconv.FormatInt(opt.Offset, 10), - strconv.FormatInt(opt.Count, 10), + formatInt(opt.Offset), + formatInt(opt.Count), ) } cmd := NewStringSliceCmd(args...) @@ -1041,13 +1170,13 @@ func (c *commandable) ZRevRangeByScore(key string, opt ZRangeByScore) *StringSli } func (c *commandable) ZRevRangeByScoreWithScores(key string, opt ZRangeByScore) *ZSliceCmd { - args := []string{"ZREVRANGEBYSCORE", key, opt.Max, opt.Min, "WITHSCORES"} + args := []interface{}{"ZREVRANGEBYSCORE", key, opt.Max, opt.Min, "WITHSCORES"} if opt.Offset != 0 || opt.Count != 0 { args = append( args, "LIMIT", - strconv.FormatInt(opt.Offset, 10), - strconv.FormatInt(opt.Count, 10), + formatInt(opt.Offset), + formatInt(opt.Count), ) } cmd := NewZSliceCmd(args...) @@ -1068,12 +1197,17 @@ func (c *commandable) ZScore(key, member string) *FloatCmd { } func (c *commandable) ZUnionStore(dest string, store ZStore, keys ...string) *IntCmd { - args := []string{"ZUNIONSTORE", dest, strconv.FormatInt(int64(len(keys)), 10)} - args = append(args, keys...) + args := make([]interface{}, 3+len(keys)) + args[0] = "ZUNIONSTORE" + args[1] = dest + args[2] = strconv.Itoa(len(keys)) + for i, key := range keys { + args[3+i] = key + } if len(store.Weights) > 0 { args = append(args, "WEIGHTS") for _, weight := range store.Weights { - args = append(args, strconv.FormatInt(weight, 10)) + args = append(args, formatInt(weight)) } } if store.Aggregate != "" { @@ -1182,11 +1316,11 @@ func (c *commandable) Save() *StatusCmd { } func (c *commandable) shutdown(modifier string) *StatusCmd { - var args []string + var args []interface{} if modifier == "" { - args = []string{"SHUTDOWN"} + args = []interface{}{"SHUTDOWN"} } else { - args = []string{"SHUTDOWN", modifier} + args = []interface{}{"SHUTDOWN", modifier} } cmd := newKeylessStatusCmd(args...) c.Process(cmd) @@ -1239,9 +1373,17 @@ func (c *commandable) Time() *StringSliceCmd { //------------------------------------------------------------------------------ func (c *commandable) Eval(script string, keys []string, args []string) *Cmd { - cmdArgs := []string{"EVAL", script, strconv.FormatInt(int64(len(keys)), 10)} - cmdArgs = append(cmdArgs, keys...) - cmdArgs = append(cmdArgs, args...) + cmdArgs := make([]interface{}, 3+len(keys)+len(args)) + cmdArgs[0] = "EVAL" + cmdArgs[1] = script + cmdArgs[2] = strconv.Itoa(len(keys)) + for i, key := range keys { + cmdArgs[3+i] = key + } + pos := 3 + len(keys) + for i, arg := range args { + cmdArgs[pos+i] = arg + } cmd := NewCmd(cmdArgs...) if len(keys) > 0 { cmd._clusterKeyPos = 3 @@ -1251,9 +1393,17 @@ func (c *commandable) Eval(script string, keys []string, args []string) *Cmd { } func (c *commandable) EvalSha(sha1 string, keys []string, args []string) *Cmd { - cmdArgs := []string{"EVALSHA", sha1, strconv.FormatInt(int64(len(keys)), 10)} - cmdArgs = append(cmdArgs, keys...) - cmdArgs = append(cmdArgs, args...) + cmdArgs := make([]interface{}, 3+len(keys)+len(args)) + cmdArgs[0] = "EVALSHA" + cmdArgs[1] = sha1 + cmdArgs[2] = strconv.Itoa(len(keys)) + for i, key := range keys { + cmdArgs[3+i] = key + } + pos := 3 + len(keys) + for i, arg := range args { + cmdArgs[pos+i] = arg + } cmd := NewCmd(cmdArgs...) if len(keys) > 0 { cmd._clusterKeyPos = 3 @@ -1263,7 +1413,12 @@ func (c *commandable) EvalSha(sha1 string, keys []string, args []string) *Cmd { } func (c *commandable) ScriptExists(scripts ...string) *BoolSliceCmd { - args := append([]string{"SCRIPT", "EXISTS"}, scripts...) + args := make([]interface{}, 2+len(scripts)) + args[0] = "SCRIPT" + args[1] = "EXISTS" + for i, script := range scripts { + args[2+i] = script + } cmd := NewBoolSliceCmd(args...) cmd._clusterKeyPos = 0 c.Process(cmd) @@ -1301,7 +1456,7 @@ func (c *commandable) DebugObject(key string) *StringCmd { //------------------------------------------------------------------------------ func (c *commandable) PubSubChannels(pattern string) *StringSliceCmd { - args := []string{"PUBSUB", "CHANNELS"} + args := []interface{}{"PUBSUB", "CHANNELS"} if pattern != "*" { args = append(args, pattern) } @@ -1312,8 +1467,12 @@ func (c *commandable) PubSubChannels(pattern string) *StringSliceCmd { } func (c *commandable) PubSubNumSub(channels ...string) *StringIntMapCmd { - args := []string{"PUBSUB", "NUMSUB"} - args = append(args, channels...) + args := make([]interface{}, 2+len(channels)) + args[0] = "PUBSUB" + args[1] = "NUMSUB" + for i, channel := range channels { + args[2+i] = channel + } cmd := NewStringIntMapCmd(args...) cmd._clusterKeyPos = 0 c.Process(cmd) @@ -1369,11 +1528,11 @@ func (c *commandable) ClusterFailover() *StatusCmd { } func (c *commandable) ClusterAddSlots(slots ...int) *StatusCmd { - args := make([]string, len(slots)+2) + args := make([]interface{}, 2+len(slots)) args[0] = "CLUSTER" - args[1] = "addslots" + args[1] = "ADDSLOTS" for i, num := range slots { - args[i+2] = strconv.Itoa(num) + args[2+i] = strconv.Itoa(num) } cmd := newKeylessStatusCmd(args...) c.Process(cmd) diff --git a/commands_test.go b/commands_test.go index ef593e1eb..0e0405edd 100644 --- a/commands_test.go +++ b/commands_test.go @@ -1,7 +1,9 @@ package redis_test import ( + "encoding/json" "fmt" + "reflect" "strconv" "sync" "testing" @@ -2286,4 +2288,95 @@ var _ = Describe("Commands", func() { }) + Describe("marshaling/unmarshaling", func() { + + type convTest struct { + value interface{} + wanted string + dest interface{} + } + + convTests := []convTest{ + {nil, "", nil}, + {"hello", "hello", new(string)}, + {[]byte("hello"), "hello", new([]byte)}, + {int(1), "1", new(int)}, + {int8(1), "1", new(int8)}, + {int16(1), "1", new(int16)}, + {int32(1), "1", new(int32)}, + {int64(1), "1", new(int64)}, + {uint(1), "1", new(uint)}, + {uint8(1), "1", new(uint8)}, + {uint16(1), "1", new(uint16)}, + {uint32(1), "1", new(uint32)}, + {uint64(1), "1", new(uint64)}, + {float32(1.0), "1", new(float32)}, + {float64(1.0), "1", new(float64)}, + {true, "1", new(bool)}, + {false, "0", new(bool)}, + } + + It("should convert to string", func() { + for _, test := range convTests { + err := client.Set("key", test.value, 0).Err() + Expect(err).NotTo(HaveOccurred()) + + s, err := client.Get("key").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(s).To(Equal(test.wanted)) + + if test.dest == nil { + continue + } + + err = client.Get("key").Scan(test.dest) + Expect(err).NotTo(HaveOccurred()) + Expect(deref(test.dest)).To(Equal(test.value)) + } + }) + + }) + + Describe("json marshaling/unmarshaling", func() { + BeforeEach(func() { + value := &numberStruct{Number: 42} + err := client.Set("key", value, 0).Err() + Expect(err).NotTo(HaveOccurred()) + }) + + It("should marshal custom values using json", func() { + s, err := client.Get("key").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(s).To(Equal(`{"Number":42}`)) + }) + + It("should scan custom values using json", func() { + value := &numberStruct{} + err := client.Get("key").Scan(value) + Expect(err).To(BeNil()) + Expect(value.Number).To(Equal(42)) + }) + + }) + }) + +type numberStruct struct { + Number int +} + +func (s *numberStruct) MarshalBinary() ([]byte, error) { + return json.Marshal(s) +} + +func (s *numberStruct) UnmarshalBinary(b []byte) error { + return json.Unmarshal(b, s) +} + +func deref(viface interface{}) interface{} { + v := reflect.ValueOf(viface) + for v.Kind() == reflect.Ptr { + v = v.Elem() + } + return v.Interface() +} diff --git a/conn.go b/conn.go index 72f82b0fd..bf104d9dd 100644 --- a/conn.go +++ b/conn.go @@ -67,7 +67,11 @@ func (cn *conn) init(opt *Options) error { func (cn *conn) writeCmds(cmds ...Cmder) error { buf := cn.buf[:0] for _, cmd := range cmds { - buf = appendArgs(buf, cmd.args()) + var err error + buf, err = appendArgs(buf, cmd.args()) + if err != nil { + return err + } } _, err := cn.Write(buf) diff --git a/multi.go b/multi.go index 9d87de9a5..63ecdd589 100644 --- a/multi.go +++ b/multi.go @@ -44,14 +44,22 @@ func (c *Multi) Close() error { } func (c *Multi) Watch(keys ...string) *StatusCmd { - args := append([]string{"WATCH"}, keys...) + args := make([]interface{}, 1+len(keys)) + args[0] = "WATCH" + for i, key := range keys { + args[1+i] = key + } cmd := NewStatusCmd(args...) c.Process(cmd) return cmd } func (c *Multi) Unwatch(keys ...string) *StatusCmd { - args := append([]string{"UNWATCH"}, keys...) + args := make([]interface{}, 1+len(keys)) + args[0] = "UNWATCH" + for i, key := range keys { + args[1+i] = key + } cmd := NewStatusCmd(args...) c.Process(cmd) return cmd diff --git a/parser.go b/parser.go index 5a9c79f44..48fe7696e 100644 --- a/parser.go +++ b/parser.go @@ -17,18 +17,201 @@ var ( //------------------------------------------------------------------------------ -func appendArgs(buf []byte, args []string) []byte { - buf = append(buf, '*') - buf = strconv.AppendUint(buf, uint64(len(args)), 10) - buf = append(buf, '\r', '\n') +// Copy of encoding.BinaryMarshaler. +type binaryMarshaler interface { + MarshalBinary() (data []byte, err error) +} + +// Copy of encoding.BinaryUnmarshaler. +type binaryUnmarshaler interface { + UnmarshalBinary(data []byte) error +} + +func appendString(b []byte, s string) []byte { + b = append(b, '$') + b = strconv.AppendUint(b, uint64(len(s)), 10) + b = append(b, '\r', '\n') + b = append(b, s...) + b = append(b, '\r', '\n') + return b +} + +func appendBytes(b, bb []byte) []byte { + b = append(b, '$') + b = strconv.AppendUint(b, uint64(len(bb)), 10) + b = append(b, '\r', '\n') + b = append(b, bb...) + b = append(b, '\r', '\n') + return b +} + +func appendArg(b []byte, val interface{}) ([]byte, error) { + switch v := val.(type) { + case nil: + b = appendString(b, "") + case string: + b = appendString(b, v) + case []byte: + b = appendBytes(b, v) + case int: + b = appendString(b, formatInt(int64(v))) + case int8: + b = appendString(b, formatInt(int64(v))) + case int16: + b = appendString(b, formatInt(int64(v))) + case int32: + b = appendString(b, formatInt(int64(v))) + case int64: + b = appendString(b, formatInt(v)) + case uint: + b = appendString(b, formatUint(uint64(v))) + case uint8: + b = appendString(b, formatUint(uint64(v))) + case uint16: + b = appendString(b, formatUint(uint64(v))) + case uint32: + b = appendString(b, formatUint(uint64(v))) + case uint64: + b = appendString(b, formatUint(v)) + case float32: + b = appendString(b, formatFloat(float64(v))) + case float64: + b = appendString(b, formatFloat(v)) + case bool: + if v { + b = appendString(b, "1") + } else { + b = appendString(b, "0") + } + default: + if bm, ok := val.(binaryMarshaler); ok { + bb, err := bm.MarshalBinary() + if err != nil { + return nil, err + } + b = appendBytes(b, bb) + } else { + err := fmt.Errorf( + "redis: can't marshal %T (consider implementing BinaryMarshaler)", val) + return nil, err + } + } + return b, nil +} + +func appendArgs(b []byte, args []interface{}) ([]byte, error) { + b = append(b, '*') + b = strconv.AppendUint(b, uint64(len(args)), 10) + b = append(b, '\r', '\n') for _, arg := range args { - buf = append(buf, '$') - buf = strconv.AppendUint(buf, uint64(len(arg)), 10) - buf = append(buf, '\r', '\n') - buf = append(buf, arg...) - buf = append(buf, '\r', '\n') + var err error + b, err = appendArg(b, arg) + if err != nil { + return nil, err + } + } + return b, nil +} + +func scan(b []byte, val interface{}) error { + switch v := val.(type) { + case nil: + return errorf("redis: Scan(nil)") + case *string: + *v = string(b) + return nil + case *[]byte: + *v = b + return nil + case *int: + var err error + *v, err = strconv.Atoi(string(b)) + return err + case *int8: + n, err := strconv.ParseInt(string(b), 10, 8) + if err != nil { + return err + } + *v = int8(n) + return nil + case *int16: + n, err := strconv.ParseInt(string(b), 10, 16) + if err != nil { + return err + } + *v = int16(n) + return nil + case *int32: + n, err := strconv.ParseInt(string(b), 10, 16) + if err != nil { + return err + } + *v = int32(n) + return nil + case *int64: + n, err := strconv.ParseInt(string(b), 10, 64) + if err != nil { + return err + } + *v = n + return nil + case *uint: + n, err := strconv.ParseUint(string(b), 10, 64) + if err != nil { + return err + } + *v = uint(n) + return nil + case *uint8: + n, err := strconv.ParseUint(string(b), 10, 8) + if err != nil { + return err + } + *v = uint8(n) + return nil + case *uint16: + n, err := strconv.ParseUint(string(b), 10, 16) + if err != nil { + return err + } + *v = uint16(n) + return nil + case *uint32: + n, err := strconv.ParseUint(string(b), 10, 32) + if err != nil { + return err + } + *v = uint32(n) + return nil + case *uint64: + n, err := strconv.ParseUint(string(b), 10, 64) + if err != nil { + return err + } + *v = n + return nil + case *float32: + n, err := strconv.ParseFloat(string(b), 32) + if err != nil { + return err + } + *v = float32(n) + return err + case *float64: + var err error + *v, err = strconv.ParseFloat(string(b), 64) + return err + case *bool: + *v = len(b) == 1 && b[0] == '1' + return nil + default: + if bu, ok := val.(binaryUnmarshaler); ok { + return bu.UnmarshalBinary(b) + } + err := fmt.Errorf( + "redis: can't unmarshal %T (consider implementing BinaryUnmarshaler)", val) + return err } - return buf } //------------------------------------------------------------------------------ @@ -120,7 +303,7 @@ func parseReply(rd *bufio.Reader, p multiBulkParser) (interface{}, error) { case '-': return nil, errorf(string(line[1:])) case '+': - return string(line[1:]), nil + return line[1:], nil case ':': v, err := strconv.ParseInt(string(line[1:]), 10, 64) if err != nil { @@ -141,7 +324,7 @@ func parseReply(rd *bufio.Reader, p multiBulkParser) (interface{}, error) { if err != nil { return nil, err } - return string(b[:replyLen]), nil + return b[:replyLen], nil case '*': if len(line) == 3 && line[1] == '-' && line[2] == '1' { return nil, Nil @@ -166,7 +349,12 @@ func parseSlice(rd *bufio.Reader, n int64) (interface{}, error) { } else if err != nil { return nil, err } else { - vals = append(vals, v) + switch vv := v.(type) { + case []byte: + vals = append(vals, string(vv)) + default: + vals = append(vals, v) + } } } return vals, nil @@ -179,11 +367,11 @@ func parseStringSlice(rd *bufio.Reader, n int64) (interface{}, error) { if err != nil { return nil, err } - v, ok := viface.(string) + v, ok := viface.([]byte) if !ok { return nil, fmt.Errorf("got %T, expected string", viface) } - vals = append(vals, v) + vals = append(vals, string(v)) } return vals, nil } @@ -211,7 +399,7 @@ func parseStringStringMap(rd *bufio.Reader, n int64) (interface{}, error) { if err != nil { return nil, err } - key, ok := keyiface.(string) + key, ok := keyiface.([]byte) if !ok { return nil, fmt.Errorf("got %T, expected string", keyiface) } @@ -220,12 +408,12 @@ func parseStringStringMap(rd *bufio.Reader, n int64) (interface{}, error) { if err != nil { return nil, err } - value, ok := valueiface.(string) + value, ok := valueiface.([]byte) if !ok { return nil, fmt.Errorf("got %T, expected string", valueiface) } - m[key] = value + m[string(key)] = string(value) } return m, nil } @@ -237,7 +425,7 @@ func parseStringIntMap(rd *bufio.Reader, n int64) (interface{}, error) { if err != nil { return nil, err } - key, ok := keyiface.(string) + key, ok := keyiface.([]byte) if !ok { return nil, fmt.Errorf("got %T, expected string", keyiface) } @@ -248,15 +436,14 @@ func parseStringIntMap(rd *bufio.Reader, n int64) (interface{}, error) { } switch value := valueiface.(type) { case int64: - m[key] = value + m[string(key)] = value case string: - m[key], err = strconv.ParseInt(value, 10, 64) + m[string(key)], err = strconv.ParseInt(value, 10, 64) if err != nil { return nil, fmt.Errorf("got %v, expected number", value) } default: return nil, fmt.Errorf("got %T, expected number or string", valueiface) - } } return m, nil @@ -271,21 +458,21 @@ func parseZSlice(rd *bufio.Reader, n int64) (interface{}, error) { if err != nil { return nil, err } - member, ok := memberiface.(string) + member, ok := memberiface.([]byte) if !ok { return nil, fmt.Errorf("got %T, expected string", memberiface) } - z.Member = member + z.Member = string(member) scoreiface, err := parseReply(rd, nil) if err != nil { return nil, err } - scorestr, ok := scoreiface.(string) + scoreb, ok := scoreiface.([]byte) if !ok { return nil, fmt.Errorf("got %T, expected string", scoreiface) } - score, err := strconv.ParseFloat(scorestr, 64) + score, err := strconv.ParseFloat(string(scoreb), 64) if err != nil { return nil, err } diff --git a/parser_test.go b/parser_test.go index 1b9e15810..b71305a7a 100644 --- a/parser_test.go +++ b/parser_test.go @@ -47,7 +47,7 @@ func benchmarkParseReply(b *testing.B, reply string, p multiBulkParser, wanterr func BenchmarkAppendArgs(b *testing.B) { buf := make([]byte, 0, 64) - args := []string{"hello", "world", "foo", "bar"} + args := []interface{}{"hello", "world", "foo", "bar"} for i := 0; i < b.N; i++ { appendArgs(buf, args) } diff --git a/pubsub.go b/pubsub.go index 86e4bf6c0..26fa85289 100644 --- a/pubsub.go +++ b/pubsub.go @@ -80,11 +80,11 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { reply := cmd.Val() - msgName := reply[0].(string) - switch msgName { + kind := reply[0].(string) + switch kind { case "subscribe", "unsubscribe", "psubscribe", "punsubscribe": return &Subscription{ - Kind: msgName, + Kind: kind, Channel: reply[1].(string), Count: int(reply[2].(int64)), }, nil @@ -101,7 +101,7 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { }, nil } - return nil, fmt.Errorf("redis: unsupported message name: %q", msgName) + return nil, fmt.Errorf("redis: unsupported pubsub notification: %q", kind) } func (c *PubSub) subscribe(cmd string, channels ...string) error { @@ -110,7 +110,11 @@ func (c *PubSub) subscribe(cmd string, channels ...string) error { return err } - args := append([]string{cmd}, channels...) + args := make([]interface{}, 1+len(channels)) + args[0] = cmd + for i, channel := range channels { + args[1+i] = channel + } req := NewSliceCmd(args...) return cn.writeCmds(req) } @@ -123,21 +127,10 @@ func (c *PubSub) PSubscribe(patterns ...string) error { return c.subscribe("PSUBSCRIBE", patterns...) } -func (c *PubSub) unsubscribe(cmd string, channels ...string) error { - cn, err := c.conn() - if err != nil { - return err - } - - args := append([]string{cmd}, channels...) - req := NewSliceCmd(args...) - return cn.writeCmds(req) -} - func (c *PubSub) Unsubscribe(channels ...string) error { - return c.unsubscribe("UNSUBSCRIBE", channels...) + return c.subscribe("UNSUBSCRIBE", channels...) } func (c *PubSub) PUnsubscribe(patterns ...string) error { - return c.unsubscribe("PUNSUBSCRIBE", patterns...) + return c.subscribe("PUNSUBSCRIBE", patterns...) } diff --git a/redis_test.go b/redis_test.go index d8c4f7aee..ac5ee7c99 100644 --- a/redis_test.go +++ b/redis_test.go @@ -192,30 +192,34 @@ func BenchmarkRedisSet(b *testing.B) { }) } -func BenchmarkRedisSetBytes(b *testing.B) { +func BenchmarkRedisGetNil(b *testing.B) { client := redis.NewClient(&redis.Options{ Addr: benchRedisAddr, }) defer client.Close() - value := bytes.Repeat([]byte{'1'}, 10000) + if err := client.FlushDb().Err(); err != nil { + b.Fatal(err) + } b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { - if err := client.Set("key", string(value), 0).Err(); err != nil { + if err := client.Get("key").Err(); err != redis.Nil { b.Fatal(err) } } }) } -func BenchmarkRedisGetNil(b *testing.B) { +func BenchmarkRedisGet(b *testing.B) { client := redis.NewClient(&redis.Options{ Addr: benchRedisAddr, }) defer client.Close() - if err := client.FlushDb().Err(); err != nil { + + value := bytes.Repeat([]byte{'1'}, 10000) + if err := client.Set("key", value, 0).Err(); err != nil { b.Fatal(err) } @@ -223,29 +227,40 @@ func BenchmarkRedisGetNil(b *testing.B) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { - if err := client.Get("key").Err(); err != redis.Nil { + s, err := client.Get("key").Result() + if err != nil { b.Fatal(err) } + if len(s) != 10000 { + panic("len(s) != 10000") + } } }) } -func BenchmarkRedisGet(b *testing.B) { +func BenchmarkRedisGetSetBytes(b *testing.B) { client := redis.NewClient(&redis.Options{ Addr: benchRedisAddr, }) defer client.Close() - if err := client.Set("key", "hello", 0).Err(); err != nil { - b.Fatal(err) - } + + src := bytes.Repeat([]byte{'1'}, 10000) b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { - if err := client.Get("key").Err(); err != nil { + if err := client.Set("key", src, 0).Err(); err != nil { b.Fatal(err) } + + dst, err := client.Get("key").Bytes() + if err != nil { + b.Fatal(err) + } + if !bytes.Equal(dst, src) { + panic("len(dst) != 10000") + } } }) }