From c30ba58acfc19b96807f72162dcdd913365e2de2 Mon Sep 17 00:00:00 2001 From: "marius a. eriksen" Date: Fri, 6 Jan 2012 09:22:52 -0800 Subject: [PATCH] [split] finagle redis client from Tumblr it's currently unpublished, but fully functional. --- .../twitter/finagle/redis/Exceptions.scala | 4 + .../com/twitter/finagle/redis/Redis.scala | 59 ++ .../finagle/redis/protocol/Codec.scala | 56 ++ .../finagle/redis/protocol/Command.scala | 176 +++++ .../finagle/redis/protocol/Parsers.scala | 51 ++ .../finagle/redis/protocol/Reply.scala | 123 +++ .../protocol/commands/CommandArguments.scala | 99 +++ .../protocol/commands/CommandTypes.scala | 37 + .../redis/protocol/commands/Keys.scala | 133 ++++ .../redis/protocol/commands/SortedSets.scala | 683 +++++++++++++++++ .../redis/protocol/commands/Strings.scala | 290 +++++++ .../finagle/redis/util/Conversions.scala | 64 ++ .../finagle/redis/util/TestServer.scala | 111 +++ .../twitter/finagle/redis/NaggatiSpec.scala | 714 ++++++++++++++++++ .../ClientServerIntegrationSpec.scala | 452 +++++++++++ project/build/Project.scala | 25 + 16 files changed, 3077 insertions(+) create mode 100644 finagle-redis/src/main/scala/com/twitter/finagle/redis/Exceptions.scala create mode 100644 finagle-redis/src/main/scala/com/twitter/finagle/redis/Redis.scala create mode 100644 finagle-redis/src/main/scala/com/twitter/finagle/redis/protocol/Codec.scala create mode 100644 finagle-redis/src/main/scala/com/twitter/finagle/redis/protocol/Command.scala create mode 100644 finagle-redis/src/main/scala/com/twitter/finagle/redis/protocol/Parsers.scala create mode 100644 finagle-redis/src/main/scala/com/twitter/finagle/redis/protocol/Reply.scala create mode 100644 finagle-redis/src/main/scala/com/twitter/finagle/redis/protocol/commands/CommandArguments.scala create mode 100644 finagle-redis/src/main/scala/com/twitter/finagle/redis/protocol/commands/CommandTypes.scala create mode 100644 finagle-redis/src/main/scala/com/twitter/finagle/redis/protocol/commands/Keys.scala create mode 100644 finagle-redis/src/main/scala/com/twitter/finagle/redis/protocol/commands/SortedSets.scala create mode 100644 finagle-redis/src/main/scala/com/twitter/finagle/redis/protocol/commands/Strings.scala create mode 100644 finagle-redis/src/main/scala/com/twitter/finagle/redis/util/Conversions.scala create mode 100644 finagle-redis/src/main/scala/com/twitter/finagle/redis/util/TestServer.scala create mode 100644 finagle-redis/src/test/scala/com/twitter/finagle/redis/NaggatiSpec.scala create mode 100644 finagle-redis/src/test/scala/com/twitter/finagle/redis/integration/ClientServerIntegrationSpec.scala diff --git a/finagle-redis/src/main/scala/com/twitter/finagle/redis/Exceptions.scala b/finagle-redis/src/main/scala/com/twitter/finagle/redis/Exceptions.scala new file mode 100644 index 0000000000..04a004e917 --- /dev/null +++ b/finagle-redis/src/main/scala/com/twitter/finagle/redis/Exceptions.scala @@ -0,0 +1,4 @@ +package com.twitter.finagle.redis + +case class ServerError(message: String) extends Exception(message) +case class ClientError(message: String) extends Exception(message) diff --git a/finagle-redis/src/main/scala/com/twitter/finagle/redis/Redis.scala b/finagle-redis/src/main/scala/com/twitter/finagle/redis/Redis.scala new file mode 100644 index 0000000000..db61a76b00 --- /dev/null +++ b/finagle-redis/src/main/scala/com/twitter/finagle/redis/Redis.scala @@ -0,0 +1,59 @@ +package com.twitter.finagle.redis + +import protocol.{Command, CommandCodec, Reply, ReplyCodec} + +import com.twitter.finagle.{Codec, CodecFactory, Service} +import com.twitter.finagle.tracing.ClientRequestTracingFilter +import com.twitter.naggati.{Codec => NaggatiCodec} +import com.twitter.util.Future +import org.jboss.netty.channel.{ChannelPipelineFactory, Channels} + +object Redis { + def apply() = new Redis + def get() = apply() +} + +class Redis extends CodecFactory[Command, Reply] { + def server = Function.const { + new Codec[Command, Reply] { + def pipelineFactory = new ChannelPipelineFactory { + def getPipeline() = { + val pipeline = Channels.pipeline() + val commandCodec = new CommandCodec + val replyCodec = new ReplyCodec + + pipeline.addLast("codec", new NaggatiCodec(commandCodec.decode, replyCodec.encode)) + + pipeline + } + } + } + } + + def client = Function.const { + new Codec[Command, Reply] { + + def pipelineFactory = new ChannelPipelineFactory { + def getPipeline() = { + val pipeline = Channels.pipeline() + val commandCodec = new CommandCodec + val replyCodec = new ReplyCodec + + pipeline.addLast("codec", new NaggatiCodec(replyCodec.decode, commandCodec.encode)) + + pipeline + } + } + + override def prepareService(underlying: Service[Command, Reply]) = { + Future.value((new RedisTracingFilter()) andThen underlying) + } + + } + } +} + +private class RedisTracingFilter extends ClientRequestTracingFilter[Command, Reply] { + val serviceName = "redis" + def methodName(req: Command): String = req.getClass().getSimpleName() +} diff --git a/finagle-redis/src/main/scala/com/twitter/finagle/redis/protocol/Codec.scala b/finagle-redis/src/main/scala/com/twitter/finagle/redis/protocol/Codec.scala new file mode 100644 index 0000000000..9dac72c14c --- /dev/null +++ b/finagle-redis/src/main/scala/com/twitter/finagle/redis/protocol/Codec.scala @@ -0,0 +1,56 @@ +package com.twitter.finagle.redis +package protocol + +import org.jboss.netty.buffer.{ChannelBuffer, ChannelBuffers} +import util.StringToChannelBuffer +import scala.collection.immutable.WrappedString + +private[redis] object RedisCodec { + object NilValue extends WrappedString("nil") { + def getBytes(charset: String = "UTF-8") = Array[Byte]() + def getBytes = Array[Byte]() + } + + val STATUS_REPLY = '+' + val ERROR_REPLY = '-' + val INTEGER_REPLY = ':' + val BULK_REPLY = '$' + val MBULK_REPLY = '*' + + val ARG_COUNT_MARKER = '*' + val ARG_SIZE_MARKER = '$' + + val TOKEN_DELIMITER = ' ' + val EOL_DELIMITER = "\r\n" + + val NIL_VALUE = NilValue + val NIL_VALUE_BA = NilValue.getBytes + + def toUnifiedFormat(args: List[Array[Byte]], includeHeader: Boolean = true) = { + val buffer = ChannelBuffers.dynamicBuffer() + includeHeader match { + case true => + val argHeader = "%c%d%s".format(ARG_COUNT_MARKER, args.length, EOL_DELIMITER) + buffer.writeBytes(argHeader.getBytes) + case false => + } + args.foreach { arg => + if (arg.length == 0) { + buffer.writeBytes("%c-1%s".format(ARG_SIZE_MARKER, EOL_DELIMITER).getBytes) + } else { + val sizeHeader = "%c%d%s".format(ARG_SIZE_MARKER, arg.length, EOL_DELIMITER) + buffer.writeBytes(sizeHeader.getBytes) + buffer.writeBytes(arg) + buffer.writeBytes(EOL_DELIMITER.getBytes) + } + } + buffer + } + def toInlineFormat(args: List[String]) = { + StringToChannelBuffer(args.mkString(TOKEN_DELIMITER.toString) + EOL_DELIMITER) + } +} +abstract class RedisMessage { + def toChannelBuffer: ChannelBuffer + def toByteArray: Array[Byte] = toChannelBuffer.array +} diff --git a/finagle-redis/src/main/scala/com/twitter/finagle/redis/protocol/Command.scala b/finagle-redis/src/main/scala/com/twitter/finagle/redis/protocol/Command.scala new file mode 100644 index 0000000000..083fcafe12 --- /dev/null +++ b/finagle-redis/src/main/scala/com/twitter/finagle/redis/protocol/Command.scala @@ -0,0 +1,176 @@ +package com.twitter.finagle.redis +package protocol + +import util._ + +object RequireClientProtocol extends ErrorConversion { + override def getException(msg: String) = new ClientError(msg) +} + +abstract class Command extends RedisMessage + +object Commands { + // Key Commands + val DEL = "DEL" + val EXISTS = "EXISTS" + val EXPIRE = "EXPIRE" + val EXPIREAT = "EXPIREAT" + val KEYS = "KEYS" + val PERSIST = "PERSIST" + val RANDOMKEY = "RANDOMKEY" + val RENAME = "RENAME" + val RENAMENX = "RENAMENX" + val TTL = "TTL" + val TYPE = "TYPE" + + // String Commands + val APPEND = "APPEND" + val DECR = "DECR" + val DECRBY = "DECRBY" + val GET = "GET" + val GETBIT = "GETBIT" + val GETRANGE = "GETRANGE" + val GETSET = "GETSET" + val INCR = "INCR" + val INCRBY = "INCRBY" + val MGET = "MGET" + val MSET = "MSET" + val MSETNX = "MSETNX" + val SET = "SET" + val SETBIT = "SETBIT" + val SETEX = "SETEX" + val SETNX = "SETNX" + val SETRANGE = "SETRANGE" + val STRLEN = "STRLEN" + + // Sorted Sets + val ZADD = "ZADD" + val ZCARD = "ZCARD" + val ZCOUNT = "ZCOUNT" + val ZINCRBY = "ZINCRBY" + val ZINTERSTORE = "ZINTERSTORE" + val ZRANGE = "ZRANGE" + val ZRANGEBYSCORE = "ZRANGEBYSCORE" + val ZRANK = "ZRANK" + val ZREM = "ZREM" + val ZREMRANGEBYRANK = "ZREMRANGEBYRANK" + val ZREMRANGEBYSCORE = "ZREMRANGEBYSCORE" + val ZREVRANGE = "ZREVRANGE" + val ZREVRANGEBYSCORE = "ZREVRANGEBYSCORE" + val ZREVRANK = "ZREVRANK" + val ZSCORE = "ZSCORE" + val ZUNIONSTORE = "ZUNIONSTORE" + + val commandMap: Map[String,Function1[List[Array[Byte]],Command]] = Map( + // key commands + DEL -> {args => Del(BytesToString.fromList(args))}, + EXISTS -> {Exists(_)}, + EXPIRE -> {Expire(_)}, + EXPIREAT -> {ExpireAt(_)}, + KEYS -> {Keys(_)}, + PERSIST -> {Persist(_)}, + RANDOMKEY -> {args => Randomkey()}, + RENAME -> {Rename(_)}, + RENAMENX -> {RenameNx(_)}, + TTL -> {Ttl(_)}, + TYPE -> {Type(_)}, + + // string commands + APPEND -> {Append(_)}, + DECR -> {Decr(_)}, + DECRBY -> {DecrBy(_)}, + GET -> {Get(_)}, + GETBIT -> {GetBit(_)}, + GETRANGE -> {GetRange(_)}, + GETSET -> {GetSet(_)}, + INCR -> {Incr(_)}, + INCRBY -> {IncrBy(_)}, + MGET -> {args => MGet(BytesToString.fromList(args))}, + MSET -> {MSet(_)}, + MSETNX -> {MSetNx(_)}, + SET -> {Set(_)}, + SETBIT -> {SetBit(_)}, + SETEX -> {SetEx(_)}, + SETNX -> {SetNx(_)}, + SETRANGE -> {SetRange(_)}, + STRLEN -> {Strlen(_)}, + + // sorted sets + ZADD -> {ZAdd(_)}, + ZCARD -> {ZCard(_)}, + ZCOUNT -> {ZCount(_)}, + ZINCRBY -> {ZIncrBy(_)}, + ZINTERSTORE -> {ZInterStore(_)}, + ZRANGE -> {ZRange(_)}, + ZRANGEBYSCORE -> {ZRangeByScore(_)}, + ZRANK -> {ZRank(_)}, + ZREM -> {ZRem(_)}, + ZREMRANGEBYRANK -> {ZRemRangeByRank(_)}, + ZREMRANGEBYSCORE -> {ZRemRangeByScore(_)}, + ZREVRANGE -> {ZRevRange(_)}, + ZREVRANGEBYSCORE -> {ZRevRangeByScore(_)}, + ZREVRANK -> {ZRevRank(_)}, + ZSCORE -> {ZScore(_)}, + ZUNIONSTORE -> {ZUnionStore(_)} + ) + + def doMatch(cmd: String, args: List[Array[Byte]]) = commandMap.get(cmd).map { + _(args) + }.getOrElse(throw ClientError("Unsupported command: " + cmd)) + + def trimList(list: List[Array[Byte]], count: Int, from: String = "") = { + RequireClientProtocol(list != null, "%s Empty list found".format(from)) + RequireClientProtocol( + list.length == count, + "%s Expected %d elements, found %d".format(from, count, list.length)) + val newList = list.take(count) + newList.foreach { item => RequireClientProtocol(item != null, "Found empty item in list") } + newList + } +} + +class CommandCodec extends UnifiedProtocolCodec { + import com.twitter.naggati.{Emit, Encoder, NextStep} + import com.twitter.naggati.Stages._ + import RedisCodec._ + import com.twitter.logging.Logger + + val log = Logger(getClass) + + val decode = readBytes(1) { bytes => + bytes(0) match { + case ARG_COUNT_MARKER => + val doneFn = { lines => commandDecode(lines) } + RequireClientProtocol.safe { + readLine { line => decodeUnifiedFormat(NumberFormat.toLong(line), doneFn) } + } + case b: Byte => + decodeInlineRequest(b.asInstanceOf[Char]) + } + } + + val encode = new Encoder[Command] { + def encode(obj: Command) = Some(obj.toChannelBuffer) + } + + def decodeInlineRequest(c: Char) = readLine { line => + val listOfArrays = (c + line).split(' ').toList.map { args => args.getBytes("UTF-8") } + val cmd = commandDecode(listOfArrays) + emit(cmd) + } + + def commandDecode(lines: List[Array[Byte]]): Command = { + RequireClientProtocol(lines != null && lines.length > 0, "Invalid client command protocol") + val cmd = new String(lines.head) + val args = lines.tail + try { + Commands.doMatch(cmd, args) + } catch { + case e: ClientError => throw e + case t: Throwable => + log.warning(t, "Unhandled exception %s(%s)".format(t.getClass.toString, t.getMessage)) + throw new ClientError(t.getMessage) + } + } + +} diff --git a/finagle-redis/src/main/scala/com/twitter/finagle/redis/protocol/Parsers.scala b/finagle-redis/src/main/scala/com/twitter/finagle/redis/protocol/Parsers.scala new file mode 100644 index 0000000000..91c3e797b4 --- /dev/null +++ b/finagle-redis/src/main/scala/com/twitter/finagle/redis/protocol/Parsers.scala @@ -0,0 +1,51 @@ +package com.twitter.finagle.redis +package protocol + +import util._ +import RedisCodec._ + +import com.twitter.naggati.{Emit, Encoder, NextStep, ProtocolError} +import com.twitter.naggati.Stages._ + +trait UnifiedProtocolCodec { + + type ByteArrays = List[Array[Byte]] + + def decodeUnifiedFormat[T <: AnyRef](argCount: Long, doneFn: ByteArrays => T) = + argCount match { + case n if n < 0 => throw new ProtocolError("Invalid argument count specified") + case n => decodeRequestLines(n, Nil, { lines => doneFn(lines) } ) + } + + def decodeRequestLines[T <: AnyRef]( + i: Long, + lines: ByteArrays, + doneFn: ByteArrays => T): NextStep = + { + if (i <= 0) { + emit(doneFn(lines.reverse)) + } else { + readLine { line => + val header = line(0) + header match { + case ARG_SIZE_MARKER => + val size = NumberFormat.toInt(line.drop(1)) + if (size < 1) { + decodeRequestLines(i - 1, lines.+:(RedisCodec.NIL_VALUE_BA), doneFn) + } else { + readBytes(size) { byteArray => + readBytes(2) { eol => + if (eol(0) != '\r' || eol(1) != '\n') { + throw new ProtocolError("Expected EOL after line data and didn't find it") + } + decodeRequestLines(i - 1, lines.+:(byteArray), doneFn) + } + } + } + case b: Char => + throw new ProtocolError("Expected size marker $, got " + b) + } // header match + } // readLine + } // else + } // decodeRequestLines +} diff --git a/finagle-redis/src/main/scala/com/twitter/finagle/redis/protocol/Reply.scala b/finagle-redis/src/main/scala/com/twitter/finagle/redis/protocol/Reply.scala new file mode 100644 index 0000000000..b7df52a0b3 --- /dev/null +++ b/finagle-redis/src/main/scala/com/twitter/finagle/redis/protocol/Reply.scala @@ -0,0 +1,123 @@ +package com.twitter.finagle.redis +package protocol + +import util._ +import org.jboss.netty.buffer.{ChannelBuffer, ChannelBuffers} + +object RequireServerProtocol extends ErrorConversion { + override def getException(msg: String) = new ServerError(msg) +} + +sealed abstract class Reply extends RedisMessage +sealed abstract class SingleLineReply extends Reply { // starts with +,-, or : + import RedisCodec.EOL_DELIMITER + + def getMessageTuple(): (Char,String) + override def toChannelBuffer = { + val (c,s) = getMessageTuple + StringToChannelBuffer("%c%s%s".format(c,s,EOL_DELIMITER)) + } +} +sealed abstract class MultiLineReply extends Reply + +case class StatusReply(message: String) extends SingleLineReply { + RequireServerProtocol(message != null && message.length > 0, "StatusReply had empty message") + override def getMessageTuple() = (RedisCodec.STATUS_REPLY, message) +} +case class ErrorReply(message: String) extends SingleLineReply { + RequireServerProtocol(message != null && message.length > 0, "ErrorReply had empty message") + override def getMessageTuple() = (RedisCodec.ERROR_REPLY, message) +} +case class IntegerReply(id: Int) extends SingleLineReply { + override def getMessageTuple() = (RedisCodec.INTEGER_REPLY, id.toString) +} + +case class BulkReply(message: Array[Byte]) extends MultiLineReply { + RequireServerProtocol(message != null && message.length > 0, "BulkReply had empty message") + + import RedisCodec.{ARG_SIZE_MARKER, EOL_DELIMITER} + + override def toChannelBuffer = { + val mlen = message.length + val exlen = 1 + 2 + 2 // 1 byte for marker, 2 bytes for first EOL, 2 for second + val header = "%c%d%s".format(ARG_SIZE_MARKER, mlen, EOL_DELIMITER) + val buffer = ChannelBuffers.dynamicBuffer(mlen + exlen) + buffer.writeBytes(header.getBytes) + buffer.writeBytes(message) + buffer.writeBytes(EOL_DELIMITER.getBytes) + buffer + } +} +case class EmptyBulkReply() extends MultiLineReply { + val message = RedisCodec.NIL_VALUE + val messageBytes = StringToBytes(message) + override def toChannelBuffer = RedisCodec.toInlineFormat(List("$-1")) +} + +case class MBulkReply(messages: List[Array[Byte]]) extends MultiLineReply { + RequireServerProtocol( + messages != null && messages.length > 0, + "Multi-BulkReply had empty message list") + override def toChannelBuffer = RedisCodec.toUnifiedFormat(messages) +} +case class EmptyMBulkReply() extends MultiLineReply { + val message = RedisCodec.NIL_VALUE + val messageBytes = StringToBytes(message) + override def toChannelBuffer = RedisCodec.toInlineFormat(List("*0")) +} + +class ReplyCodec extends UnifiedProtocolCodec { + import com.twitter.naggati.{Emit, Encoder, NextStep} + import com.twitter.naggati.Stages._ + import RedisCodec._ + + val decode = readBytes(1) { bytes => + bytes(0) match { + case STATUS_REPLY => + readLine { line => emit(StatusReply(line)) } + case ERROR_REPLY => + readLine { line => emit(ErrorReply(line)) } + case INTEGER_REPLY => + readLine { line => + RequireServerProtocol.safe { + emit(IntegerReply(NumberFormat.toInt(line))) + } + } + case BULK_REPLY => + decodeBulkReply + case MBULK_REPLY => + val doneFn = { lines: List[Array[Byte]] => + lines.length match { + case empty if empty == 0 => EmptyMBulkReply() + case n => MBulkReply(lines) + } + } + RequireServerProtocol.safe { + readLine { line => decodeUnifiedFormat(NumberFormat.toLong(line), doneFn) } + } + case b: Byte => + throw new ServerError("Unknown response format(%c) found".format(b.asInstanceOf[Char])) + } + } + + val encode = new Encoder[Reply] { + def encode(obj: Reply) = Some(obj.toChannelBuffer) + } + + def decodeBulkReply = readLine { line => + RequireServerProtocol.safe { + NumberFormat.toInt(line) + } match { + case empty if empty < 1 => emit(EmptyBulkReply()) + case replySz => readBytes(replySz) { bytes => + readBytes(2) { eol => + if (eol(0) != '\r' || eol(1) != '\n') { + throw new ServerError("Expected EOL after line data and didn't find it") + } + emit(BulkReply(bytes)) + } //readBytes + } // readBytes + } // match + } // decodeBulkReply + +} diff --git a/finagle-redis/src/main/scala/com/twitter/finagle/redis/protocol/commands/CommandArguments.scala b/finagle-redis/src/main/scala/com/twitter/finagle/redis/protocol/commands/CommandArguments.scala new file mode 100644 index 0000000000..a653a15fcc --- /dev/null +++ b/finagle-redis/src/main/scala/com/twitter/finagle/redis/protocol/commands/CommandArguments.scala @@ -0,0 +1,99 @@ +package com.twitter.finagle.redis +package protocol + +import util._ + +trait CommandArgument extends Command { + override def toChannelBuffer = + throw new UnsupportedOperationException("OptionCommand does not support toChannelBuffer") +} + +// Constant case object representing WITHSCORES command arg +case object WithScores extends CommandArgument { + val WITHSCORES = "WITHSCORES" + override def toString = WITHSCORES + def unapply(s: String) = s.toUpperCase match { + case WITHSCORES => Some(s) + case _ => None + } +} + +case class Limit(offset: Int, count: Int) extends CommandArgument { + override def toString = "%s %d %d".format(Limit.LIMIT, offset, count) +} +object Limit { + val LIMIT = "LIMIT" + def apply(args: List[String]) = { + RequireClientProtocol(args != null && args.length == 3, "LIMIT requires two arguments") + RequireClientProtocol(args.head == LIMIT, "LIMIT must start with LIMIT clause") + RequireClientProtocol.safe { + val offset = NumberFormat.toInt(args(1)) + val count = NumberFormat.toInt(args(2)) + new Limit(offset, count) + } + } +} + +// Represents a list of WEIGHTS +class Weights(underlying: Vector[Float]) extends CommandArgument with IndexedSeq[Float] { + override def apply(idx: Int) = underlying(idx) + override def length = underlying.length + override def toString = Weights.toString + " " + this.mkString(" ") +} + +// Handles parsing and manipulation of WEIGHTS arguments +object Weights { + val WEIGHTS = "WEIGHTS" + + def apply(weight: Float) = new Weights(Vector(weight)) + def apply(weights: Float*) = new Weights(Vector(weights:_*)) + def apply(weights: Vector[Float]) = new Weights(weights) + + def apply(args: List[String]): Option[Weights] = { + val argLength = args.length + RequireClientProtocol( + args != null && argLength > 0, + "WEIGHTS can not be specified with an empty list") + args.head.toUpperCase match { + case WEIGHTS => + RequireClientProtocol(argLength > 1, "WEIGHTS requires additional arguments") + val weights: Vector[Float] = RequireClientProtocol.safe { + args.tail.map { item => NumberFormat.toFloat(item) }(collection.breakOut) + } + Some(new Weights(weights)) + case _ => None + } + } + override def toString = Weights.WEIGHTS +} + +// Handles parsing and manipulation of AGGREGATE arguments +sealed abstract class Aggregate(val name: String) { + override def toString = Aggregate.toString + " " + name.toUpperCase + def equals(str: String) = str.equals(name) +} +object Aggregate { + val AGGREGATE = "AGGREGATE" + case object Sum extends Aggregate("SUM") + case object Min extends Aggregate("MIN") + case object Max extends Aggregate("MAX") + override def toString = AGGREGATE + + def apply(args: List[String]): Option[Aggregate] = { + val argLength = args.length + RequireClientProtocol( + args != null && argLength > 0, + "AGGREGATE can not be specified with empty list") + args.head.toUpperCase match { + case AGGREGATE => + RequireClientProtocol(argLength == 2, "AGGREGATE requires a type (MIN, MAX, SUM)") + args(1).toUpperCase match { + case Aggregate.Sum.name => Some(Aggregate.Sum) + case Aggregate.Max.name => Some(Aggregate.Max) + case Aggregate.Min.name => Some(Aggregate.Min) + case _ => throw new ClientError("AGGREGATE type must be one of MIN, MAX or SUM") + } + case _ => None + } + } +} diff --git a/finagle-redis/src/main/scala/com/twitter/finagle/redis/protocol/commands/CommandTypes.scala b/finagle-redis/src/main/scala/com/twitter/finagle/redis/protocol/commands/CommandTypes.scala new file mode 100644 index 0000000000..ba1f8b6c4e --- /dev/null +++ b/finagle-redis/src/main/scala/com/twitter/finagle/redis/protocol/commands/CommandTypes.scala @@ -0,0 +1,37 @@ +package com.twitter.finagle.redis +package protocol + +trait KeyCommand extends Command { + val key: String + protected def validate() { + RequireClientProtocol(key != null && key.length > 0, "Empty Key found") + } +} +trait StrictKeyCommand extends KeyCommand { + validate() +} + +trait KeysCommand extends Command { + val keys: List[String] + protected def validate() { + RequireClientProtocol(keys != null && keys.length > 0, "Empty KeySet found") + keys.foreach { key => RequireClientProtocol(key != null && key.length > 0, "Empty key found") } + } +} +trait StrictKeysCommand extends KeysCommand { + validate() +} + +trait ValueCommand extends Command { + val value: Array[Byte] +} +trait StrictValueCommand extends ValueCommand { + RequireClientProtocol(value != null && value.length > 0, "Found unexpected empty value") +} + +trait MemberCommand extends Command { + val member: Array[Byte] +} +trait StrictMemberCommand extends MemberCommand { + RequireClientProtocol(member != null && member.length > 0, "Found unexpected empty set member") +} diff --git a/finagle-redis/src/main/scala/com/twitter/finagle/redis/protocol/commands/Keys.scala b/finagle-redis/src/main/scala/com/twitter/finagle/redis/protocol/commands/Keys.scala new file mode 100644 index 0000000000..2cee2f6961 --- /dev/null +++ b/finagle-redis/src/main/scala/com/twitter/finagle/redis/protocol/commands/Keys.scala @@ -0,0 +1,133 @@ +package com.twitter.finagle.redis +package protocol + +import util._ + +import com.twitter.conversions.time._ +import com.twitter.util.Time + +import Commands.trimList + +/** + * TODO + * - EVAL + * - MOVE + * - OBJECT + * - SORT + */ + +case class Del(keys: List[String]) extends StrictKeysCommand { + override def toChannelBuffer = RedisCodec.toInlineFormat(Commands.DEL +: keys) +} +object Del { + def apply(key: String) = new Del(List(key)) +} + +case class Exists(key: String) extends StrictKeyCommand { + override def toChannelBuffer = RedisCodec.toInlineFormat(List(Commands.EXISTS, key)) +} +object Exists { + def apply(args: List[Array[Byte]]) = { + val list = trimList(args, 1, "EXISTS") + new Exists(BytesToString(list(0))) + } +} + +case class Expire(key: String, seconds: Long) extends StrictKeyCommand { + RequireClientProtocol(seconds > 0, "Seconds must be greater than 0") + override def toChannelBuffer = + RedisCodec.toInlineFormat(List(Commands.EXPIRE, key, seconds.toString)) +} +object Expire { + def apply(args: List[Array[Byte]]) = { + val list = trimList(args, 2, "EXPIRE") + RequireClientProtocol.safe { + new Expire(BytesToString(list(0)), NumberFormat.toLong(BytesToString(list(1)))) + } + } +} + +case class ExpireAt(key: String, timestamp: Time) extends StrictKeyCommand { + RequireClientProtocol( + timestamp != null && timestamp > Time.now, + "Timestamp must be in the future") + + val seconds = timestamp.inSeconds + + override def toChannelBuffer = + RedisCodec.toInlineFormat(List(Commands.EXPIREAT, key, seconds.toString)) +} +object ExpireAt { + def apply(args: List[Array[Byte]]) = { + val list = trimList(args, 2, "EXPIREAT") + val secondsString = BytesToString(list(1)) + val seconds = RequireClientProtocol.safe { + Time.fromSeconds(NumberFormat.toInt(secondsString)) + } + new ExpireAt(BytesToString(list(0)), seconds) + } +} + +case class Keys(pattern: String) extends Command { + RequireClientProtocol(pattern != null && pattern.length > 0, "Pattern must be specified") + override def toChannelBuffer = RedisCodec.toInlineFormat(List(Commands.KEYS, pattern)) +} +object Keys { + def apply(args: List[Array[Byte]]) = new Keys(BytesToString.fromList(args).mkString) +} + +case class Persist(key: String) extends StrictKeyCommand { + override def toChannelBuffer = RedisCodec.toInlineFormat(List(Commands.PERSIST, key)) +} +object Persist { + def apply(args: List[Array[Byte]]) = { + val list = trimList(args, 1, "PERSIST") + new Persist(BytesToString(list(0))) + } +} + +case class Randomkey() extends Command { + override def toChannelBuffer = RedisCodec.toInlineFormat(List(Commands.RANDOMKEY)) +} + +case class Rename(key: String, newkey: String) extends StrictKeyCommand { + RequireClientProtocol(newkey != null && newkey.length > 0, "New key must not be empty") + override def toChannelBuffer = RedisCodec.toInlineFormat(List(Commands.RENAME, key, newkey)) +} +object Rename { + def apply(args: List[Array[Byte]]) = { + val list = trimList(args, 2, "RENAME") + new Rename(BytesToString(list(0)), BytesToString(list(1))) + } +} + +case class RenameNx(key: String, newkey: String) extends StrictKeyCommand { + RequireClientProtocol(newkey != null && newkey.length > 0, "New key must not be empty") + override def toChannelBuffer = RedisCodec.toInlineFormat(List(Commands.RENAMENX, key, newkey)) +} +object RenameNx { + def apply(args: List[Array[Byte]]) = { + val list = trimList(args, 2, "RENAMENX") + new RenameNx(BytesToString(list(0)), BytesToString(list(1))) + } +} + +case class Ttl(key: String) extends StrictKeyCommand { + override def toChannelBuffer = RedisCodec.toInlineFormat(List(Commands.TTL, key)) +} +object Ttl { + def apply(args: List[Array[Byte]]) = { + val list = trimList(args, 1, "TTL") + new Ttl(BytesToString(list(0))) + } +} + +case class Type(key: String) extends StrictKeyCommand { + override def toChannelBuffer = RedisCodec.toInlineFormat(List(Commands.TYPE, key)) +} +object Type { + def apply(args: List[Array[Byte]]) = { + val list = trimList(args, 1, "TYPE") + new Type(BytesToString(list(0))) + } +} diff --git a/finagle-redis/src/main/scala/com/twitter/finagle/redis/protocol/commands/SortedSets.scala b/finagle-redis/src/main/scala/com/twitter/finagle/redis/protocol/commands/SortedSets.scala new file mode 100644 index 0000000000..5d480099fa --- /dev/null +++ b/finagle-redis/src/main/scala/com/twitter/finagle/redis/protocol/commands/SortedSets.scala @@ -0,0 +1,683 @@ +package com.twitter.finagle.redis +package protocol + +import util._ +import Commands.trimList + +case class ZAdd(key: String, members: List[ZMember]) + extends StrictKeyCommand + with StrictZMembersCommand +{ + override def toChannelBuffer = { + val cmds = StringToBytes.fromList(List(Commands.ZADD, key)) + RedisCodec.toUnifiedFormat(cmds ::: membersByteArray) + } +} +object ZAdd { + def apply(args: List[Array[Byte]]) = args match { + case head :: tail => + new ZAdd(BytesToString(head), ZMembers(tail)) + case _ => + throw ClientError("Invalid use of ZADD") + } + def apply(key: String, member: ZMember) = new ZAdd(key, List(member)) +} + + +case class ZCard(key: String) extends StrictKeyCommand { + override def toChannelBuffer = RedisCodec.toInlineFormat(List(Commands.ZCARD, key)) +} +object ZCard { + def apply(args: List[Array[Byte]]) = { + val list = BytesToString.fromList(trimList(args, 1, "ZCARD")) + new ZCard(list(0)) + } +} + + +case class ZCount(key: String, min: ZInterval, max: ZInterval) extends StrictKeyCommand { + override def toChannelBuffer = + RedisCodec.toInlineFormat(List(Commands.ZCOUNT, key, min.toString, max.toString)) +} +object ZCount { + def apply(args: List[Array[Byte]]) = { + val list = BytesToString.fromList(trimList(args, 3, "ZCOUNT")) + new ZCount(list(0), ZInterval(list(1)), ZInterval(list(2))) + } +} + + +case class ZIncrBy(key: String, amount: Float, member: Array[Byte]) + extends StrictKeyCommand + with StrictMemberCommand +{ + override def toChannelBuffer = + RedisCodec.toUnifiedFormat(List( + StringToBytes(Commands.ZINCRBY), + StringToBytes(key), + StringToBytes(amount.toString), + member)) +} +object ZIncrBy { + def apply(args: List[Array[Byte]]) = { + val list = trimList(args, 3, "ZINCRBY") + val key = BytesToString(list(0)) + val amount = RequireClientProtocol.safe { + NumberFormat.toFloat(BytesToString(list(1))) + } + new ZIncrBy(key, amount, list(2)) + } +} + + +case class ZInterStore( + destination: String, + numkeys: Int, + keys: List[String], + weights: Option[Weights] = None, + aggregate: Option[Aggregate] = None) + extends ZStore +{ + val command = Commands.ZINTERSTORE + validate() +} +object ZInterStore extends ZStoreCompanion { + def get( + dest: String, + numkeys: Int, + keys: List[String], + weights: Option[Weights], + agg: Option[Aggregate]) = new ZInterStore(dest, numkeys, keys, weights, agg) +} + + +case class ZRange(key: String, start: Int, stop: Int, withScores: Option[CommandArgument] = None) + extends ZRangeCmd +{ + val command = Commands.ZRANGE +} +object ZRange extends ZRangeCmdCompanion { + override def get(key: String, start: Int, stop: Int, withScores: Option[CommandArgument]) = + new ZRange(key, start, stop, withScores) +} + + +case class ZRangeByScore( + key: String, + min: ZInterval, + max: ZInterval, + withScores: Option[CommandArgument] = None, + limit: Option[Limit] = None) + extends ZScoredRange +{ + val command = Commands.ZRANGEBYSCORE + validate() +} +object ZRangeByScore extends ZScoredRangeCompanion { + def get( + key: String, + min: ZInterval, + max: ZInterval, + withScores: Option[CommandArgument], + limit: Option[Limit]): ZScoredRange = new ZRangeByScore(key, min, max, withScores, limit) +} + + +case class ZRank(key: String, member: Array[Byte]) extends ZRankCmd { + val command = Commands.ZRANK +} +object ZRank extends ZRankCmdCompanion { + def get(key: String, member: Array[Byte]) = new ZRank(key, member) +} + + +case class ZRem(key: String, members: List[Array[Byte]]) extends StrictKeyCommand { + RequireClientProtocol( + members != null && members.length > 0, + "Members list must not be empty for ZREM") + + override def toChannelBuffer = { + RedisCodec.toUnifiedFormat( + List(StringToBytes(Commands.ZREM), StringToBytes(key)) ::: members + ) + } +} +object ZRem { + def apply(args: List[Array[Byte]]) = { + RequireClientProtocol(args != null && args.length > 1, "ZREM requires at least one member") + val key = BytesToString(args(0)) + val remaining = args.drop(1) + new ZRem(key, remaining) + } +} + + +case class ZRemRangeByRank(key: String, start: Int, stop: Int) extends StrictKeyCommand { + override def toChannelBuffer = RedisCodec.toUnifiedFormat(StringToBytes.fromList(List( + Commands.ZREMRANGEBYRANK, + key, + start.toString, + stop.toString))) +} +object ZRemRangeByRank { + def apply(args: List[Array[Byte]]) = { + val list = BytesToString.fromList(trimList(args, 3, "ZREMRANGEBYRANK requires 3 arguments")) + val key = list(0) + val start = RequireClientProtocol.safe { NumberFormat.toInt(list(1)) } + val stop = RequireClientProtocol.safe { NumberFormat.toInt(list(2)) } + new ZRemRangeByRank(key, start, stop) + } +} + + +case class ZRemRangeByScore(key: String, min: ZInterval, max: ZInterval) extends StrictKeyCommand { + override def toChannelBuffer = RedisCodec.toUnifiedFormat(StringToBytes.fromList(List( + Commands.ZREMRANGEBYSCORE, + key, + min.toString, + max.toString))) +} +object ZRemRangeByScore { + def apply(args: List[Array[Byte]]) = { + val list = BytesToString.fromList(trimList(args, 3, "ZREMRANGEBYSCORE requires 3 arguments")) + val key = list(0) + val min = ZInterval(list(1)) + val max = ZInterval(list(2)) + new ZRemRangeByScore(key, min, max) + } +} + + +case class ZRevRange( + key: String, + start: Int, + stop: Int, + withScores: Option[CommandArgument] = None) + extends ZRangeCmd +{ + val command = Commands.ZREVRANGE +} +object ZRevRange extends ZRangeCmdCompanion { + override def get(key: String, start: Int, stop: Int, withScores: Option[CommandArgument]) = + new ZRevRange(key, start, stop, withScores) +} + + +case class ZRevRangeByScore( + key: String, + max: ZInterval, + min: ZInterval, + withScores: Option[CommandArgument] = None, + limit: Option[Limit] = None) + extends ZScoredRange +{ + val command = Commands.ZREVRANGEBYSCORE + validate() +} +object ZRevRangeByScore extends ZScoredRangeCompanion { + def get( + key: String, + max: ZInterval, + min: ZInterval, + withScores: Option[CommandArgument], + limit: Option[Limit]): ZScoredRange = new ZRevRangeByScore(key, max, min, withScores, limit) +} + + +case class ZRevRank(key: String, member: Array[Byte]) extends ZRankCmd { + val command = Commands.ZREVRANK +} +object ZRevRank extends ZRankCmdCompanion { + def get(key: String, member: Array[Byte]) = new ZRevRank(key, member) +} + + +case class ZScore(key: String, member: Array[Byte]) + extends StrictKeyCommand + with StrictMemberCommand +{ + override def toChannelBuffer = RedisCodec.toUnifiedFormat(List( + StringToBytes(Commands.ZSCORE), + StringToBytes(key), + member)) +} +object ZScore { + def apply(args: List[Array[Byte]]) = { + val list = trimList(args, 2, "ZSCORE") + new ZScore(BytesToString(args(0)), args(1)) + } +} + + +case class ZUnionStore( + destination: String, + numkeys: Int, + keys: List[String], + weights: Option[Weights] = None, + aggregate: Option[Aggregate] = None) + extends ZStore +{ + val command = Commands.ZUNIONSTORE + validate() +} +object ZUnionStore extends ZStoreCompanion { + def get( + dest: String, + numkeys: Int, + keys: List[String], + weights: Option[Weights], + agg: Option[Aggregate]) = new ZUnionStore(dest, numkeys, keys, weights, agg) +} + +/** + * Internal Helpers + */ + +// Represents part of an interval, helpers in companion object +case class ZInterval(value: String) { + import ZInterval._ + private val representation = value.toLowerCase match { + case N_INF => N_INF + case P_INF => P_INF + case float => float.head match { + case EXCLUSIVE => RequireClientProtocol.safe { + NumberFormat.toFloat(float.tail) + float + } + case f => RequireClientProtocol.safe { + NumberFormat.toFloat(value) + float + } + } + } + override def toString = representation +} +object ZInterval { + private val P_INF = "+inf" + private val N_INF = "-inf" + private val EXCLUSIVE = '(' + + val MAX = new ZInterval(P_INF) + val MIN = new ZInterval(N_INF) + def apply(float: Float) = new ZInterval(float.toString) + def apply(v: Array[Byte]) = new ZInterval(BytesToString(v)) + def exclusive(float: Float) = new ZInterval("%c%s".format(EXCLUSIVE, float.toString)) +} + + +case class ZMember(score: Float, member: Array[Byte]) + extends StrictScoreCommand + with StrictMemberCommand +{ + override def toChannelBuffer = + throw new UnsupportedOperationException("ZMember doesn't support toChannelBuffer") +} + + +sealed trait ScoreCommand extends Command { + val score: Float +} +sealed trait StrictScoreCommand extends ScoreCommand { +} + + +sealed trait ZMembersCommand { + val members: List[ZMember] +} +sealed trait StrictZMembersCommand extends ZMembersCommand { + RequireClientProtocol(members != null && members.length > 0, "Members set must not be empty") + members.foreach { member => + RequireClientProtocol(member != null, "Empty member found") + } + def membersByteArray: List[Array[Byte]] = { + members.map { member => + List( + StringToBytes(member.score.toString), + member.member + ) + }.flatten + } +} + + +object ZMembers { + def apply(args: List[Array[Byte]]): List[ZMember] = { + val size = args.length + RequireClientProtocol(size % 2 == 0 && size > 0, "Unexpected uneven pair of elements") + + args.grouped(2).map { + case score :: member :: Nil => + ZMember( + RequireClientProtocol.safe { + NumberFormat.toFloat(BytesToString(score)) + }, + member) + case _ => + throw ClientError("Unexpected uneven pair of elements in members") + }.toList + } +} + + +abstract class ZStore extends KeysCommand { + val command: String + val destination: String + val numkeys: Int + val keys: List[String] + val weights: Option[Weights] + val aggregate: Option[Aggregate] + + override protected def validate() { + super.validate() + RequireClientProtocol( + destination != null && destination.length > 0, + "destination must not be empty") + RequireClientProtocol(numkeys > 0, "numkeys must be > 0") + RequireClientProtocol(keys.size == numkeys, "must supply the same number of keys as numkeys") + // ensure if weights are specified they are equal to the size of numkeys + weights match { + case Some(list) => + RequireClientProtocol( + list.size == numkeys, + "If WEIGHTS specified, numkeys weights required") + case None => + } + } + + override def toChannelBuffer = { + // FIXME + var args = List(command, destination, numkeys.toString) ::: keys + weights match { + case Some(wlist) => args = args :+ wlist.toString + case None => + } + aggregate match { + case Some(agg) => args = args :+ agg.toString + case None => + } + RedisCodec.toInlineFormat(args) + } +} +trait ZStoreCompanion { + def apply(dest: String, keys: List[String]) = get(dest, keys.length, keys, None, None) + def apply(dest: String, keys: List[String], weights: Weights) = { + get(dest, keys.length, keys, Some(weights), None) + } + def apply(dest: String, keys: List[String], agg: Aggregate) = + get(dest, keys.length, keys, None, Some(agg)) + def apply(dest: String, keys: List[String], weights: Weights, agg: Aggregate) = + get(dest, keys.length, keys, Some(weights), Some(agg)) + + /** get a new instance of the appropriate storage class + * @param d - Destination + * @param n - Number of keys + * @param k - Keys + * @param w - Weights + * @param a - Aggregate + * + * @return new instance + */ + def get(d: String, n: Int, k: List[String], w: Option[Weights], a: Option[Aggregate]): ZStore + + def apply(args: List[Array[Byte]]) = BytesToString.fromList(args) match { + case destination :: nk :: tail => + val numkeys = RequireClientProtocol.safe { NumberFormat.toInt(nk) } + tail.size match { + case done if done == numkeys => + get(destination, numkeys, tail, None, None) + case more if more > numkeys => + parseArgs(destination, numkeys, tail) + case _ => + throw ClientError("Specified keys must equal numkeys") + } + case _ => throw ClientError("Expected a minimum of 3 arguments for command") + } + + protected def parseArgs(dest: String, numkeys: Int, remaining: List[String]) = { + val (keys, args) = remaining.splitAt(numkeys) + args.isEmpty match { + case true => + get(dest, numkeys, keys, None, None) + case false => + val (args0, args1) = findArgs(args, numkeys) + RequireClientProtocol(args0.length > 1, "Length of arguments must be > 1") + val weights = findWeights(args0, args1) + val aggregate = findAggregate(args0, args1) + weights.foreach { w => + RequireClientProtocol(w.size == numkeys, "WEIGHTS length must equal keys length") + } + get(dest, numkeys, keys, weights, aggregate) + } + } + + protected def findArgs(args: List[String], numkeys: Int) = { + RequireClientProtocol(args != null && args.length > 0, "Args list must not be empty") + args.head.toUpperCase match { + case Weights.WEIGHTS => args.splitAt(numkeys+1) + case Aggregate.AGGREGATE => args.splitAt(2) + case s => throw ClientError("AGGREGATE or WEIGHTS argument expected, found %s".format(s)) + } + } + + protected def findWeights(args0: List[String], args1: List[String]) = Weights(args0) match { + case None => args1.length > 0 match { + case true => Weights(args1) match { + case None => throw ClientError("Have additional arguments but unable to process") + case w => w + } + case false => None + } + case w => w + } + + protected def findAggregate(args0: List[String], args1: List[String]) = Aggregate(args0) match { + case None => args1.length > 0 match { + case true => Aggregate(args1) match { + case None => throw ClientError("Have additional arguments but unable to process") + case agg => agg + } + case false => None + } + case agg => agg + } +} + + +trait ZScoredRange extends KeyCommand { self => + val key: String + val min: ZInterval + val max: ZInterval + val withScores: Option[CommandArgument] + val limit: Option[Limit] + + val command: String + + override def validate() { + super.validate() + withScores.map { s => + s match { + case WithScores => + case _ => throw ClientError("withScores must be an instance of WithScores") + } + } + RequireClientProtocol(min != null, "min must not be null") + RequireClientProtocol(max != null, "max must not be null") + } + + override def toChannelBuffer = { + val command = List(self.command, key, min.toString, max.toString) + val scores: List[String] = withScores match { + case Some(WithScores) => List(WithScores.toString) + case None => Nil + } + val limits: List[String] = limit match { + case Some(limit) => List(limit.toString) + case None => Nil + } + RedisCodec.toInlineFormat(command ::: scores ::: limits) + } +} +trait ZScoredRangeCompanion { self => + def get( + key: String, + min: ZInterval, + max: ZInterval, + withScores: Option[CommandArgument], + limit: Option[Limit]): ZScoredRange + + def apply(args: List[Array[Byte]]) = args match { + case key :: min :: max :: Nil => + get(BytesToString(key), ZInterval(min), ZInterval(max), None, None) + case key :: min :: max :: tail => + parseArgs(BytesToString(key), ZInterval(min), ZInterval(max), tail) + case _ => + throw ClientError("Expected either 3, 4 or 5 args for ZRANGEBYSCORE/ZREVRANGEBYSCORE") + } + + def apply(key: String, min: ZInterval, max: ZInterval, withScores: CommandArgument) = { + withScores match { + case WithScores => + get(key, min, max, Some(withScores), None) + case _ => + throw ClientError("Only WITHSCORES is supported") + } + } + + def apply(key: String, min: ZInterval, max: ZInterval, limit: Limit) = + get(key, min, max, None, Some(limit)) + + def apply(key: String, min: ZInterval, max: ZInterval, withScores: CommandArgument, + limit: Limit) = + { + withScores match { + case WithScores => + get(key, min, max, Some(withScores), Some(limit)) + case _ => + throw ClientError("Only WITHSCORES supported") + } + } + + protected def parseArgs(key: String, min: ZInterval, max: ZInterval, args: List[Array[Byte]]) = { + RequireClientProtocol(args != null && args.length > 0, "Expected arguments for command") + val sArgs = BytesToString.fromList(args) + val (arg0, remaining) = doParse(sArgs) + + remaining.isEmpty match { + case true => + get(key, min, max, convertScore(arg0), convertLimit(arg0)) + case false => + val (arg1, leftovers) = doParse(remaining) + leftovers.isEmpty match { + case true => + val score = findScore(arg0, arg1) + val limit = findLimit(arg0, arg1) + get(key, min, max, score, limit) + case false => + throw ClientError("Found unexpected extra arguments for command") + } + } + } + type ScoreOrLimit = Either[CommandArgument,Limit] + protected def doParse(args: List[String]): (ScoreOrLimit, List[String]) = { + args.head match { + case WithScores(s) => + (Left(WithScores), if (args.length > 1) args.drop(1) else Nil) + case _ => + (Right(Limit(args.take(3))), if (args.length > 3) args.drop(3) else Nil) + } + } + protected def findScore(arg0: ScoreOrLimit, arg1: ScoreOrLimit) = convertScore(arg0) match { + case None => convertScore(arg1) match { + case None => throw ClientError("No WITHSCORES found but one expected") + case s => s + } + case s => s + } + protected def convertScore(arg: ScoreOrLimit) = arg match { + case Left(_) => Some(WithScores) + case _ => None + } + protected def findLimit(arg0: ScoreOrLimit, arg1: ScoreOrLimit) = convertLimit(arg0) match { + case None => convertLimit(arg1) match { + case None => throw ClientError("No LIMIT found but one expected") + case s => s + } + case s => s + } + protected def convertLimit(arg: ScoreOrLimit) = arg match { + case Right(limit) => Some(limit) + case _ => None + } + +} + + +abstract class ZRangeCmd extends StrictKeyCommand { + val key: String + val start: Int + val stop: Int + val withScores: Option[CommandArgument] + val command: String + + private def forChannelBuffer = { + val commands = List(this.command, key, start.toString, stop.toString) + val scored = withScores match { + case Some(WithScores) => commands :+ WithScores.toString + case None => commands + } + StringToBytes.fromList(scored) + } + override def toChannelBuffer = RedisCodec.toUnifiedFormat(forChannelBuffer) + +} +trait ZRangeCmdCompanion { + def get(key: String, start: Int, stop: Int, withScores: Option[CommandArgument]): ZRangeCmd + + def apply(args: List[Array[Byte]]) = { + RequireClientProtocol( + args != null && args.length >= 3, + "Expected at least 3 arguments for command") + + BytesToString.fromList(args) match { + case key :: start :: stop :: Nil => + get(key, safeInt(start), safeInt(stop), None) + case key :: start :: stop :: withScores :: Nil => + withScores match { + case WithScores(arg) => + get(key, safeInt(start), safeInt(stop), Some(WithScores)) + case _ => + throw ClientError("Expected 4 arguments with 4th as WITHSCORES") + } + case _ => throw ClientError("Expected 3 or 4 arguments for command") + } + } + + def apply(key: String, start: Int, stop: Int, scored: CommandArgument) = scored match { + case WithScores => get(key, start, stop, Some(scored)) + case _ => throw ClientError("Only WithScores is supported") + } + + protected def safeInt(i: String) = RequireClientProtocol.safe { + NumberFormat.toInt(i) + } +} + + +abstract class ZRankCmd extends StrictKeyCommand with StrictMemberCommand { + val key: String + val member: Array[Byte] + val command: String + + override def toChannelBuffer = RedisCodec.toUnifiedFormat(List( + StringToBytes(this.command), + StringToBytes(key), + member)) +} +trait ZRankCmdCompanion { + def get(key: String, member: Array[Byte]): ZRankCmd + + def apply(args: List[Array[Byte]]) = { + val list = trimList(args, 2, "ZRANKcmd") + get(BytesToString(args(0)), args(1)) + } +} diff --git a/finagle-redis/src/main/scala/com/twitter/finagle/redis/protocol/commands/Strings.scala b/finagle-redis/src/main/scala/com/twitter/finagle/redis/protocol/commands/Strings.scala new file mode 100644 index 0000000000..98a71fb8d7 --- /dev/null +++ b/finagle-redis/src/main/scala/com/twitter/finagle/redis/protocol/commands/Strings.scala @@ -0,0 +1,290 @@ +package com.twitter.finagle.redis +package protocol + +import util._ +import Commands.trimList + +case class Append(key: String, value: Array[Byte]) + extends StrictKeyCommand + with StrictValueCommand +{ + override def toChannelBuffer = RedisCodec.toUnifiedFormat(List( + StringToBytes(Commands.APPEND), + StringToBytes(key), + value)) +} +object Append { + def apply(args: List[Array[Byte]]) = { + val list = trimList(args, 2) + new Append(BytesToString(list(0)), list(1)) + } +} + +case class Decr(override val key: String) extends DecrBy(key, 1) { + override def toChannelBuffer = RedisCodec.toInlineFormat(List(Commands.DECR, key)) +} +object Decr { + def apply(args: List[Array[Byte]]) = { + new Decr(BytesToString(trimList(args, 1, "DECR")(0))) + } +} +class DecrBy(val key: String, val amount: Int) extends StrictKeyCommand { + override def toChannelBuffer = + RedisCodec.toInlineFormat(List(Commands.DECRBY, key, amount.toString)) + override def toString = "DecrBy(%s, %d)".format(key, amount) + override def equals(other: Any) = other match { + case that: DecrBy => that.canEqual(this) && this.key == that.key && this.amount == that.amount + case _ => false + } + def canEqual(other: Any) = other.isInstanceOf[DecrBy] +} +object DecrBy { + def apply(key: String, amount: Int) = new DecrBy(key, amount) + def apply(args: List[Array[Byte]]) = { + val list = BytesToString.fromList(trimList(args, 2, "DECRBY")) + val amount = RequireClientProtocol.safe { + NumberFormat.toInt(list(1)) + } + new DecrBy(list(0), amount) + } +} + +case class Get(key: String) extends StrictKeyCommand { + override def toChannelBuffer = RedisCodec.toInlineFormat(List(Commands.GET, key)) +} +object Get { + def apply(args: List[Array[Byte]]) = { + new Get(BytesToString(trimList(args, 1, "GET")(0))) + } +} + +case class GetBit(key: String, offset: Int) extends StrictKeyCommand { + override def toChannelBuffer = + RedisCodec.toInlineFormat(List(Commands.GETBIT, key, offset.toString)) +} +object GetBit { + def apply(args: List[Array[Byte]]) = { + val list = BytesToString.fromList(trimList(args,2,"GETBIT")) + val offset = RequireClientProtocol.safe { NumberFormat.toInt(list(1)) } + new GetBit(list(0), offset) + } +} + +case class GetRange(key: String, start: Int, end: Int) extends StrictKeyCommand { + override def toChannelBuffer = + RedisCodec.toInlineFormat(List(Commands.GETRANGE, key, start.toString, end.toString)) +} +object GetRange { + def apply(args: List[Array[Byte]]) = { + val list = BytesToString.fromList(trimList(args,3,"GETRANGE")) + val start = RequireClientProtocol.safe { NumberFormat.toInt(list(1)) } + val end = RequireClientProtocol.safe { NumberFormat.toInt(list(2)) } + new GetRange(list(0), start, end) + } +} + +case class GetSet(key: String, value: Array[Byte]) + extends StrictKeyCommand + with StrictValueCommand +{ + override def toChannelBuffer = RedisCodec.toUnifiedFormat(List( + StringToBytes(Commands.GETSET), + StringToBytes(key), + value)) +} +object GetSet { + def apply(args: List[Array[Byte]]) = { + val list = trimList(args, 2, "GETSET") + new GetSet(BytesToString(list(0)), list(1)) + } +} + +case class Incr(override val key: String) extends IncrBy(key, 1) { + override def toChannelBuffer = RedisCodec.toInlineFormat(List(Commands.INCR, key)) +} +object Incr { + def apply(args: List[Array[Byte]]) = { + new Incr(BytesToString(trimList(args, 1, "INCR")(0))) + } +} + +class IncrBy(val key: String, val amount: Int) extends StrictKeyCommand { + override def toChannelBuffer = + RedisCodec.toInlineFormat(List(Commands.INCRBY, key, amount.toString)) + override def toString = "IncrBy(%s, %d)".format(key, amount) + override def equals(other: Any) = other match { + case that: IncrBy => that.canEqual(this) && this.key == that.key && this.amount == that.amount + case _ => false + } + def canEqual(other: Any) = other.isInstanceOf[IncrBy] +} +object IncrBy { + def apply(key: String, amount: Int) = new IncrBy(key, amount) + def apply(args: List[Array[Byte]]) = { + val list = BytesToString.fromList(trimList(args, 2, "INCRBY")) + val amount = RequireClientProtocol.safe { + NumberFormat.toInt(list(1)) + } + new IncrBy(list(0), amount) + } +} + +case class MGet(keys: List[String]) extends StrictKeysCommand { + override def toChannelBuffer = RedisCodec.toInlineFormat(Commands.MGET +: keys) +} + +case class MSet(kv: Map[String, Array[Byte]]) extends MultiSet { + validate() + val command = MSet.command +} +object MSet extends MultiSetCompanion { + val command = Commands.MSET + def get(map: Map[String, Array[Byte]]) = new MSet(map) +} + +case class MSetNx(kv: Map[String, Array[Byte]]) extends MultiSet { + validate() + val command = MSetNx.command +} +object MSetNx extends MultiSetCompanion { + val command = Commands.MSETNX + def get(map: Map[String, Array[Byte]]) = new MSetNx(map) +} + +case class Set(key: String, value: Array[Byte]) + extends StrictKeyCommand + with SetCommand + with StrictValueCommand +{ + val command = Set.command +} +object Set extends SetCommandCompanion { + val command = Commands.SET + def get(key: String, value: Array[Byte]) = new Set(key, value) +} + +case class SetBit(key: String, offset: Int, value: Int) extends StrictKeyCommand { + override def toChannelBuffer = + RedisCodec.toInlineFormat(List(Commands.SETBIT, key, offset.toString, value.toString)) +} +object SetBit { + def apply(args: List[Array[Byte]]) = { + val list = BytesToString.fromList(trimList(args,3,"SETBIT")) + val offset = RequireClientProtocol.safe { NumberFormat.toInt(list(1)) } + val value = RequireClientProtocol.safe { NumberFormat.toInt(list(2)) } + new SetBit(list(0), offset, value) + } +} + +case class SetEx(key: String, seconds: Long, value: Array[Byte]) + extends StrictKeyCommand + with StrictValueCommand +{ + RequireClientProtocol(seconds > 0, "Seconds must be greater than 0") + override def toChannelBuffer = RedisCodec.toUnifiedFormat(List( + StringToBytes(Commands.SETEX), + StringToBytes(key), + StringToBytes(seconds.toString), + value + )) +} +object SetEx { + def apply(args: List[Array[Byte]]) = { + val list = trimList(args, 3, "SETEX") + val seconds = RequireClientProtocol.safe { NumberFormat.toLong(BytesToString(list(1))) } + new SetEx(BytesToString(list(0)), seconds, list(2)) + } +} + +case class SetNx(key: String, value: Array[Byte]) + extends StrictKeyCommand + with SetCommand + with StrictValueCommand +{ + val command = SetNx.command +} +object SetNx extends SetCommandCompanion { + val command = Commands.SETNX + def get(key: String, value: Array[Byte]) = new SetNx(key, value) +} + +case class SetRange(key: String, offset: Int, value: Array[Byte]) + extends StrictKeyCommand + with StrictValueCommand +{ + override def toChannelBuffer = RedisCodec.toUnifiedFormat(List( + StringToBytes(Commands.SETRANGE), + StringToBytes(key), + StringToBytes(offset.toString), + value + )) +} +object SetRange { + def apply(args: List[Array[Byte]]) = { + val list = trimList(args,3,"SETRANGE") + val key = BytesToString(list(0)) + val offset = RequireClientProtocol.safe { NumberFormat.toInt(BytesToString(list(1))) } + val value = list(2) + new SetRange(key, offset, value) + } +} + +case class Strlen(key: String) extends StrictKeyCommand { + override def toChannelBuffer = RedisCodec.toInlineFormat(List(Commands.STRLEN, key)) +} +object Strlen { + def apply(args: List[Array[Byte]]) = { + val list = BytesToString.fromList(trimList(args,1,"STRLEN")) + new Strlen(list(0)) + } +} + +/** Helpers for common idioms */ +trait SetCommand extends KeyCommand with ValueCommand { + val command: String + + def toChannelBuffer = RedisCodec.toUnifiedFormat(List( + StringToBytes(command), + StringToBytes(key), + value + )) +} +trait SetCommandCompanion { + val command: String + def apply(args: List[Array[Byte]]) = { + val list = trimList(args, 2, command) + get(BytesToString(list(0)), list(1)) + } + def get(key: String, value: Array[Byte]): SetCommand +} + +trait MultiSet extends KeysCommand { + val kv: Map[String, Array[Byte]] + val command: String + override lazy val keys: List[String] = kv.keys.toList + + override def toChannelBuffer = { + val kvList: List[Array[Byte]] = kv.keys.zip(kv.values).flatMap { case(k,v) => + StringToBytes(k) :: v :: Nil + }(collection.breakOut) + RedisCodec.toUnifiedFormat(StringToBytes(command) :: kvList) + } +} +trait MultiSetCompanion { + val command: String + def apply(args: List[Array[Byte]]) = { + val length = args.length + + RequireClientProtocol( + length % 2 == 0 && length > 0, + "Expected even number of k/v pairs for " + command) + + val map = args.grouped(2).map { + case key :: value :: Nil => (BytesToString(key), value) + case _ => throw new ClientError("Unexpected uneven pair of elements in MSET") + }.toMap + RequireClientProtocol(map.size == length/2, "Broken mapping, map size not equal to group size") + get(map) + } + def get(map: Map[String, Array[Byte]]): MultiSet +} diff --git a/finagle-redis/src/main/scala/com/twitter/finagle/redis/util/Conversions.scala b/finagle-redis/src/main/scala/com/twitter/finagle/redis/util/Conversions.scala new file mode 100644 index 0000000000..9e5f09294c --- /dev/null +++ b/finagle-redis/src/main/scala/com/twitter/finagle/redis/util/Conversions.scala @@ -0,0 +1,64 @@ +package com.twitter.finagle.redis +package util + +import org.jboss.netty.buffer.{ChannelBuffer, ChannelBuffers} + +trait ErrorConversion { + def getException(msg: String): Throwable + + def apply(requirement: Boolean, message: String = "Prerequisite failed") { + if (!requirement) { + throw getException(message) + } + } + def safe[T](fn: => T): T = { + try { + fn + } catch { + case e: Throwable => throw getException(e.getMessage) + } + } +} + +object BytesToString { + def apply(arg: Array[Byte], charset: String = "UTF-8") = new String(arg, charset) + def fromList(args: List[Array[Byte]], charset: String = "UTF-8") = args.map { arg => + BytesToString(arg, charset) + } +} +object StringToBytes { + def apply(arg: String, charset: String = "UTF-8") = arg.getBytes(charset) + def fromList(args: List[String], charset: String = "UTF-8") = args.map { arg => + arg.getBytes(charset) + } +} +object StringToChannelBuffer { + def apply(string: String, charset: String = "UTF-8") = { + ChannelBuffers.wrappedBuffer(string.getBytes(charset)) + } +} +object NumberFormat { + import com.twitter.naggati.ProtocolError + def toLong(arg: String): Long = { + try { + arg.toLong + } catch { + case e: Throwable => throw new ProtocolError("Unable to convert %s to Long".format(arg)) + } + } + def toInt(arg: String): Int = { + try { + arg.toInt + } catch { + case e: Throwable => throw new ProtocolError("Unable to convert %s to Int".format(arg)) + } + } + def toFloat(arg: String): Float = { + try { + arg.toFloat + } catch { + case e: Throwable => throw new ProtocolError("Unable to convert %s to Float".format(arg)) + } + } +} + diff --git a/finagle-redis/src/main/scala/com/twitter/finagle/redis/util/TestServer.scala b/finagle-redis/src/main/scala/com/twitter/finagle/redis/util/TestServer.scala new file mode 100644 index 0000000000..393125f853 --- /dev/null +++ b/finagle-redis/src/main/scala/com/twitter/finagle/redis/util/TestServer.scala @@ -0,0 +1,111 @@ +package com.twitter.finagle.redis +package util + +import java.lang.ProcessBuilder +import java.net.InetSocketAddress +import java.io.{BufferedWriter, FileWriter, PrintWriter, File} +import com.twitter.util.RandomSocket +import collection.JavaConversions._ +import scala.util.Random + +// Helper classes for spinning up a little redis cluster +private[twitter] object RedisCluster { self => + import collection.mutable.{Stack => MutableStack} + val instanceStack = MutableStack[ExternalRedis]() + + def address: Option[InetSocketAddress] = instanceStack.head.address + def address(i: Int) = instanceStack(i).address + def addresses: Seq[Option[InetSocketAddress]] = instanceStack.map { i => i.address } + + def hostAddresses(): String = { + require(instanceStack.length > 0) + addresses.map { address => + val addy = address.get + "%s:%d".format(addy.getHostName(), addy.getPort()) + }.sorted.mkString(",") + } + + def start(count: Int = 1) { + 0 until count foreach { i => + val instance = new ExternalRedis() + instance.start() + instanceStack.push(instance) + } + } + def stop() { + instanceStack.pop().stop() + } + def stopAll() { + instanceStack.foreach { i => i.stop() } + instanceStack.clear + } + + // Make sure the process is always killed eventually + Runtime.getRuntime().addShutdownHook(new Thread { + override def run() { + self.instanceStack.foreach { instance => instance.stop() } + } + }); +} + +private[twitter] class ExternalRedis() { + private[this] val rand = new Random + private[this] var process: Option[Process] = None + private[this] val forbiddenPorts = 6300.until(7300) + var address: Option[InetSocketAddress] = None + + private[this] def assertRedisBinaryPresent() { + val p = new ProcessBuilder("redis-server", "--help").start() + p.waitFor() + val exitValue = p.exitValue() + require(exitValue == 0 || exitValue == 1, "redis-server binary must be present.") + } + + private[this] def findAddress() { + var tries = 100 + while (address == None && tries >= 0) { + address = Some(RandomSocket.nextAddress()) + if (forbiddenPorts.contains(address.get.getPort)) { + address = None + tries -= 1 + Thread.sleep(5) + } + } + address.getOrElse { error("Couldn't get an address for the external redis instance") } + } + + protected def createConfigFile(port: Int): File = { + val f = File.createTempFile("redis-"+rand.nextInt(1000), ".tmp") + f.deleteOnExit() + val out = new PrintWriter(new BufferedWriter(new FileWriter(f))) + val conf = "port %s".format(port) + out.write(conf) + out.println() + out.close() + f + } + + def start() { + val port = address.get.getPort() + val conf = createConfigFile(port).getAbsolutePath + val cmd: Seq[String] = Seq("redis-server", conf) + val builder = new ProcessBuilder(cmd.toList) + process = Some(builder.start()) + Thread.sleep(200) + } + + def stop() { + process.foreach { p => + p.destroy() + p.waitFor() + } + } + + def restart() { + stop() + start() + } + + assertRedisBinaryPresent() + findAddress() +} diff --git a/finagle-redis/src/test/scala/com/twitter/finagle/redis/NaggatiSpec.scala b/finagle-redis/src/test/scala/com/twitter/finagle/redis/NaggatiSpec.scala new file mode 100644 index 0000000000..9eae8ec239 --- /dev/null +++ b/finagle-redis/src/test/scala/com/twitter/finagle/redis/NaggatiSpec.scala @@ -0,0 +1,714 @@ +package com.twitter.finagle.redis +package protocol + +import util._ + +import org.specs.Specification +import org.jboss.netty.buffer.{ChannelBuffer, ChannelBuffers} +import com.twitter.conversions.time._ +import com.twitter.naggati.test._ +import com.twitter.util.{Future, Time} +import org.jboss.netty.channel.Channel + +class NaggatiSpec extends Specification { + import com.twitter.logging.Logger + Logger.reset() + val log = Logger() + log.setUseParentHandlers(false) + log.setLevel(Logger.ALL) + + def wrap(s: String) = ChannelBuffers.wrappedBuffer(s.getBytes) + + "A Redis Request" should { + val commandCodec = new CommandCodec + val (codec, counter) = TestCodec(commandCodec.decode, commandCodec.encode) + + "Properly decode" >> { + "inline requests" >> { + "key commands" >> { + "DEL" >> { + codec(wrap("DEL foo\r\n")) mustEqual List(Del(List("foo"))) + codec(wrap("DEL foo bar\r\n")) mustEqual List(Del(List("foo", "bar"))) + codec(wrap("DEL\r\n")) must throwA[ClientError] + } + "EXISTS" >> { + codec(wrap("EXISTS\r\n")) must throwA[ClientError] + codec(wrap("EXISTS foo\r\n")) mustEqual List(Exists("foo")) + } + "EXPIRE" >> { + codec(wrap("EXPIRE foo 100\r\n")) mustEqual List(Expire("foo", 100)) + codec(wrap("EXPIRE foo -1\r\n")) must throwA[ClientError] + } + "EXPIREAT" >> { + codec(wrap("EXPIREAT foo 100\r\n")) must throwA[ClientError] + val time = Time.now + 10.seconds + unwrap(codec(wrap("EXPIREAT foo %d\r\n".format(time.inSeconds)))) { + case ExpireAt("foo", timestamp) => timestamp.inSeconds mustEqual time.inSeconds + } + } + "KEYS" >> { + codec(wrap("KEYS h?llo\r\n")) mustEqual List(Keys("h?llo")) + } + "PERSIST" >> { + codec(wrap("PERSIST\r\n")) must throwA[ClientError] + codec(wrap("PERSIST foo\r\n")) mustEqual List(Persist("foo")) + } + "RENAME" >> { + codec(wrap("RENAME\r\n")) must throwA[ClientError] + codec(wrap("RENAME foo\r\n")) must throwA[ClientError] + codec(wrap("RENAME foo bar\r\n")) mustEqual List(Rename("foo", "bar")) + } + "RENAMENX" >> { + codec(wrap("RENAMENX\r\n")) must throwA[ClientError] + codec(wrap("RENAMENX foo\r\n")) must throwA[ClientError] + codec(wrap("RENAMENX foo bar\r\n")) mustEqual List(RenameNx("foo", "bar")) + } + "RANDOMKEY" >> { + codec(wrap("RANDOMKEY\r\n")) mustEqual List(Randomkey()) + } + "TTL" >> { + codec(wrap("TTL foo\r\n")) mustEqual List(Ttl("foo")) + } + "TYPE" >> { + codec(wrap("TYPE\r\n")) must throwA[ClientError] + codec(wrap("TYPE foo\r\n")) mustEqual List(Type("foo")) + } + } // key commands + + "sorted set commands" >> { + "ZADD" >> { + val bad = List( + "ZADD", "ZADD foo", "ZADD foo 123", "ZADD foo BAD_SCORE bar", + "ZADD foo 123 bar BAD_SCORE bar") + bad.foreach { e => + codec(wrap("%s\r\n".format(e))) must throwA[ClientError] + } + + unwrap(codec(wrap("ZADD nums 3.14159 pi\r\n"))) { + case ZAdd("nums", members) => + unwrap(members) { case ZMember(3.14159f, value) => + BytesToString(value) mustEqual "pi" + } + } + unwrap(codec(wrap("ZADD nums 3.14159 pi 2.71828 e\r\n"))) { + case ZAdd("nums", members) => members match { + case pi :: e :: Nil => + unwrap(List(pi)) { case ZMember(3.14159f, value) => + BytesToString(value) mustEqual "pi" + } + unwrap(List(e)) { case ZMember(2.71828f, value) => + BytesToString(value) mustEqual "e" + } + case _ => fail("Expected two elements in list") + } + } + } // ZADD + + "ZCARD" >> { + codec(wrap("ZCARD\r\n")) must throwA[ClientError] + unwrap(codec(wrap("ZCARD foo\r\n"))) { + case ZCard(key) => key mustEqual "foo" + } + } + + "ZCOUNT" >> { + val bad = List( + "ZCOUNT", "ZCOUNT foo", "ZCOUNT foo 1", "ZCOUNT foo 1 bar", "ZCOUNT foo bar 1", + "ZCOUNT foo -inf foo", "ZCOUNT foo 1 +info", "ZCOUNT foo )1 3", "ZCOUNT foo (1 n") + bad.foreach { b => + codec(wrap("%s\r\n".format(b))) must throwA[ClientError] + } + val good = Map( + "foo -inf +inf" -> ZCount("foo", ZInterval.MIN, ZInterval.MAX), + "foo (1.0 3.0" -> ZCount("foo", ZInterval.exclusive(1), ZInterval(3)) + ) + good.foreach { case(s,v) => + unwrap(codec(wrap("ZCOUNT %s\r\n".format(s)))) { + case c: Command => c mustEqual v + } + } + } + + "ZINCRBY" >> { + val bad = List("ZINCRBY", "ZINCRBY key", "ZINCRBY key 1", "ZINCRBY key bad member") + bad.foreach { b => + codec(wrap("%s\r\n".format(b))) must throwA[ClientError] + } + unwrap(codec(wrap("ZINCRBY key 2 one\r\n"))) { + case ZIncrBy("key", 2, member) => BytesToString(member) mustEqual "one" + } + unwrap(codec(wrap("ZINCRBY key 2.1 one\r\n"))) { + case ZIncrBy("key", value, member) => + value mustEqual 2.1f + BytesToString(member) mustEqual "one" + } + } + + "ZINTERSTORE/ZUNIONSTORE" >> { + val bad = List( + "%s", "%s foo", "%s foo 1 a b", + "%s foo 2 a b WEIGHTS 2", "%s foo 2 a b WEIGHTS 1", + "%s foo 2 a b WEIGHTS 2 2 AGGREGATE foo", + "%s foo 2 a b c WEIGHTS 2 2 2", "%s foo 1 b WEIGHTS 2 AGGREGATE", + "%s foo 2 a b WEIGHTS 2 2 2", "%s foo 1 a WEIGHTS a", + "%s foo 1 a WEIGHTS 2 WEIGHTS 3", + "%s foo 1 a AGGREGATE SUM AGGREGATE MAX") + List("ZINTERSTORE","ZUNIONSTORE").foreach { cmd => + def doCmd(rcmd: String) = codec(wrap(rcmd.format(cmd))) + def verify(k: String, n: Int)(f: (List[String],Option[Weights],Option[Aggregate]) => Unit): PartialFunction[Command,Unit] = + cmd match { + case "ZINTERSTORE" => { + case ZInterStore(k, n, keys, w, a) => f(keys,w,a) + } + case "ZUNIONSTORE" => { + case ZUnionStore(k, n, keys, w, a) => f(keys,w,a) + } + case _ => throw new Exception("Unhandled type") + } + + bad.foreach { b => + val toTry = b.format(cmd) + try { + codec(wrap("%s\r\n".format(toTry))) + fail("Unexpected success for %s".format(toTry)) + } catch { + case e: Throwable => e must haveClass[ClientError] + } + } + unwrap(doCmd("%s out 2 zset1 zset2\r\n")) { + verify("out", 2) { (keys, weights, aggregate) => + keys mustEqual List("zset1", "zset2") + weights must beNone + aggregate must beNone + } + } + unwrap(doCmd("%s out 2 zset1 zset2 WEIGHTS 2 3\r\n")) { + verify("out", 2) { (keys, weights, aggregate) => + keys mustEqual List("zset1", "zset2") + weights must beSome(Weights(2f, 3f)) + aggregate must beNone + } + } + unwrap(doCmd("%s out 2 zset1 zset2 aggregate sum\r\n")) { + verify("out", 2) { (keys, weights, aggregate) => + keys mustEqual List("zset1", "zset2") + aggregate must beSome(Aggregate.Sum) + weights must beNone + } + } + unwrap(doCmd("%s out 2 zset1 zset2 weights 2 3 aggregate min\r\n")) { + verify("out", 2) { (keys, weights, aggregate) => + keys mustEqual List("zset1", "zset2") + weights must beSome(Weights(2f, 3f)) + aggregate must beSome(Aggregate.Min) + } + } + unwrap(doCmd("%s out 2 zset1 zset2 aggregate max weights 2 3\r\n")) { + verify("out", 2) { (keys, weights, aggregate) => + keys mustEqual List("zset1", "zset2") + weights must beSome(Weights(2f, 3f)) + aggregate must beSome(Aggregate.Max) + } + } + } // List(Zinterstore... + } // ZINTERSTORE/ZUNIONSTORE + + "ZRANGE/ZREVRANGE" >> { + val bad = List( + "%s", "%s myset", "%s myset 1", "%s myset 1 foo", + "%s myset foo 1", "%s myset 0 2 blah") + List("ZRANGE","ZREVRANGE").foreach { cmd => + def doCmd(s: String) = codec(wrap("%s\r\n".format(s.format(cmd)))) + def verify(k: String, start: Int, stop: Int, scored: Option[CommandArgument]): PartialFunction[Command,Unit] = { + cmd match { + case "ZRANGE" => { + case ZRange(k, start, stop, scored) => true must beTrue + } + case "ZREVRANGE" => { + case ZRevRange(k, start, stop, scored) => true must beTrue + } + } + } + bad.foreach { b => + val scmd = "%s\r\n".format(b.format(cmd)) + codec(wrap(scmd)) must throwA[ClientError] + } + + unwrap(doCmd("%s myset 0 -1")) { + verify("myset", 0, -1, None) + } + unwrap(doCmd("%s myset 0 -1 withscores")) { + verify("myset", 0, -1, Some(WithScores)) + } + unwrap(doCmd("%s myset 0 -1 WITHSCORES")) { + verify("myset", 0, -1, Some(WithScores)) + } + } + + } // ZRANGE + + "ZRANGEBYSCORE/ZREVRANGEBYSCORE" >> { + val bad = List( + "%s", "%s key", "%s key -inf", "%s key -inf foo", "%s key foo +inf", + "%s key 0 1 NOSCORES", "%s key 0 1 LIMOT 1 2", "%s key 0 1 LIMIT foo 1", + "%s key 0 1 LIMIT 1 foo", "%s key 0 1 WITHSCORES WITHSCORES", + "%s key 0 1 LIMIT 0 1 LIMIT 0 1 WITHSCORES", "%s key 0 1 LIMIT 0 1 NOSCORES", + "%s key 0 1 LIMIT 1", "%s key 0 1 LIMIT 0 1 WITHSCORES NOSCORES"); + List("ZRANGEBYSCORE","ZREVRANGEBYSCORE").foreach { cmd => + def doCmd(rcmd: String) = codec(wrap("%s\r\n".format(rcmd.format(cmd)))) + def verify(k: String, min: ZInterval, max: ZInterval)(f: (Option[CommandArgument],Option[Limit]) => Unit): PartialFunction[Command,Unit] = { + cmd match { + case "ZRANGEBYSCORE" => { + case ZRangeByScore(k, min, max, s, l) => f(s, l) + } + case "ZREVRANGEBYSCORE" => { + case ZRevRangeByScore(k, min, max, s, l) => f(s, l) + } + } + } + + bad.foreach { b => + val cstr = b.format(cmd) + try { + codec(wrap("%s\r\n".format(cstr))) + fail("Unexpected success for %s".format(cstr)) + } catch { + case e: Throwable => e must haveClass[ClientError] + } + } // bad + + unwrap(doCmd("%s myzset -inf +inf")) { + verify("myzset", ZInterval.MIN, ZInterval.MAX) { (s,l) => + s must beNone + l must beNone + } + } + unwrap(doCmd("%s myzset 1 2")) { + verify("myzset", ZInterval(1f), ZInterval(2f)) { (s,l) => + s must beNone + l must beNone + } + } + unwrap(doCmd("%s myzset (1 2")) { + verify("myzset", ZInterval.exclusive(1f), ZInterval(2f)) { (s,l) => + s must beNone + l must beNone + } + } + unwrap(doCmd("%s myzset (1 (2")) { + verify("myzset", ZInterval.exclusive(1f), ZInterval.exclusive(2f)) { (s,l) => + s must beNone + l must beNone + } + } + unwrap(doCmd("%s myzset -inf +inf LIMIT 1 5")) { + verify("myzset", ZInterval.MIN, ZInterval.MAX) { (s,l) => + s must beNone + l must beSome(Limit(1,5)) + } + } + unwrap(doCmd("%s myzset -inf +inf LIMIT 3 9 WITHSCORES")) { + verify("myzset", ZInterval.MIN, ZInterval.MAX) { (s,l) => + s must beSome(WithScores) + l must beSome(Limit(3, 9)) + } + } + } // List(ZRANGEBYSCORE + } // ZRANGEBYSCORE + + "ZRANK/ZREVRANK" >> { + val bad = List("%s", "%s key", "%s key member member") + List("ZRANK", "ZREVRANK").foreach { cmd => + def doCmd(s: String) = codec(wrap("%s\r\n".format(s.format(cmd)))) + def verify(k: String)(f: Array[Byte] => Unit): PartialFunction[Command,Unit] = { + cmd match { + case "ZRANK" => { + case ZRank(k, m) => f(m) + } + case "ZREVRANK" => { + case ZRevRank(k, m) => f(m) + } + } + } + + bad.foreach { b => + val scmd = b.format(cmd) + val fcmd = "%s\r\n".format(scmd) + try { + codec(wrap(fcmd)) + fail("Unexpected success for %s".format(scmd)) + } catch { + case e: Throwable => e must haveClass[ClientError] + } + } + unwrap(doCmd("%s myzset three")) { + verify("myzset") { m => + BytesToString(m) mustEqual "three" + } + } + unwrap(doCmd("%s myzset four")) { + verify("myzset") { m => + BytesToString(m) mustEqual "four" + } + } + } + } // ZRANK/ZREVRANK + + "ZREM" >> { + List("ZREM", "ZREM key").foreach { bad => + codec(wrap("%s\r\n".format(bad))) must throwA[ClientError] + } + unwrap(codec(wrap("ZREM key member1\r\n"))) { + case ZRem("key", members) => + BytesToString.fromList(members) mustEqual List("member1") + } + unwrap(codec(wrap("ZREM key member1 member2\r\n"))) { + case ZRem("key", members) => + BytesToString.fromList(members) mustEqual List("member1", "member2") + } + } + + "ZREMRANGEBYRANK" >> { + val cmd = "ZREMRANGEBYRANK" + val bad = List("%s", "%s key", "%s key start", "%s key 1", "%s key 1 stop", + "%s key start 2") + bad.foreach { b => + codec(wrap("%s\r\n".format(b.format(cmd)))) must throwA[ClientError] + } + unwrap(codec(wrap(cmd + " key 0 1\r\n"))) { + case ZRemRangeByRank("key", start, stop) => + start must beEqualTo(0) + stop must beEqualTo(1) + } + } + + "ZREMRANGEBYSCORE" >> { + val cmd = "ZREMRANGEBYSCORE" + val bad = List("%s", "%s key", "%s key min", "%s key min max", "%s key ( 1", + "%s key (1 max") + bad.foreach { b => + codec(wrap("%s\r\n".format(b.format(cmd)))) must throwA[ClientError] + } + unwrap(codec(wrap(cmd + " key -inf (2.0\r\n"))) { + case ZRemRangeByScore("key", min, max) => + min mustEqual ZInterval.MIN + max mustEqual ZInterval.exclusive(2) + } + } + + "ZSCORE" >> { + List("ZSCORE","ZSCORE key").foreach { bad => + codec(wrap("%s\r\n".format(bad))) must throwA[ClientError] + } + unwrap(codec(wrap("ZSCORE myset one\r\n"))) { + case ZScore("myset", one) => BytesToString(one) mustEqual "one" + } + } + + } // sorted sets should + + "string commands" >> { + "APPEND" >> { + codec(wrap("APPEND\r\n")) must throwA[ClientError] + codec(wrap("APPEND foo\r\n")) must throwA[ClientError] + unwrap(codec(wrap("APPEND foo bar\r\n"))) { + case Append("foo", value) => BytesToString(value) mustEqual "bar" + } + } + "DECR" >> { + codec(wrap("DECR 1\r\n")) mustEqual List(Decr("1")) + codec(wrap("DECR foo\r\n")) mustEqual List(Decr("foo")) + codec(wrap("DECR foo 1\r\n")) must throwA[ClientError] + } + "DECRBY" >> { + codec(wrap("DECRBY foo 1\r\n")) mustEqual List(DecrBy("foo", 1)) + codec(wrap("DECRBY foo 4096\r\n")) mustEqual List(DecrBy("foo", 4096)) + codec(wrap("DECRBY foo\r\n")) must throwA[ClientError] + } + "GET" >> { + codec(wrap("GET foo\r\n")) mustEqual List(Get("foo")) + } + "GETBIT" >> { + codec(wrap("GETBIT\r\n")) must throwA[ClientError] + codec(wrap("GETBIT foo\r\n")) must throwA[ClientError] + codec(wrap("GETBIT foo 0\r\n")) mustEqual List(GetBit("foo", 0)) + } + "GETRANGE" >> { + codec(wrap("GETRANGE\r\n")) must throwA[ClientError] + codec(wrap("GETRANGE key\r\n")) must throwA[ClientError] + codec(wrap("GETRANGE key 0\r\n")) must throwA[ClientError] + codec(wrap("GETRANGE key 0 5\r\n")) mustEqual List(GetRange("key", 0, 5)) + } + "GETSET" >> { + codec(wrap("GETSET\r\n")) must throwA[ClientError] + codec(wrap("GETSET key\r\n")) must throwA[ClientError] + unwrap(codec(wrap("GETSET key value\r\n"))) { + case GetSet("key", value) => BytesToString(value) mustEqual "value" + } + } + "INCR" >> { + codec(wrap("INCR 1\r\n")) mustEqual List(Incr("1")) + codec(wrap("INCR foo\r\n")) mustEqual List(Incr("foo")) + codec(wrap("INCR foo 1\r\n")) must throwA[ClientError] + } + "INCRBY" >> { + codec(wrap("INCRBY foo 1\r\n")) mustEqual List(IncrBy("foo", 1)) + codec(wrap("INCRBY foo 4096\r\n")) mustEqual List(IncrBy("foo", 4096)) + codec(wrap("INCRBY foo\r\n")) must throwA[ClientError] + } + "MGET" >> { + codec(wrap("MGET foo bar\r\n")) mustEqual List(MGet(List("foo","bar"))) + } + "MSETNX" >> { + codec(wrap("MSETNX\r\n")) must throwA[ClientError] + codec(wrap("MSETNX foo\r\n")) must throwA[ClientError] + unwrap(codec(wrap("MSETNX foo bar\r\n"))) { + case MSetNx(map) => + map must haveKey("foo") + BytesToString(map("foo")) mustEqual "bar" + } + } + "SET" >> { + unwrap(codec(wrap("SET foo bar\r\n"))) { + case Set("foo", bar) => BytesToString(bar) mustEqual "bar" + } + } + "SETBIT" >> { + codec(wrap("SETBIT\r\n")) must throwA[ClientError] + codec(wrap("SETBIT foo\r\n")) must throwA[ClientError] + codec(wrap("SETBIT foo 0\r\n")) must throwA[ClientError] + codec(wrap("SETBIT foo 7 1\r\n")) mustEqual List(SetBit("foo", 7, 1)) + } + "SETEX" >> { + codec(wrap("SETEX\r\n")) must throwA[ClientError] + codec(wrap("SETEX key\r\n")) must throwA[ClientError] + codec(wrap("SETEX key 30\r\n")) must throwA[ClientError] + unwrap(codec(wrap("SETEX key 30 value\r\n"))) { + case SetEx("key", 30, value) => BytesToString(value) mustEqual "value" + } + } + "SETNX" >> { + codec(wrap("SETNX\r\n")) must throwA[ClientError] + codec(wrap("SETNX key\r\n")) must throwA[ClientError] + unwrap(codec(wrap("SETNX key value\r\n"))) { + case SetNx("key", value) => BytesToString(value) mustEqual "value" + } + } + "SETRANGE" >> { + codec(wrap("SETRANGE\r\n")) must throwA[ClientError] + codec(wrap("SETRANGE key\r\n")) must throwA[ClientError] + codec(wrap("SETRANGE key 0\r\n")) must throwA[ClientError] + unwrap(codec(wrap("SETRANGE key 0 value\r\n"))) { + case SetRange("key", 0, value) => BytesToString(value) mustEqual "value" + } + } + "STRLEN" >> { + codec(wrap("STRLEN\r\n")) must throwA[ClientError] + codec(wrap("STRLEN foo\r\n")) mustEqual List(Strlen("foo")) + } + } // string commands + } // inline + + def unwrap(list: List[AnyRef])(fn: PartialFunction[Command,Unit]) = list match { + case head :: Nil => head match { + case c: Command => fn.isDefinedAt(c) match { + case true => fn(c) + case false => fail("Didn't find expected type in list: %s".format(c.getClass)) + } + case _ => fail("Expected to find a command in the list") + } + case _ => fail("Expected single element list") + } + + "unified requests" >> { + "GET" >> { + codec(wrap("*2\r\n")) mustEqual Nil + codec(wrap("$3\r\n")) mustEqual Nil + codec(wrap("GET\r\n")) mustEqual Nil + codec(wrap("$3\r\n")) mustEqual Nil + codec(wrap("bar\r\n")) mustEqual List(Get("bar")) + } + "MGET" >> { + codec(wrap("*3\r\n")) mustEqual Nil + codec(wrap("$4\r\n")) mustEqual Nil + codec(wrap("MGET\r\n")) mustEqual Nil + codec(wrap("$3\r\n")) mustEqual Nil + codec(wrap("foo\r\n")) mustEqual Nil + codec(wrap("$3\r\n")) mustEqual Nil + codec(wrap("bar\r\n")) mustEqual List(MGet(List("foo","bar"))) + } + "MSET" >> { + codec(wrap("*5\r\n")) mustEqual Nil + codec(wrap("$4\r\n")) mustEqual Nil + codec(wrap("MSET\r\n")) mustEqual Nil + codec(wrap("$3\r\n")) mustEqual Nil + codec(wrap("foo\r\n")) mustEqual Nil + codec(wrap("$7\r\n")) mustEqual Nil + codec(wrap("bar baz\r\n")) mustEqual Nil + codec(wrap("$3\r\n")) mustEqual Nil + codec(wrap("bar\r\n")) mustEqual Nil + codec(wrap("$5\r\n")) mustEqual Nil + codec(wrap("Hello\r\n")) match { + case MSet(kv) :: Nil => + val nkv = kv.map { case(k,v) => (k, BytesToString(v)) } + nkv mustEqual Map( + "foo" -> "bar baz", + "bar" -> "Hello" + ) + case _ => fail("Expected MSet to be returned") + } + } + } + } // decode + + "Properly encode" >> { + def unpackBuffer(buffer: ChannelBuffer) = { + val bytes = new Array[Byte](buffer.readableBytes) + buffer.readBytes(bytes) + new String(bytes, "UTF-8") + } + + "Inline Requests" >> { + codec.send(Get("foo")) mustEqual List("GET foo\r\n") + } + "Unified Requests" >> { + val value = "bar\r\nbaz" + val valSz = 8 + val expected = "*3\r\n$3\r\nSET\r\n$3\r\nfoo\r\n$%d\r\n%s\r\n".format(valSz, value) + codec.send(Set("foo", value.getBytes)) mustEqual List(expected) + } + } // Encode properly + + } // A Redis Request + + "A Redis Response" should { + val replyCodec = new ReplyCodec + val (codec, counter) = TestCodec(replyCodec.decode, replyCodec.encode) + + "Properly decode" >> { + "status replies" >> { + codec(wrap("+OK\r\n")) mustEqual List(StatusReply("OK")) + codec(wrap("+OK\r\n+Hello World\r\n")) mustEqual List( + StatusReply("OK"), + StatusReply("Hello World")) + codec(wrap("+\r\n")) must throwA[ServerError] + } + "error replies" >> { + codec(wrap("-BAD\r\n")) mustEqual List(ErrorReply("BAD")) + codec(wrap("-BAD\r\n-Bad Thing\r\n")) mustEqual List( + ErrorReply("BAD"), + ErrorReply("Bad Thing")) + codec(wrap("-\r\n")) must throwA[ServerError] + } + "integer replies" >> { + codec(wrap(":-2147483648\r\n")) mustEqual List(IntegerReply(-2147483648)) + codec(wrap(":0\r\n")) mustEqual List(IntegerReply(0)) + codec(wrap(":2147483647\r\n")) mustEqual List(IntegerReply(2147483647)) + codec(wrap(":2147483648\r\n")) must throwA[ServerError] + codec(wrap(":-2147483649\r\n")) must throwA[ServerError] + codec(wrap(":1\r\n:2\r\n")) mustEqual List(IntegerReply(1), IntegerReply(2)) + codec(wrap(":\r\n")) must throwA[ServerError] + } + "bulk replies" >> { + codec(wrap("$3\r\nfoo\r\n")) match { + case reply :: Nil => reply match { + case BulkReply(msg) => BytesToString(msg) mustEqual "foo" + case _ => fail("Expected BulkReply, got something else") + } + case _ => fail("Found no or multiple reply lines") + } + codec(wrap("$8\r\nfoo\r\nbar\r\n")) match { + case reply :: Nil => reply match { + case BulkReply(msg) => BytesToString(msg) mustEqual "foo\r\nbar" + case _ => fail("Expected BulkReply, got something else") + } + case _ => fail("Found no or multiple reply lines") + } + codec(wrap("$3\r\nfoo\r\n$3\r\nbar\r\n")) match { + case fooR :: barR :: Nil => + fooR match { + case BulkReply(msg) => BytesToString(msg) mustEqual "foo" + case _ => fail("Expected BulkReply") + } + barR match { + case BulkReply(msg) => BytesToString(msg) mustEqual "bar" + case _ => fail("Expected BulkReply") + } + case _ => fail("Expected two elements in list") + } + codec(wrap("$-1\r\n")) match { + case reply :: Nil => reply must haveClass[EmptyBulkReply] + case _ => fail("Invalid reply for empty bulk reply") + } + } + "empty multi-bulk replies" >> { + codec(wrap("*0\r\n")) mustEqual List(EmptyMBulkReply()) + } + "multi-bulk replies" >> { + codec(wrap("*4\r\n")) mustEqual Nil + codec(wrap("$3\r\n")) mustEqual Nil + codec(wrap("foo\r\n")) mustEqual Nil + codec(wrap("$3\r\n")) mustEqual Nil + codec(wrap("bar\r\n")) mustEqual Nil + codec(wrap("$5\r\n")) mustEqual Nil + codec(wrap("Hello\r\n")) mustEqual Nil + codec(wrap("$5\r\n")) mustEqual Nil + codec(wrap("World\r\n")) match { + case reply :: Nil => reply match { + case MBulkReply(msgs) => + BytesToString.fromList(msgs) mustEqual List("foo","bar","Hello","World") + case _ => fail("Expected MBulkReply") + } + case _ => fail("Expected one element in list") + } + + codec(wrap("*3\r\n")) mustEqual Nil + codec(wrap("$3\r\n")) mustEqual Nil + codec(wrap("foo\r\n")) mustEqual Nil + codec(wrap("$-1\r\n")) mustEqual Nil + codec(wrap("$3\r\n")) mustEqual Nil + codec(wrap("bar\r\n")) match { + case reply :: Nil => reply match { + case MBulkReply(msgs) => + BytesToString.fromList(msgs) mustEqual List( + "foo", + BytesToString(RedisCodec.NIL_VALUE_BA), + "bar") + case _ => fail("Expected MBulkReply") + } + case _ => fail("Expected one element in list") + } + + } + } + + "Properly encode" >> { + def unpackBuffer(buffer: ChannelBuffer) = { + val bytes = new Array[Byte](buffer.readableBytes) + buffer.readBytes(bytes) + new String(bytes, "UTF-8") + } + + "Status Replies" >> { + codec.send(StatusReply("OK")) mustEqual List("+OK\r\n") + } + "Error Replies" >> { + codec.send(ErrorReply("BAD")) mustEqual List("-BAD\r\n") + } + "Integer Replies" >> { + codec.send(IntegerReply(123)) mustEqual List(":123\r\n") + codec.send(IntegerReply(-123)) mustEqual List(":-123\r\n") + } + "Bulk Replies" >> { + val expected = "$8\r\nfoo\r\nbar\r\n" + codec.send(BulkReply("foo\r\nbar".getBytes)) mustEqual List(expected) + } + "Multi Bulk Replies" >> { + val messages = StringToBytes.fromList(List("foo","bar","Hello","World")) + val expected = "*4\r\n$3\r\nfoo\r\n$3\r\nbar\r\n$5\r\nHello\r\n$5\r\nWorld\r\n" + codec.send(MBulkReply(messages)) mustEqual List(expected) + } + } // Encode properly + + } +} diff --git a/finagle-redis/src/test/scala/com/twitter/finagle/redis/integration/ClientServerIntegrationSpec.scala b/finagle-redis/src/test/scala/com/twitter/finagle/redis/integration/ClientServerIntegrationSpec.scala new file mode 100644 index 0000000000..410329c0f4 --- /dev/null +++ b/finagle-redis/src/test/scala/com/twitter/finagle/redis/integration/ClientServerIntegrationSpec.scala @@ -0,0 +1,452 @@ +package com.twitter.finagle.redis +package protocol +package integration + +import util._ + +import com.twitter.util.{Future, RandomSocket} +import com.twitter.finagle.Service +import com.twitter.finagle.builder.{ClientBuilder, ServerBuilder} +import org.specs.Specification + +class ClientServerIntegrationSpec extends Specification { + val serverAddress = RandomSocket.nextAddress() + implicit def s2b(s: String) = s.getBytes + + lazy val svcClient = ClientBuilder() + .name("redis-client") + .codec(Redis()) + .hosts(RedisCluster.hostAddresses()) + .hostConnectionLimit(2) + .retries(2) + .build() + + lazy val client = ClientBuilder() + .name("redis-client") + .codec(Redis()) + .hosts("localhost:%d".format(serverAddress.getPort())) + .hostConnectionLimit(2) + .retries(2) + .build() + + val service = new Service[Command, Reply] { + def apply(cmd: Command): Future[Reply] = { + svcClient(cmd) + } + } + + val server = ServerBuilder() + .name("redis-server") + .codec(Redis()) + .bindTo(serverAddress) + .build(service) + + val KEY = "foo" + val VALUE = "bar" + + doBeforeSpec { + RedisCluster.start(1) + client(Del(List(KEY)))() + client(Set(KEY, VALUE))() mustEqual StatusReply("OK") + } + + "A client" should { + "Handle Key Commands" >> { + "DEL" >> { + client(Set("key1", "val1"))() mustEqual StatusReply("OK") + client(Set("key2", "val2"))() mustEqual StatusReply("OK") + client(Del(List("key1", "key2")))() mustEqual IntegerReply(2) + client(Del(null:List[String]))() must throwA[ClientError] + client(Del(List(null)))() must throwA[ClientError] + } + "EXISTS" >> { + client(Exists(KEY))() mustEqual IntegerReply(1) + client(Exists("nosuchkey"))() mustEqual IntegerReply(0) + client(Exists(null:String))() must throwA[ClientError] + } + "EXPIRE" >> { + client(Expire("key1", 30))() mustEqual IntegerReply(0) + client(Expire(null, 30))() must throwA[ClientError] + client(Expire(KEY, 3600))() mustEqual IntegerReply(1) + assertIntegerReply(client(Ttl(KEY)), 3600) + } + "EXPIREAT" >> { + import com.twitter.util.Time + import com.twitter.conversions.time._ + val t = Time.now + 3600.seconds + client(ExpireAt("key1", t))() mustEqual IntegerReply(0) + client(ExpireAt(KEY, t))() mustEqual IntegerReply(1) + assertIntegerReply(client(Ttl(KEY)), 3600) + client(ExpireAt(null, t))() must throwA[ClientError] + } + "KEYS" >> { + val request = client(Keys("%s*".format(KEY))) + val expects = List(KEY) + assertMBulkReply(request, expects, true) + client(Keys("%s.*".format(KEY)))() mustEqual EmptyMBulkReply() + } + "PERSIST" >> { + client(Persist("nosuchkey"))() mustEqual IntegerReply(0) + client(Set("persist", "value"))() + client(Persist("persist"))() mustEqual IntegerReply(0) + client(Expire("persist", 30))() + client(Persist("persist"))() mustEqual IntegerReply(1) + client(Ttl("persist"))() mustEqual IntegerReply(-1) + client(Del(List("persist")))() + } + "RENAME" >> { + client(Set("rename1", "foo"))() + assertBulkReply(client(Get("rename1")), "foo") + client(Rename("rename1", "rename1"))() must haveClass[ErrorReply] + client(Rename("nosuchkey", "foo"))() must haveClass[ErrorReply] + client(Rename("rename1", "rename2"))() mustEqual StatusReply("OK") + assertBulkReply(client(Get("rename2")), "foo") + } + "RENAMENX" >> { + client(Set("renamenx1", "foo"))() + assertBulkReply(client(Get("renamenx1")), "foo") + client(RenameNx("renamenx1", "renamenx1"))() must haveClass[ErrorReply] + client(RenameNx("nosuchkey", "foo"))() must haveClass[ErrorReply] + client(RenameNx("renamenx1", "renamenx2"))() mustEqual IntegerReply(1) + assertBulkReply(client(Get("renamenx2")), "foo") + client(RenameNx("renamenx2", KEY))() mustEqual IntegerReply(0) + } + "RANDOMKEY" >> { + client(Randomkey())() must haveClass[BulkReply] + } + "TTL" >> { // tested by expire/expireat + client(Ttl(null:String))() must throwA[ClientError] + client(Ttl("thing"))() mustEqual IntegerReply(-1) + } + "TYPE" >> { + client(Type(KEY))() mustEqual StatusReply("string") + client(Type("nosuchkey"))() mustEqual StatusReply("none") + } + } + + "Handle Sorted Set Commands" >> { + val ZKEY = "zkey" + val ZVAL = List(ZMember(1, "one"), ZMember(2, "two"), ZMember(3, "three")) + + def zAdd(key: String, members: ZMember*) { + members.foreach { member => + client(ZAdd(key, member))() mustEqual IntegerReply(1) + } + } + + doBefore { + ZVAL.foreach { zv => + client(ZAdd(ZKEY, zv))() mustEqual IntegerReply(1) + } + } + doAfter { + client(Del(ZKEY))() mustEqual IntegerReply(1) + } + + "ZADD" >> { + client(ZAdd("zadd1", ZMember(1, "one")))() mustEqual IntegerReply(1) + client(ZAdd("zadd1", ZMember(2, "two")))() mustEqual IntegerReply(1) + client(ZAdd("zadd1", ZMember(3, "two")))() mustEqual IntegerReply(0) + val expected = List("one", "1", "two", "3") + assertMBulkReply(client(ZRange("zadd1", 0, -1, WithScores)), expected) + assertMBulkReply(client(ZRange("zadd1", 0, -1)), List("one", "two")) + } + "ZCARD" >> { + client(ZCard(ZKEY))() mustEqual IntegerReply(3) + client(ZCard("nosuchkey"))() mustEqual IntegerReply(0) + client(ZCard(KEY))() must haveClass[ErrorReply] + } + "ZCOUNT" >> { + client(ZCount(ZKEY, ZInterval.MIN, ZInterval.MAX))() mustEqual IntegerReply(3) + client(ZCount(ZKEY, ZInterval.exclusive(1), ZInterval(3)))() mustEqual IntegerReply(2) + } + "ZINCRBY" >> { + zAdd("zincrby1", ZMember(1, "one"), ZMember(2, "two")) + assertBulkReply(client(ZIncrBy("zincrby1", 2, "one")), "3") + assertMBulkReply( + client(ZRange("zincrby1", 0, -1, WithScores)), + List("two", "2", "one", "3")) + } + "ZINTERSTORE/ZUNIONSTORE" >> { + val key = "zstore1" + val key2 = "zstore2" + zAdd(key, ZMember(1, "one"), ZMember(2, "two")) + zAdd(key2, ZMember(1, "one"), ZMember(2, "two"), ZMember(3, "three")) + + client(ZInterStore("out", List(key, key2), Weights(2,3)))() mustEqual IntegerReply(2) + assertMBulkReply( + client(ZRange("out", 0, -1, WithScores)), + List("one", "5", "two", "10")) + + client(ZUnionStore("out", List(key, key2), Weights(2,3)))() mustEqual IntegerReply(3) + assertMBulkReply( + client(ZRange("out", 0, -1, WithScores)), + List("one", "5", "three", "9", "two", "10")) + } + "ZRANGE/ZREVRANGE" >> { + zAdd("zrange1", ZMember(1, "one"), ZMember(2, "two"), ZMember(3, "three")) + + assertMBulkReply(client(ZRange("zrange1", 0, -1)), List("one", "two", "three")) + assertMBulkReply( + client(ZRange("zrange1", 0, -1, WithScores)), + List("one", "1", "two", "2", "three", "3")) + assertMBulkReply(client(ZRange("zrange1", 2, 3)), List("three")) + assertMBulkReply( + client(ZRange("zrange1", 2, 3, WithScores)), + List("three", "3")) + assertMBulkReply(client(ZRange("zrange1", -2, -1)), List("two", "three")) + assertMBulkReply( + client(ZRange("zrange1", -2, -1, WithScores)), + List("two", "2", "three", "3")) + + assertMBulkReply( + client(ZRevRange("zrange1", 0, -1)), + List("three", "two", "one")) + assertMBulkReply( + client(ZRevRange("zrange1", 2, 3)), + List("one")) + assertMBulkReply( + client(ZRevRange("zrange1", -2, -1)), + List("two", "one")) + } + "ZRANGEBYSCORE/ZREVRANGEBYSCORE" >> { + val key = "zrangebyscore1" + zAdd(key, ZMember(1, "one"), ZMember(2, "two"), ZMember(3, "three")) + assertMBulkReply( + client(ZRangeByScore(key, ZInterval.MIN, ZInterval.MAX)), + List("one", "two", "three")) + assertMBulkReply( + client(ZRangeByScore(key, ZInterval(1f), ZInterval(2f))), + List("one", "two")) + assertMBulkReply( + client(ZRangeByScore(key, ZInterval.exclusive(1f), ZInterval(2f))), + List("two")) + assertMBulkReply( + client(ZRangeByScore(key, ZInterval.exclusive(1f), ZInterval.exclusive(2f))), + List()) + assertMBulkReply( + client(ZRangeByScore(key, ZInterval.MIN, ZInterval.MAX, Limit(1,5))), + List("two","three")) + + assertMBulkReply( + client(ZRevRangeByScore(key, ZInterval.MAX, ZInterval.MIN)), + List("three", "two", "one")) + assertMBulkReply( + client(ZRevRangeByScore(key, ZInterval(2f), ZInterval(1f))), + List("two", "one")) + assertMBulkReply( + client(ZRevRangeByScore(key, ZInterval(2f), ZInterval.exclusive(1f))), + List("two")) + assertMBulkReply( + client(ZRevRangeByScore(key, ZInterval.exclusive(2f), ZInterval.exclusive(1f))), + List()) + } + "ZRANK/ZREVRANK" >> { + val key = "zrank1" + zAdd(key, ZMember(1, "one"), ZMember(2, "two"), ZMember(3, "three")) + client(ZRank(key, "three"))() mustEqual IntegerReply(2) + client(ZRank(key, "four"))() mustEqual EmptyBulkReply() + client(ZRevRank(key, "one"))() mustEqual IntegerReply(2) + client(ZRevRank(key, "four"))() mustEqual EmptyBulkReply() + } + "ZREM" >> { + val key = "zrem1" + zAdd(key, ZMember(1, "one"), ZMember(2, "two"), ZMember(3, "three")) + client(ZRem(key, List("two")))() mustEqual IntegerReply(1) + client(ZRem(key, List("nosuchmember")))() mustEqual IntegerReply(0) + assertMBulkReply( + client(ZRange(key, 0, -1, WithScores)), + List("one", "1", "three", "3")) + } + "ZREMRANGEBYRANK" >> { + val key = "zremrangebyrank1" + zAdd(key, ZMember(1, "one"), ZMember(2, "two"), ZMember(3, "three")) + client(ZRemRangeByRank(key, 0, 1))() mustEqual IntegerReply(2) + assertMBulkReply( + client(ZRange(key, 0, -1, WithScores)), + List("three", "3")) + } + "ZREMRANGEBYSCORE" >> { + val key = "zremrangebyscore1" + zAdd(key, ZMember(1, "one"), ZMember(2, "two"), ZMember(3, "three")) + client(ZRemRangeByScore(key, ZInterval.MIN, ZInterval.exclusive(2)))() mustEqual IntegerReply(1) + assertMBulkReply( + client(ZRange(key, 0, -1, WithScores)), + List("two", "2", "three", "3")) + } + "ZSCORE" >> { + zAdd("zscore1", ZMember(1, "one")) + assertBulkReply(client(ZScore("zscore1", "one")), "1") + } + } + + "Handle String Commands" >> { + "APPEND" >> { + client(Append("append1", "Hello"))() mustEqual IntegerReply(5) + client(Append("append1", " World"))() mustEqual IntegerReply(11) + assertBulkReply(client(Get("append1")), "Hello World") + } + "DECR" >> { + client(Decr("decr1"))() mustEqual IntegerReply(-1) + client(Decr("decr1"))() mustEqual IntegerReply(-2) + client(Decr(KEY))() must haveClass[ErrorReply] + client(Decr(null:String))() must throwA[ClientError] + } + "DECRBY" >> { + client(DecrBy("decrby1", 1))() mustEqual IntegerReply(-1) + client(DecrBy("decrby1", 10))() mustEqual IntegerReply(-11) + client(DecrBy(KEY, 1))() must haveClass[ErrorReply] + client(DecrBy(null: String, 1))() must throwA[ClientError] + } + "GET" >> { + client(Get("thing"))() must haveClass[EmptyBulkReply] + assertBulkReply(client(Get(KEY)), VALUE) + client(Get(null:String))() must throwA[ClientError] + client(Get(null:List[Array[Byte]]))() must throwA[ClientError] + } + "GETBIT" >> { + client(SetBit("getbit", 7, 1))() mustEqual IntegerReply(0) + client(GetBit("getbit", 0))() mustEqual IntegerReply(0) + client(GetBit("getbit", 7))() mustEqual IntegerReply(1) + client(GetBit("getbit", 100))() mustEqual IntegerReply(0) + } + "GETRANGE" >> { + val key = "getrange" + val value = "This is a string" + client(Set(key, value))() mustEqual StatusReply("OK") + assertBulkReply(client(GetRange(key, 0, 3)), "This") + assertBulkReply(client(GetRange(key, -3, -1)), "ing") + assertBulkReply(client(GetRange(key, 0, -1)), value) + assertBulkReply(client(GetRange(key, 10, 100)), "string") + } + "GETSET" >> { + val key = "getset" + client(Incr(key))() mustEqual IntegerReply(1) + assertBulkReply(client(GetSet(key, "0")), "1") + assertBulkReply(client(Get(key)), "0") + client(GetSet("brandnewkey", "foo"))() mustEqual EmptyBulkReply() + } + "INCR" >> { + client(Incr("incr1"))() mustEqual IntegerReply(1) + client(Incr("incr1"))() mustEqual IntegerReply(2) + client(Incr(KEY))() must haveClass[ErrorReply] + client(Incr(null:String))() must throwA[ClientError] + } + "INCRBY" >> { + client(IncrBy("incrby1", 1))() mustEqual IntegerReply(1) + client(IncrBy("incrby1", 10))() mustEqual IntegerReply(11) + client(IncrBy(KEY, 1))() must haveClass[ErrorReply] + client(IncrBy(null: String, 1))() must throwA[ClientError] + } + "MGET" >> { + val expects = List( + BytesToString(RedisCodec.NIL_VALUE_BA), + VALUE + ) + val req = client(MGet(List("thing", KEY))) + assertMBulkReply(req, expects) + client(MGet(null))() must throwA[ClientError] + client(MGet(List(null)))() must throwA[ClientError] + } + "MSET" >> { + val input = Map( + "thing" -> StringToBytes("thang"), + KEY -> StringToBytes(VALUE), + "stuff" -> StringToBytes("bleh") + ) + client(MSet(input))() mustEqual StatusReply("OK") + val req = client(MGet(List("thing",KEY,"noexists","stuff"))) + val expects = List("thang",VALUE,BytesToString(RedisCodec.NIL_VALUE_BA),"bleh") + assertMBulkReply(req, expects) + } + "MSETNX" >> { + val input1 = Map( + "msnx.key1" -> s2b("Hello"), + "msnx.key2" -> s2b("there") + ) + client(MSetNx(input1))() mustEqual IntegerReply(1) + val input2 = Map( + "msnx.key2" -> s2b("there"), + "msnx.key3" -> s2b("world") + ) + client(MSetNx(input2))() mustEqual IntegerReply(0) + val expects = List("Hello", "there", BytesToString(RedisCodec.NIL_VALUE_BA)) + assertMBulkReply(client(MGet(List("msnx.key1","msnx.key2","msnx.key3"))), expects) + } + "SET" >> { + client(Set(null, null))() must throwA[ClientError] + client(Set("key1", null))() must throwA[ClientError] + client(Set(null, "value1"))() must throwA[ClientError] + } + "SETBIT" >> { + client(SetBit("setbit", 7, 1))() mustEqual IntegerReply(0) + client(SetBit("setbit", 7, 0))() mustEqual IntegerReply(1) + assertBulkReply(client(Get("setbit")), BytesToString(Array[Byte](0))) + } + "SETEX" >> { + val key = "setex" + client(SetEx(key, 10, "Hello"))() mustEqual StatusReply("OK") + client(Ttl(key))() match { + case IntegerReply(seconds) => seconds must beCloseTo(10, 2) + case _ => fail("Expected IntegerReply") + } + assertBulkReply(client(Get(key)), "Hello") + } + "SETNX" >> { + val key = "setnx" + val value1 = "Hello" + val value2 = "World" + client(SetNx(key, value1))() mustEqual IntegerReply(1) + client(SetNx(key, value2))() mustEqual IntegerReply(0) + assertBulkReply(client(Get(key)), value1) + } + "SETRANGE" >> { + val key = "setrange" + val value = "Hello World" + client(Set(key, value))() mustEqual StatusReply("OK") + client(SetRange(key, 6, "Redis"))() mustEqual IntegerReply(11) + assertBulkReply(client(Get(key)), "Hello Redis") + } + "STRLEN" >> { + val key = "strlen" + val value = "Hello World" + client(Set(key, value))() mustEqual StatusReply("OK") + client(Strlen(key))() mustEqual IntegerReply(11) + client(Strlen("nosuchkey"))() mustEqual IntegerReply(0) + } + } + } + + def assertMBulkReply(reply: Future[Reply], expects: List[String], contains: Boolean = false) = + reply() match { + case MBulkReply(msgs) => contains match { + case true => + expects.isEmpty must beFalse + val newMsgs = msgs.map { msg => BytesToString(msg) } + expects.foreach { msg => newMsgs must contain(msg) } + case false => + expects.isEmpty must beFalse + msgs.map { msg => BytesToString(msg) } mustEqual expects + } + case EmptyMBulkReply() => expects.isEmpty must beTrue + case r: Reply => fail("Expected MBulkReply, got %s".format(r)) + case _ => fail("Expected MBulkReply") + } + + def assertBulkReply(reply: Future[Reply], expects: String) = reply() match { + case BulkReply(msg) => BytesToString(msg) mustEqual expects + case _ => fail("Expected BulkReply") + } + + def assertIntegerReply(reply: Future[Reply], expects: Int, delta: Int = 10) = reply() match { + case IntegerReply(amnt) => amnt must beCloseTo(expects, delta) + case _ => fail("Expected IntegerReply") + } + + doAfterSpec { + client.release() + server.close() + RedisCluster.stopAll() + } + +} diff --git a/project/build/Project.scala b/project/build/Project.scala index be307d752a..7374d2d903 100644 --- a/project/build/Project.scala +++ b/project/build/Project.scala @@ -75,6 +75,13 @@ class Project(info: ProjectInfo) extends StandardParentProject(info) "finagle-kestrel", "finagle-kestrel", new KestrelProject(_), coreProject, memcachedProject) + /** + * finagle-redis is a redis codec contributed by Tumblr. + */ + val redisProject = project( + "finagle-redis", "finagle-redis", + new RedisProject(_), coreProject) + /** * finagle-http contains an http codec. */ @@ -190,6 +197,24 @@ class Project(info: ProjectInfo) extends StandardParentProject(info) override def compileOrder = CompileOrder.ScalaThenJava } + class RedisProject(info: ProjectInfo) extends StandardProject(info) + with Defaults with UnpublishedProject + { + val naggati = "com.twitter" % "naggati" % "2.2.0" intransitive() + + // This is currently disabled since it requires the user to have a redis + // installation. We might be able to ship the redis binary with some + // architectures, and make the test conditional. + override def testOptions = { + val name = "com.twitter.finagle.redis.protocol.integration.ClientServerIntegrationSpec" + ExcludeTests(name :: Nil) :: super.testOptions.toList + } + + projectDependencies( + "util" ~ "util-logging" + ) + } + class HttpProject(info: ProjectInfo) extends StandardProject(info) with Defaults {