Skip to content
This repository has been archived by the owner on Sep 18, 2021. It is now read-only.

Commit

Permalink
convert the flood test over to client-neutral code too
Browse files Browse the repository at this point in the history
  • Loading branch information
Robey Pointer committed Sep 29, 2011
1 parent aedfd94 commit 60dd3cd
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 39 deletions.
44 changes: 41 additions & 3 deletions src/test/scala/net/lag/kestrel/load/Client.scala
Expand Up @@ -34,6 +34,10 @@ trait Client {

def flush(queueName: String): ByteBuffer
def flushSuccess(): ByteBuffer

def get(queueName: String, timeoutMsec: Option[Int]): ByteBuffer
def getEmpty(queueName: String): ByteBuffer
def getSuccess(queueName: String, data: String): ByteBuffer
}

object MemcacheClient extends Client {
Expand Down Expand Up @@ -62,6 +66,18 @@ object MemcacheClient extends Client {
def flushSuccess() = {
ByteBuffer.wrap("END\r\n".getBytes)
}

def get(queueName: String, timeoutMsec: Option[Int]) = {
ByteBuffer.wrap(("get " + queueName + timeoutMsec.map { "/t=" + _ }.getOrElse("") + "\r\n").getBytes)
}

def getEmpty(queueName: String) = {
ByteBuffer.wrap("END\r\n".getBytes)
}

def getSuccess(queueName: String, data: String) = {
ByteBuffer.wrap(("VALUE " + queueName + " 0 " + data.length + "\r\n" + data + "\r\nEND\r\n").getBytes)
}
}

object ThriftClient extends Client {
Expand All @@ -70,7 +86,7 @@ object ThriftClient extends Client {
import org.apache.thrift.protocol.{TBinaryProtocol, TMessage, TMessageType, TProtocol}
import org.apache.thrift.transport.{TFramedTransport, TIOStreamTransport, TMemoryBuffer}

def withProtocol(f: TProtocol => Unit) = {
private def withProtocol(f: TProtocol => Unit) = {
val buffer = new TMemoryBuffer(512)
val protocol = new TBinaryProtocol(new TFramedTransport(buffer))
f(protocol)
Expand All @@ -83,15 +99,15 @@ object ThriftClient extends Client {

def putSuccess() = putNSuccess(1)

def putN(queueName: String, data: Seq[String]): ByteBuffer = {
def putN(queueName: String, data: Seq[String]) = {
withProtocol { p =>
p.writeMessageBegin(new TMessage("put", TMessageType.CALL, 0))
val item = data.map { item => ByteBuffer.wrap(item.getBytes) }
(new thrift.Kestrel.put_args(queueName, item, 0)).write(p)
}
}

def putNSuccess(count: Int): ByteBuffer = {
def putNSuccess(count: Int) = {
withProtocol { p =>
p.writeMessageBegin(new TMessage("put", TMessageType.REPLY, 0))
(new thrift.Kestrel.put_result(success = Some(count))).write(p)
Expand All @@ -111,4 +127,26 @@ object ThriftClient extends Client {
(new thrift.Kestrel.flush_result()).write(p)
}
}

def get(queueName: String, timeoutMsec: Option[Int]) = {
withProtocol { p =>
p.writeMessageBegin(new TMessage("get", TMessageType.CALL, 0))
(new thrift.Kestrel.get_args(queueName, 1, timeoutMsec.getOrElse(0), true)).write(p)
}
}

def getEmpty(queueName: String) = {
withProtocol { p =>
p.writeMessageBegin(new TMessage("get", TMessageType.REPLY, 0))
(new thrift.Kestrel.get_result(success = Some(Nil))).write(p)
}
}

def getSuccess(queueName: String, data: String) = {
withProtocol { p =>
p.writeMessageBegin(new TMessage("get", TMessageType.REPLY, 0))
val item = new thrift.Item(ByteBuffer.wrap(data.getBytes), 0)
(new thrift.Kestrel.get_result(success = Some(Seq(item)))).write(p)
}
}
}
69 changes: 48 additions & 21 deletions src/test/scala/net/lag/kestrel/load/Flood.scala
Expand Up @@ -31,33 +31,32 @@ import com.twitter.conversions.string._
object Flood extends LoadTesting {
private val DATA = "x" * 1024

private val EXPECT = ByteBuffer.wrap("STORED\r\n".getBytes)

def put(socket: SocketChannel, queueName: String, n: Int, data: String) = {
val spam = ByteBuffer.wrap(("set " + queueName + " 0 0 " + data.length + "\r\n" + data + "\r\n").getBytes)
val buffer = ByteBuffer.allocate(EXPECT.limit)
val spam = client.put(queueName, data)
val expect = client.putSuccess()
val buffer = ByteBuffer.allocate(expect.limit)

for (i <- 0 until n) {
send(socket, spam)
if (receive(socket, buffer) != EXPECT) {
if (receive(socket, buffer) != expect) {
// the "!" is important.
throw new Exception("Unexpected response at " + i + "!")
}
}
}

def get(socket: SocketChannel, queueName: String, n: Int, data: String, blockingReads: Boolean): Int = {
val req = ByteBuffer.wrap(("get " + queueName + (if (blockingReads) "/t=1000" else "") + "\r\n").getBytes)
val expectEnd = ByteBuffer.wrap("END\r\n".getBytes)
val expectData = ByteBuffer.wrap(("VALUE " + queueName + " 0 " + data.length + "\r\n" + data + "\r\nEND\r\n").getBytes)
val expecting = new Expecting(expectEnd, expectData)
def get(socket: SocketChannel, queueName: String, n: Int, data: String): Int = {
val req = client.get(queueName, if (blockingReads) Some(1000) else None)
val expectNoData = client.getEmpty(queueName)
val expectData = client.getSuccess(queueName, data)
val expecting = new Expecting(expectNoData, expectData)

var count = 0
var misses = 0
while (count < n) {
send(socket, req)
val got = expecting(socket)
if (got == expectEnd) {
if (got == expectNoData) {
// nothing yet. poop. :(
misses += 1
} else {
Expand All @@ -72,8 +71,11 @@ object Flood extends LoadTesting {
var queueName = "spam"
var prefillItems = 0
var hostname = "localhost"
var port = 22133
var threads = 1
var blockingReads = false
var client: Client = MemcacheClient
var flushFirst = true

def usage() {
Console.println("usage: flood [options]")
Expand All @@ -86,14 +88,20 @@ object Flood extends LoadTesting {
Console.println(" put KILOBYTES per queue item (default: %d)".format(kilobytes))
Console.println(" -q NAME")
Console.println(" use queue NAME (default: %s)".format(queueName))
Console.println(" -p ITEMS")
Console.println(" -P ITEMS")
Console.println(" prefill ITEMS items into the queue before the test (default: %d)".format(prefillItems))
Console.println(" -h HOSTNAME")
Console.println(" use kestrel on HOSTNAME (default: %s)".format(hostname))
Console.println(" -t THREADS")
Console.println(" create THREADS producers and THREADS consumers (default: %d)".format(threads))
Console.println(" -h HOSTNAME")
Console.println(" use kestrel on HOSTNAME (default: %s)".format(hostname))
Console.println(" -p PORT")
Console.println(" use kestrel on PORT (default: %d)".format(port))
Console.println(" -B")
Console.println(" do blocking reads (reads with a timeout)")
Console.println(" --thrift")
Console.println(" use thrift RPC")
Console.println(" -F")
Console.println(" don't flush queue(s) before the test")
}

def parseArgs(args: List[String]): Unit = args match {
Expand All @@ -110,18 +118,28 @@ object Flood extends LoadTesting {
case "-q" :: x :: xs =>
queueName = x
parseArgs(xs)
case "-p" :: x :: xs =>
case "-P" :: x :: xs =>
prefillItems = x.toInt
parseArgs(xs)
case "-t" :: x :: xs =>
threads = x.toInt
parseArgs(xs)
case "-h" :: x :: xs =>
hostname = x
parseArgs(xs)
case "-t" :: x :: xs =>
threads = x.toInt
case "-p" :: x :: xs =>
port = x.toInt
parseArgs(xs)
case "-B" :: xs =>
blockingReads = true
parseArgs(xs)
case "--thrift" :: xs =>
client = ThriftClient
port = 2229
parseArgs(xs)
case "-F" :: xs =>
flushFirst = false
parseArgs(xs)
case _ =>
usage()
System.exit(1)
Expand All @@ -131,9 +149,18 @@ object Flood extends LoadTesting {
parseArgs(args.toList)
val data = DATA * kilobytes

// flush queues first
if (flushFirst) {
println("Flushing queues first.")
val socket = tryHard { SocketChannel.open(new InetSocketAddress(hostname, port)) }
send(socket, client.flush(queueName))
expect(socket, client.flushSuccess())
socket.close()
}

if (prefillItems > 0) {
println("prefill: " + prefillItems + " items of " + kilobytes + "kB")
val socket = SocketChannel.open(new InetSocketAddress(hostname, 22133))
val socket = tryHard { SocketChannel.open(new InetSocketAddress(hostname, port)) }
put(socket, queueName, prefillItems, data)
}

Expand All @@ -146,15 +173,15 @@ object Flood extends LoadTesting {
for (i <- 0 until threads) {
val producerThread = new Thread {
override def run = {
val socket = SocketChannel.open(new InetSocketAddress(hostname, 22133))
val socket = tryHard { SocketChannel.open(new InetSocketAddress(hostname, port)) }
put(socket, queueName, totalItems, data)
}
}

val consumerThread = new Thread {
override def run = {
val socket = SocketChannel.open(new InetSocketAddress(hostname, 22133))
val n = get(socket, queueName, totalItems, data, blockingReads)
val socket = tryHard { SocketChannel.open(new InetSocketAddress(hostname, port)) }
val n = get(socket, queueName, totalItems, data)
misses.addAndGet(n)
}
}
Expand Down
13 changes: 13 additions & 0 deletions src/test/scala/net/lag/kestrel/load/LoadTesting.scala
Expand Up @@ -20,6 +20,7 @@ package net.lag.kestrel.load
import java.net._
import java.nio._
import java.nio.channels._
import java.util.concurrent.atomic.AtomicInteger
import scala.annotation.tailrec
import scala.collection.mutable
import com.twitter.conversions.string._
Expand Down Expand Up @@ -91,4 +92,16 @@ trait LoadTesting {
throw new Exception("Unexpected response!")
}
}

val failedConnects = new AtomicInteger(0)

final def tryHard[A](f: => A): A = {
try {
f
} catch {
case e: java.io.IOException =>
failedConnects.incrementAndGet()
tryHard(f)
}
}
}
17 changes: 2 additions & 15 deletions src/test/scala/net/lag/kestrel/load/PutMany.scala
Expand Up @@ -20,7 +20,6 @@ package net.lag.kestrel.load
import java.net._
import java.nio._
import java.nio.channels._
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable
import com.twitter.conversions.string._

Expand Down Expand Up @@ -54,7 +53,7 @@ object PutMany extends LoadTesting {
"run down, with no one to find you\n" +
"we're survivors, here til the end\n"

def put(socket: SocketChannel, client: Client, queueName: String, n: Int, globalTimings: mutable.ListBuffer[Long], data: String) = {
def put(socket: SocketChannel, queueName: String, n: Int, globalTimings: mutable.ListBuffer[Long], data: String) = {
val spam = if (rollup == 1) client.put(queueName, data) else client.putN(queueName, (0 until rollup).map { _ => data })
val expect = if (rollup == 1) client.putSuccess() else client.putNSuccess(rollup)

Expand Down Expand Up @@ -166,18 +165,6 @@ object PutMany extends LoadTesting {
System.exit(1)
}

val failedConnects = new AtomicInteger

def tryHard[A](f: => A): A = {
try {
f
} catch {
case e: java.io.IOException =>
failedConnects.incrementAndGet()
tryHard(f)
}
}

def main(args: Array[String]) = {
parseArgs(args.toList)

Expand Down Expand Up @@ -217,7 +204,7 @@ object PutMany extends LoadTesting {
override def run = {
val socket = tryHard { SocketChannel.open(new InetSocketAddress(hostname, port)) }
val qName = queueName + (if (queueCount > 1) (i % queueCount).toString else "")
put(socket, client, qName, totalItems / clientCount, timings, rawData.toString)
put(socket, qName, totalItems / clientCount, timings, rawData.toString)
}
}
threadList = t :: threadList
Expand Down

0 comments on commit 60dd3cd

Please sign in to comment.