Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Merge branch 'master' of git@github.com:robey/scarling

  • Loading branch information...
commit f1101f698d569693f67265b288480d84b4dcea1a 2 parents 2afb9b5 + e7e013f
Steve Jenson authored
View
10 ivy/ivy.xml
@@ -2,7 +2,7 @@
<info
organisation="com.twitter"
module="scarling"
- revision="0.6"
+ revision="0.8"
e:testclass="com.twitter.TestRunner"
e:jarclassname="com.twitter.scarling.Scarling"
e:buildpackage="com/twitter/scarling"
@@ -15,11 +15,11 @@
</configurations>
<dependencies>
<!-- i guess scala-compiler.jar is needed for MainGenericRunner -->
- <dependency org="org.scala-lang" name="scala-compiler" rev="2.7.2-rc6" conf="bootstrap->*; test->*" />
- <dependency org="org.scala-lang" name="scala-library" rev="2.7.2-rc6" />
- <dependency org="org.scala-tools" name="vscaladoc" rev="1.1" conf="bootstrap->*" />
+ <dependency org="org.scala-lang" name="scala-compiler" rev="2.7.2" conf="bootstrap->*; test->*" />
+ <dependency org="org.scala-lang" name="scala-library" rev="2.7.2" />
+ <dependency org="org.scala-tools" name="vscaladoc" rev="1.2-SNAPSHOT" conf="bootstrap->*" />
<dependency org="org.specs" name="specs" rev="1.3.1" />
- <dependency org="net.lag" name="configgy" rev="1.2-rc3" />
+ <dependency org="net.lag" name="configgy" rev="1.2" />
<dependency org="org.apache.mina" name="mina-core" rev="1.1.7" />
<dependency org="org.slf4j" name="slf4j-jdk14" rev="1.5.0" />
</dependencies>
View
3  ivy/ivysettings.xml
@@ -5,6 +5,9 @@
<url name="scala-tools.org" m2compatible="true">
<artifact pattern="http://scala-tools.org/repo-releases/[organisation]/[module]/[revision]/[artifact]-[revision].[ext]" />
</url>
+ <url name="scala-tools.org-snapshots" m2compatible="true">
+ <artifact pattern="http://scala-tools.org/repo-snapshots/[organisation]/[module]/[revision]/[artifact]-[revision].[ext]" />
+ </url>
<ibiblio name="maven2" m2compatible="true"/>
<url name="lag.net" m2compatible="true">
<artifact pattern="http://www.lag.net/repo/[organisation]/[module]/[revision]/[artifact]-[revision].[ext]" />
View
24 src/main/scala/com/twitter/scarling/Event.scala
@@ -1,24 +0,0 @@
-package com.twitter.scarling
-
-// i'm a little weirded out that java doesn't have Event yet.
-class Event {
- private var _flag = false
- private val _lock = new Object
-
- def clear: Unit = _lock.synchronized { _flag = false }
-
- def set: Unit = _lock.synchronized {
- _flag = true
- _lock.notifyAll
- }
-
- def isSet = _lock.synchronized { _flag }
-
- def waitFor(timeout: Int): Unit = _lock.synchronized {
- if (! _flag) {
- _lock.wait(timeout)
- }
- }
-
- def waitFor: Unit = waitFor(0)
-}
View
312 src/main/scala/com/twitter/scarling/Journal.scala
@@ -0,0 +1,312 @@
+package com.twitter.scarling
+
+import net.lag.logging.Logger
+import java.io._
+import java.nio.{ByteBuffer, ByteOrder}
+import java.nio.channels.FileChannel
+
+
+// returned from journal replay
+abstract case class JournalItem
+object JournalItem {
+ case class Add(item: QItem) extends JournalItem
+ case object Remove extends JournalItem
+ case object RemoveTentative extends JournalItem
+ case class SavedXid(xid: Int) extends JournalItem
+ case class Unremove(xid: Int) extends JournalItem
+ case class ConfirmRemove(xid: Int) extends JournalItem
+ case object EndOfFile extends JournalItem
+}
+
+
+/**
+ * Codes for working with the journal file for a PersistentQueue.
+ */
+class Journal(queuePath: String) {
+
+ /* in theory, you might want to sync the file after each
+ * transaction. however, the original starling doesn't.
+ * i think if you can cope with a truncated journal file,
+ * this is fine, because a non-synced file only matters on
+ * catastrophic disk/machine failure.
+ */
+
+ private val log = Logger.get
+
+ private var writer: FileChannel = null
+ private var reader: Option[FileChannel] = None
+ private var replayer: Option[FileChannel] = None
+
+ var size: Long = 0
+
+ // small temporary buffer for formatting operations into the journal:
+ private val buffer = new Array[Byte](16)
+ private val byteBuffer = ByteBuffer.wrap(buffer)
+ byteBuffer.order(ByteOrder.LITTLE_ENDIAN)
+
+ private val CMD_ADD = 0
+ private val CMD_REMOVE = 1
+ private val CMD_ADDX = 2
+ private val CMD_REMOVE_TENTATIVE = 3
+ private val CMD_SAVE_XID = 4
+ private val CMD_UNREMOVE = 5
+ private val CMD_CONFIRM_REMOVE = 6
+
+
+ def open(): Unit = {
+ writer = new FileOutputStream(queuePath, true).getChannel
+ }
+
+ def roll(): Unit = {
+ writer.close
+ val backupFile = new File(queuePath + "." + Time.now)
+ new File(queuePath).renameTo(backupFile)
+ open
+ size = 0
+ backupFile.delete
+ }
+
+ def close(): Unit = {
+ writer.close
+ for (r <- reader) r.close
+ reader = None
+ }
+
+ def inReadBehind(): Boolean = reader.isDefined
+
+ def add(item: QItem) = {
+ val blob = ByteBuffer.wrap(pack(item))
+ byteBuffer.clear
+ byteBuffer.put(CMD_ADDX.toByte)
+ byteBuffer.putInt(blob.limit)
+ byteBuffer.flip
+ do {
+ writer.write(byteBuffer)
+ } while (byteBuffer.position < byteBuffer.limit)
+ do {
+ writer.write(blob)
+ } while (blob.position < blob.limit)
+ size += (5 + blob.limit)
+ }
+
+ def remove() = {
+ byteBuffer.clear
+ byteBuffer.put(CMD_REMOVE.toByte)
+ byteBuffer.flip
+ while (byteBuffer.position < byteBuffer.limit) {
+ writer.write(byteBuffer)
+ }
+ size += 1
+ }
+
+ def removeTentative() = {
+ byteBuffer.clear
+ byteBuffer.put(CMD_REMOVE_TENTATIVE.toByte)
+ byteBuffer.flip
+ while (byteBuffer.position < byteBuffer.limit) {
+ writer.write(byteBuffer)
+ }
+ size += 1
+ }
+
+ def saveXid(xid: Int) = {
+ byteBuffer.clear
+ byteBuffer.put(CMD_SAVE_XID.toByte)
+ byteBuffer.putInt(xid)
+ byteBuffer.flip
+ while (byteBuffer.position < byteBuffer.limit) {
+ writer.write(byteBuffer)
+ }
+ size += 5
+ }
+
+ def unremove(xid: Int) = {
+ byteBuffer.clear
+ byteBuffer.put(CMD_UNREMOVE.toByte)
+ byteBuffer.putInt(xid)
+ byteBuffer.flip
+ while (byteBuffer.position < byteBuffer.limit) {
+ writer.write(byteBuffer)
+ }
+ size += 5
+ }
+
+ def confirmRemove(xid: Int) = {
+ byteBuffer.clear
+ byteBuffer.put(CMD_CONFIRM_REMOVE.toByte)
+ byteBuffer.putInt(xid)
+ byteBuffer.flip
+ while (byteBuffer.position < byteBuffer.limit) {
+ writer.write(byteBuffer)
+ }
+ size += 5
+ }
+
+ def startReadBehind(): Unit = {
+ val pos = if (replayer.isDefined) replayer.get.position else writer.position
+ val rj = new FileInputStream(queuePath).getChannel
+ rj.position(pos)
+ reader = Some(rj)
+ }
+
+ def fillReadBehind(f: QItem => Unit): Unit = {
+ val pos = if (replayer.isDefined) replayer.get.position else writer.position
+ for (rj <- reader) {
+ if (rj.position == pos) {
+ // we've caught up.
+ rj.close
+ reader = None
+ } else {
+ readJournalEntry(rj, false) match {
+ case JournalItem.Add(item) => f(item)
+ case _ =>
+ }
+ }
+ }
+ }
+
+ def replay(name: String)(f: JournalItem => Unit): Unit = {
+ size = 0
+ try {
+ val in = new FileInputStream(queuePath).getChannel
+ replayer = Some(in)
+ var done = false
+ do {
+ readJournalEntry(in, true) match {
+ case JournalItem.EndOfFile => done = true
+ case x: JournalItem => f(x)
+ }
+ } while (!done)
+ } catch {
+ case e: FileNotFoundException =>
+ log.info("No transaction journal for '%s'; starting with empty queue.", name)
+ case e: IOException =>
+ log.error(e, "Exception replaying journal for '%s'", name)
+ log.error("DATA MAY HAVE BEEN LOST!")
+ // this can happen if the server hardware died abruptly in the middle
+ // of writing a journal. not awesome but we should recover.
+ }
+ replayer = None
+ }
+
+ private def readJournalEntry(in: FileChannel, replaying: Boolean): JournalItem = {
+ byteBuffer.rewind
+ byteBuffer.limit(1)
+ var x: Int = 0
+ do {
+ x = in.read(byteBuffer)
+ } while (byteBuffer.position < byteBuffer.limit && x >= 0)
+
+ if (x < 0) {
+ JournalItem.EndOfFile
+ } else {
+ buffer(0) match {
+ case CMD_ADD =>
+ readBlock(in) match {
+ case None => JournalItem.EndOfFile
+ case Some(data) =>
+ if (replaying) size += 5 + data.length
+ JournalItem.Add(unpackOldAdd(data))
+ }
+ case CMD_REMOVE =>
+ if (replaying) size += 1
+ JournalItem.Remove
+ case CMD_ADDX =>
+ readBlock(in) match {
+ case None => JournalItem.EndOfFile
+ case Some(data) =>
+ if (replaying) size += 5 + data.length
+ JournalItem.Add(unpack(data))
+ }
+ case CMD_REMOVE_TENTATIVE =>
+ if (replaying) size += 1
+ JournalItem.RemoveTentative
+ case CMD_SAVE_XID =>
+ readInt(in) match {
+ case None => JournalItem.EndOfFile
+ case Some(xid) =>
+ if (replaying) size += 5
+ JournalItem.SavedXid(xid)
+ }
+ case CMD_UNREMOVE =>
+ readInt(in) match {
+ case None => JournalItem.EndOfFile
+ case Some(xid) =>
+ if (replaying) size += 5
+ JournalItem.Unremove(xid)
+ }
+ case CMD_CONFIRM_REMOVE =>
+ readInt(in) match {
+ case None => JournalItem.EndOfFile
+ case Some(xid) =>
+ if (replaying) size += 5
+ JournalItem.ConfirmRemove(xid)
+ }
+ case n =>
+ throw new IOException("invalid opcode in journal: " + n.toInt)
+ }
+ }
+ }
+
+ private def readBlock(in: FileChannel): Option[Array[Byte]] = {
+ readInt(in) match {
+ case None => None
+ case Some(size) =>
+ val data = new Array[Byte](size)
+ val dataBuffer = ByteBuffer.wrap(data)
+ var x: Int = 0
+ do {
+ x = in.read(dataBuffer)
+ } while (dataBuffer.position < dataBuffer.limit && x >= 0)
+ if (x < 0) {
+ None
+ } else {
+ Some(data)
+ }
+ }
+ }
+
+ private def readInt(in: FileChannel): Option[Int] = {
+ byteBuffer.rewind
+ byteBuffer.limit(4)
+ var x: Int = 0
+ do {
+ x = in.read(byteBuffer)
+ } while (byteBuffer.position < byteBuffer.limit && x >= 0)
+ if (x < 0) {
+ None
+ } else {
+ byteBuffer.rewind
+ Some(byteBuffer.getInt())
+ }
+ }
+
+ private def pack(item: QItem): Array[Byte] = {
+ val bytes = new Array[Byte](item.data.length + 16)
+ val buffer = ByteBuffer.wrap(bytes)
+ buffer.order(ByteOrder.LITTLE_ENDIAN)
+ buffer.putLong(item.addTime)
+ buffer.putLong(item.expiry)
+ buffer.put(item.data)
+ bytes
+ }
+
+ private def unpack(data: Array[Byte]): QItem = {
+ val buffer = ByteBuffer.wrap(data)
+ val bytes = new Array[Byte](data.length - 16)
+ buffer.order(ByteOrder.LITTLE_ENDIAN)
+ val addTime = buffer.getLong
+ val expiry = buffer.getLong
+ buffer.get(bytes)
+ return QItem(addTime, expiry, bytes, 0)
+ }
+
+ private def unpackOldAdd(data: Array[Byte]): QItem = {
+ val buffer = ByteBuffer.wrap(data)
+ val bytes = new Array[Byte](data.length - 4)
+ buffer.order(ByteOrder.LITTLE_ENDIAN)
+ val expiry = buffer.getInt
+ buffer.get(bytes)
+ return QItem(Time.now, if (expiry == 0) 0 else expiry * 1000, bytes, 0)
+ }
+}
View
427 src/main/scala/com/twitter/scarling/PersistentQueue.scala
@@ -3,71 +3,25 @@ package com.twitter.scarling
import java.io._
import java.nio.{ByteBuffer, ByteOrder}
import java.nio.channels.FileChannel
+import java.util.concurrent.CountDownLatch
import scala.actors.{Actor, TIMEOUT}
import scala.collection.mutable
import net.lag.configgy.{Config, Configgy, ConfigMap}
import net.lag.logging.Logger
-// why does java make this so hard? :/
-// this is not threadsafe so may only be used as a local var.
-class IntReader(private val order: ByteOrder) {
- val buffer = new Array[Byte](4)
- val byteBuffer = ByteBuffer.wrap(buffer)
- byteBuffer.order(order)
-
- def readInt(in: DataInputStream) = {
- in.readFully(buffer)
- byteBuffer.rewind
- byteBuffer.getInt
- }
-}
-
-
-// again, must be either behind a lock, or in a local var
-class IntWriter(private val order: ByteOrder) {
- val buffer = new Array[Byte](4)
- val byteBuffer = ByteBuffer.wrap(buffer)
- byteBuffer.order(order)
-
- def writeInt(out: DataOutputStream, n: Int) = {
- byteBuffer.rewind
- byteBuffer.putInt(n)
- out.write(buffer)
- }
-}
+case class QItem(addTime: Long, expiry: Long, data: Array[Byte], var xid: Int)
class PersistentQueue(private val persistencePath: String, val name: String,
val config: ConfigMap) {
- private case class QItem(addTime: Long, expiry: Long, data: Array[Byte]) {
- def pack: Array[Byte] = {
- val bytes = new Array[Byte](data.length + 16)
- val buffer = ByteBuffer.wrap(bytes)
- buffer.order(ByteOrder.LITTLE_ENDIAN)
- buffer.putLong(addTime)
- buffer.putLong(expiry)
- buffer.put(data)
- bytes
- }
- }
-
- private case class JournalItem(command: Int, length: Int, item: Option[QItem])
-
private case class Waiter(actor: Actor)
-
private case object ItemArrived
private val log = Logger.get
- private val CMD_ADD = 0
- private val CMD_REMOVE = 1
- private val CMD_ADDX = 2
-
- private val queuePath: String = new File(persistencePath, name).getCanonicalPath()
-
// current size of all data in the queue:
private var queueSize: Long = 0
@@ -83,21 +37,15 @@ class PersistentQueue(private val persistencePath: String, val name: String,
// # of items in the queue (including those not in memory)
private var queueLength: Long = 0
- private var queue = new mutable.Queue[QItem]
- private var journal: FileOutputStream = null
- private var _journalSize: Long = 0
+ private var queue = new mutable.Queue[QItem] {
+ // scala's Queue doesn't (yet?) have a way to put back.
+ def unget(item: QItem) = prependElem(item)
+ }
private var _memoryBytes: Long = 0
- private var readJournal: Option[FileInputStream] = None
-
- // small temporary buffer for formatting ADD transactions into the journal:
- private var byteBuffer = new ByteArrayOutputStream(16)
- private var buffer = new DataOutputStream(byteBuffer)
-
- // sad way to write little-endian ints when journaling queue adds
- private val intWriter = new IntWriter(ByteOrder.LITTLE_ENDIAN)
+ private var journal = new Journal(new File(persistencePath, name).getCanonicalPath)
// force get/set operations to block while we're replaying any existing journal
- private val initialized = new Event
+ private val initialized = new CountDownLatch(1)
private var closed = false
// attempting to add an item after the queue reaches this size will fail.
@@ -109,8 +57,9 @@ class PersistentQueue(private val persistencePath: String, val name: String,
// clients waiting on an item in this queue
private val waiters = new mutable.ArrayBuffer[Waiter]
- config.subscribe(configure _)
- configure(Some(config))
+ // track tentative removals
+ private var xidCounter: Int = 0
+ private val openTransactions = new mutable.HashMap[Int, QItem]
def length: Long = synchronized { queueLength }
@@ -118,7 +67,7 @@ class PersistentQueue(private val persistencePath: String, val name: String,
def bytes: Long = synchronized { queueSize }
- def journalSize: Long = synchronized { _journalSize }
+ def journalSize: Long = synchronized { journal.size }
def totalExpired: Long = synchronized { _totalExpired }
@@ -127,7 +76,11 @@ class PersistentQueue(private val persistencePath: String, val name: String,
// mostly for unit tests.
def memoryLength: Long = synchronized { queue.size }
def memoryBytes: Long = synchronized { _memoryBytes }
- def inReadBehind = synchronized { readJournal.isDefined }
+ def inReadBehind = synchronized { journal.inReadBehind }
+
+
+ config.subscribe(configure _)
+ configure(Some(config))
def configure(c: Option[ConfigMap]) = synchronized {
for (config <- c) {
@@ -136,26 +89,6 @@ class PersistentQueue(private val persistencePath: String, val name: String,
}
}
-
- private def unpack(data: Array[Byte]): QItem = {
- val buffer = ByteBuffer.wrap(data)
- val bytes = new Array[Byte](data.length - 16)
- buffer.order(ByteOrder.LITTLE_ENDIAN)
- val addTime = buffer.getLong
- val expiry = buffer.getLong
- buffer.get(bytes)
- return QItem(addTime, expiry, bytes)
- }
-
- private def unpackOldAdd(data: Array[Byte]): QItem = {
- val buffer = ByteBuffer.wrap(data)
- val bytes = new Array[Byte](data.length - 4)
- buffer.order(ByteOrder.LITTLE_ENDIAN)
- val expiry = buffer.getInt
- buffer.get(bytes)
- return QItem(System.currentTimeMillis, if (expiry == 0) 0 else expiry * 1000, bytes)
- }
-
private final def adjustExpiry(expiry: Long): Long = {
if (maxAge > 0) {
if (expiry > 0) (expiry min maxAge) else maxAge
@@ -168,94 +101,65 @@ class PersistentQueue(private val persistencePath: String, val name: String,
* Add a value to the end of the queue, transactionally.
*/
def add(value: Array[Byte], expiry: Long): Boolean = {
- initialized.waitFor
+ initialized.await
synchronized {
if (closed || queueLength >= maxItems) {
- return false
- }
-
- if (!readJournal.isDefined && queueSize >= PersistentQueue.maxMemorySize) {
- startReadBehind(journal.getChannel.position)
- }
-
- val item = QItem(System.currentTimeMillis, adjustExpiry(expiry), value)
- val blob = item.pack
-
- byteBuffer.reset()
- buffer.write(CMD_ADDX)
- intWriter.writeInt(buffer, blob.length)
- byteBuffer.writeTo(journal)
- journal.write(blob)
- /* in theory, you might want to sync the file after each
- * transaction. however, the original starling doesn't.
- * i think if you can cope with a truncated journal file,
- * this is fine, because a non-synced file only matters on
- * catastrophic disk/machine failure.
- */
- //journal.getFD.sync
- _journalSize += (5 + blob.length)
-
- _totalItems += 1
- queueLength += 1
- queueSize += value.length
- if (! readJournal.isDefined) {
- queue += item
- _memoryBytes += value.length
- }
-
- if (waiters.size > 0) {
- waiters.remove(0).actor ! ItemArrived
+ false
+ } else {
+ val item = QItem(Time.now, adjustExpiry(expiry), value, 0)
+ if (!journal.inReadBehind && queueSize >= PersistentQueue.maxMemorySize) {
+ log.info("Dropping to read-behind for queue '%s' (%d bytes)", name, queueSize)
+ journal.startReadBehind
+ }
+ _add(item)
+ journal.add(item)
+ if (waiters.size > 0) {
+ waiters.remove(0).actor ! ItemArrived
+ }
+ true
}
- true
}
}
def add(value: Array[Byte]): Boolean = add(value, 0)
/**
- * Remove an item from the queue, transactionally. If no item is
- * available, an empty byte array is returned.
+ * Remove an item from the queue. If no item is available, an empty byte
+ * array is returned.
+ *
+ * @param transaction true if this should be considered the first part
+ * of a transaction, to be committed or rolled back (put back at the
+ * head of the queue)
*/
- def remove(): Option[Array[Byte]] = {
- initialized.waitFor
+ def remove(transaction: Boolean): Option[QItem] = {
+ initialized.await
synchronized {
if (closed || queueLength == 0) {
- return None
- }
-
- journal.write(CMD_REMOVE)
- journal.getFD.sync
- _journalSize += 1
-
- val now = System.currentTimeMillis
- val item = queue.dequeue
- queueLength -= 1
- queueSize -= item.data.length
- _memoryBytes -= item.data.length
-
- if ((queueLength == 0) && (_journalSize >= PersistentQueue.maxJournalSize)) {
- rollJournal
- }
- // if we're in read-behind mode, scan forward in the journal to keep memory as full as
- // possible. this amortizes the disk overhead across all reads.
- while (readJournal.isDefined && _memoryBytes < PersistentQueue.maxMemorySize) {
- fillReadBehind(journal.getChannel.position)
- }
-
- val realExpiry = adjustExpiry(item.expiry)
- if ((realExpiry == 0) || (realExpiry >= now)) {
- _currentAge = now - item.addTime
- Some(item.data)
+ None
} else {
- _totalExpired += 1
- remove
+ val item = _remove(transaction)
+ if (transaction) journal.removeTentative() else journal.remove()
+
+ if ((queueLength == 0) && (journal.size >= PersistentQueue.maxJournalSize) &&
+ (openTransactions.size == 0)) {
+ log.info("Rolling journal file for '%s'", name)
+ journal.roll
+ journal.saveXid(xidCounter)
+ }
+ item
}
}
}
- def remove(timeoutAbsolute: Long)(f: Option[Array[Byte]] => Unit): Unit = {
+ /**
+ * Remove an item from the queue. If no item is available, an empty byte
+ * array is returned.
+ */
+ def remove(): Option[QItem] = remove(false)
+
+ def remove(timeoutAbsolute: Long, transaction: Boolean)(f: Option[QItem] => Unit): Unit = {
synchronized {
- val item = remove()
+ val item = remove(transaction)
if (item.isDefined) {
f(item)
} else if (timeoutAbsolute == 0) {
@@ -263,14 +167,14 @@ class PersistentQueue(private val persistencePath: String, val name: String,
} else {
val w = Waiter(Actor.self)
waiters += w
- Actor.self.reactWithin((timeoutAbsolute - System.currentTimeMillis) max 0) {
- case ItemArrived => remove(timeoutAbsolute)(f)
+ Actor.self.reactWithin((timeoutAbsolute - Time.now) max 0) {
+ case ItemArrived => remove(timeoutAbsolute, transaction)(f)
case TIMEOUT => synchronized {
waiters -= w
// race: someone could have done an add() between the timeout and grabbing the lock.
Actor.self.reactWithin(0) {
- case ItemArrived => f(remove())
- case TIMEOUT => f(remove())
+ case ItemArrived => f(remove(transaction))
+ case TIMEOUT => f(remove(transaction))
}
}
}
@@ -279,6 +183,33 @@ class PersistentQueue(private val persistencePath: String, val name: String,
}
/**
+ * Return a transactionally-removed item to the queue. This is a rolled-
+ * back transaction.
+ */
+ def unremove(xid: Int): Unit = {
+ initialized.await
+ synchronized {
+ if (!closed) {
+ journal.unremove(xid)
+ _unremove(xid)
+ if (waiters.size > 0) {
+ waiters.remove(0).actor ! ItemArrived
+ }
+ }
+ }
+ }
+
+ def confirmRemove(xid: Int): Unit = {
+ initialized.await
+ synchronized {
+ if (!closed) {
+ journal.confirmRemove(xid)
+ openTransactions.removeKey(xid)
+ }
+ }
+ }
+
+ /**
* Close the queue's journal file. Not safe to call on an active queue.
*/
def close = synchronized {
@@ -286,127 +217,107 @@ class PersistentQueue(private val persistencePath: String, val name: String,
journal.close()
}
- def setup: Unit = synchronized {
+ def setup(): Unit = synchronized {
+ queueSize = 0
replayJournal
- initialized.set
+ initialized.countDown
}
+ private def nextXid(): Int = {
+ do {
+ xidCounter += 1
+ } while (openTransactions contains xidCounter)
+ xidCounter
+ }
- private def startReadBehind(pos: Long): Unit = {
- log.info("Dropping to read-behind for queue '%s' (%d bytes)", name, queueSize)
- val rj = new FileInputStream(queuePath)
- rj.getChannel.position(pos)
- readJournal = Some(rj)
+ def fillReadBehind(): Unit = {
+ journal.fillReadBehind { item =>
+ queue += item
+ _memoryBytes += item.data.length
+ }
}
- private def fillReadBehind(pos: Long): Unit = {
- for (rj <- readJournal) {
- if (rj.getChannel.position == pos) {
- // we've caught up.
- log.info("Coming out of read-behind for queue '%s'", name)
- rj.close
- readJournal = None
- } else {
- readJournalEntry(new DataInputStream(readJournal.get),
- new IntReader(ByteOrder.LITTLE_ENDIAN)) match {
- case JournalItem(CMD_ADDX, _, Some(item)) =>
- queue += item
- _memoryBytes += item.data.length
- case JournalItem(_, _, _) =>
+ def replayJournal(): Unit = {
+ log.info("Replaying transaction journal for '%s'", name)
+ xidCounter = 0
+
+ journal.replay(name) {
+ case JournalItem.Add(item) =>
+ _add(item)
+ // when processing the journal, this has to happen after:
+ if (!journal.inReadBehind && queueSize >= PersistentQueue.maxMemorySize) {
+ log.info("Dropping to read-behind for queue '%s' (%d bytes)", name, queueSize)
+ journal.startReadBehind
}
- }
+ case JournalItem.Remove => _remove(false)
+ case JournalItem.RemoveTentative => _remove(true)
+ case JournalItem.SavedXid(xid) => xidCounter = xid
+ case JournalItem.Unremove(xid) => _unremove(xid)
+ case JournalItem.ConfirmRemove(xid) => openTransactions.removeKey(xid)
+ case x => log.error("Unexpected item in journal: %s", x)
}
+ log.info("Finished transaction journal for '%s' (%d items, %d bytes)", name, queueLength,
+ journal.size)
+ journal.open
}
- private def openJournal: Unit = {
- journal = new FileOutputStream(queuePath, true)
- }
- private def rollJournal: Unit = {
- log.info("Rolling journal file for '%s'", name)
- journal.close
+ // ----- internal implementations
- val backupFile = new File(queuePath + "." + System.currentTimeMillis)
- new File(queuePath).renameTo(backupFile)
- openJournal
- _journalSize = 0
- backupFile.delete
+ private def _add(item: QItem): Unit = {
+ if (!journal.inReadBehind) {
+ queue += item
+ _memoryBytes += item.data.length
+ }
+ _totalItems += 1
+ queueSize += item.data.length
+ queueLength += 1
}
- private def replayJournal: Unit = {
- queueSize = 0
-
- try {
- val fileIn = new FileInputStream(queuePath)
- val in = new DataInputStream(fileIn)
- var offset: Long = 0
- val intReader = new IntReader(ByteOrder.LITTLE_ENDIAN)
-
- log.info("Replaying transaction journal for '%s'", name)
- var done = false
- do {
- readJournalEntry(in, intReader) match {
- case JournalItem(CMD_ADDX, length, Some(item)) =>
- if (!readJournal.isDefined) {
- queue += item
- _memoryBytes += item.data.length
- }
- queueSize += item.data.length
- queueLength += 1
- offset += length
- if (!readJournal.isDefined && queueSize >= PersistentQueue.maxMemorySize) {
- startReadBehind(fileIn.getChannel.position)
- }
- case JournalItem(CMD_REMOVE, length, _) =>
- val len = queue.dequeue.data.length
- queueSize -= len
- _memoryBytes -= len
- queueLength -= 1
- offset += length
- while (readJournal.isDefined && _memoryBytes < PersistentQueue.maxMemorySize) {
- fillReadBehind(fileIn.getChannel.position)
- }
- case JournalItem(-1, _, _) =>
- done = true
- }
- } while (!done)
- _journalSize = offset
- log.info("Finished transaction journal for '%s' (%d items, %d bytes)", name, queueLength, offset)
- } catch {
- case e: FileNotFoundException =>
- log.info("No transaction journal for '%s'; starting with empty queue.", name)
- case e: IOException =>
- log.error(e, "Exception replaying journal for '%s'", name)
- log.error("DATA MAY HAVE BEEN LOST!")
- // this can happen if the server hardware died abruptly in the middle
- // of writing a journal. not awesome but we should recover.
+ private def _remove(transaction: Boolean): Option[QItem] = {
+ if (queue.isEmpty) return None
+
+ val now = Time.now
+ val item = queue.dequeue
+ val len = item.data.length
+ queueSize -= len
+ _memoryBytes -= len
+ queueLength -= 1
+ val xid = if (transaction) nextXid else 0
+
+ // if we're in read-behind mode, scan forward in the journal to keep memory as full as
+ // possible. this amortizes the disk overhead across all reads.
+ while (journal.inReadBehind && _memoryBytes < PersistentQueue.maxMemorySize) {
+ fillReadBehind
+ if (!journal.inReadBehind) {
+ log.info("Coming out of read-behind for queue '%s'", name)
+ }
}
- openJournal
+ val realExpiry = adjustExpiry(item.expiry)
+ if ((realExpiry == 0) || (realExpiry >= now)) {
+ _currentAge = now - item.addTime
+ if (transaction) {
+ item.xid = xid
+ openTransactions(xid) = item
+ }
+ Some(item)
+ } else {
+ _totalExpired += 1
+ _remove(transaction)
+ }
}
- private def readJournalEntry(in: DataInputStream, intReader: IntReader): JournalItem = {
- in.read() match {
- case -1 =>
- JournalItem(-1, 0, None)
- case CMD_ADD =>
- val size = intReader.readInt(in)
- val data = new Array[Byte](size)
- in.readFully(data)
- JournalItem(CMD_ADDX, 5 + data.length, Some(unpackOldAdd(data)))
- case CMD_REMOVE =>
- JournalItem(CMD_REMOVE, 1, None)
- case CMD_ADDX =>
- val size = intReader.readInt(in)
- val data = new Array[Byte](size)
- in.readFully(data)
- JournalItem(CMD_ADDX, 5 + data.length, Some(unpack(data)))
- case n =>
- throw new IOException("invalid opcode in journal: " + n.toInt)
- }
+ private def _unremove(xid: Int) = {
+ val item = openTransactions.removeKey(xid).get
+ queueLength += 1
+ queueSize += item.data.length
+ queue unget item
+ _memoryBytes += item.data.length
}
}
+
object PersistentQueue {
@volatile var maxJournalSize: Long = 16 * 1024 * 1024
@volatile var maxMemorySize: Long = 128 * 1024 * 1024
View
38 src/main/scala/com/twitter/scarling/QueueCollection.scala
@@ -104,7 +104,7 @@ class QueueCollection(private val queueFolder: String, private var queueConfigs:
queue(key) match {
case None => false
case Some(q) =>
- val now = System.currentTimeMillis
+ val now = Time.now
val normalizedExpiry: Long = if (expiry == 0) {
0
} else if (expiry < 1000000) {
@@ -130,23 +130,23 @@ class QueueCollection(private val queueFolder: String, private var queueConfigs:
* Retrieve an item from a queue and pass it to a continuation. If no item is available within
* the requested time, or the server is shutting down, None is passed.
*/
- def remove(key: String, timeout: Int)(f: Option[Array[Byte]] => Unit): Unit = {
+ def remove(key: String, timeout: Int, transaction: Boolean)(f: Option[QItem] => Unit): Unit = {
queue(key) match {
case None =>
synchronized { _queueMisses += 1 }
f(None)
case Some(q) =>
- q.remove(if (timeout == 0) timeout else System.currentTimeMillis + timeout) {
+ q.remove(if (timeout == 0) timeout else Time.now + timeout, transaction) {
case None =>
synchronized { _queueMisses += 1 }
f(None)
- case item @ Some(x) =>
+ case Some(item) =>
synchronized {
_queueHits += 1
- _currentBytes -= x.length
+ _currentBytes -= item.data.length
_currentItems -= 1
}
- f(item)
+ f(Some(item))
}
}
}
@@ -155,14 +155,34 @@ class QueueCollection(private val queueFolder: String, private var queueConfigs:
def receive(key: String): Option[Array[Byte]] = {
var rv: Option[Array[Byte]] = None
val latch = new CountDownLatch(1)
- remove(key, 0) { v =>
- rv = v
- latch.countDown
+ remove(key, 0, false) {
+ case None =>
+ rv = None
+ latch.countDown
+ case Some(v) =>
+ rv = Some(v.data)
+ latch.countDown
}
latch.await
rv
}
+ def unremove(key: String, xid: Int): Unit = {
+ queue(key) match {
+ case None =>
+ case Some(q) =>
+ q.unremove(xid)
+ }
+ }
+
+ def confirmRemove(key: String, xid: Int): Unit = {
+ queue(key) match {
+ case None =>
+ case Some(q) =>
+ q.confirmRemove(xid)
+ }
+ }
+
case class Stats(items: Long, bytes: Long, totalItems: Long, journalSize: Long,
totalExpired: Long, currentAge: Long, memoryItems: Long, memoryBytes: Long)
View
5 src/main/scala/com/twitter/scarling/Scarling.scala
@@ -43,7 +43,7 @@ object Scarling {
var queues: QueueCollection = null
private val _expiryStats = new mutable.HashMap[String, Int]
- private val _startTime = System.currentTimeMillis
+ private val _startTime = Time.now
ByteBuffer.setUseDirectBuffers(false)
ByteBuffer.setAllocator(new SimpleByteBufferAllocator())
@@ -98,9 +98,10 @@ object Scarling {
acceptor.unbindAll
Scheduler.shutdown
acceptorExecutor.shutdown
+ // the line below causes a 1s pause in unit tests. :(
acceptorExecutor.awaitTermination(5, TimeUnit.SECONDS)
deathSwitch.countDown
}
- def uptime = (System.currentTimeMillis - _startTime) / 1000
+ def uptime = (Time.now - _startTime) / 1000
}
View
72 src/main/scala/com/twitter/scarling/ScarlingHandler.scala
@@ -19,6 +19,11 @@ class ScarlingHandler(val session: IoSession, val config: Config) extends Actor
private val sessionID = ScarlingStats.sessionID.incr
private val remoteAddress = session.getRemoteAddress.asInstanceOf[InetSocketAddress]
+ private var pendingTransaction: Option[(String, Int)] = None
+
+ // used internally to indicate a client error: tried to close a transaction on the wrong queue.
+ private class MismatchedQueueException extends Exception
+
if (session.getTransportType == TransportType.SOCKET) {
session.getConfig.asInstanceOf[SocketSessionConfig].setReceiveBufferSize(2048)
@@ -56,6 +61,7 @@ class ScarlingHandler(val session: IoSession, val config: Config) extends Actor
case MinaMessage.SessionClosed =>
log.debug("End of session %d", sessionID)
+ abortAnyTransaction
ScarlingStats.sessions.decr
exit()
@@ -103,20 +109,74 @@ class ScarlingHandler(val session: IoSession, val config: Config) extends Actor
private def get(name: String): Unit = {
var key = name
var timeout = 0
+ var closing = false
+ var opening = false
if (name contains '/') {
val options = name.split("/")
key = options(0)
for (i <- 1 until options.length) {
- val opt = options(1)
+ val opt = options(i)
if (opt startsWith "t=") {
timeout = opt.substring(2).toInt
}
+ if (opt == "close") closing = true
+ if (opt == "open") opening = true
+ }
+ }
+ log.debug("get q=%s t=%d open=%s close=%s", key, timeout, opening, closing)
+
+ try {
+ if (closing) {
+ if (!closeTransaction(key)) {
+ log.warning("Attempt to close a non-existent transaction on '%s' (sid %d, %s:%d)",
+ key, sessionID, remoteAddress.getHostName, remoteAddress.getPort)
+ writeResponse("ERROR\r\n")
+ session.close
+ } else {
+ writeResponse("END\r\n")
+ }
+ } else {
+ if (opening) closeTransaction(key)
+ ScarlingStats.getRequests.incr
+ Scarling.queues.remove(key, timeout, opening) {
+ case None =>
+ writeResponse("END\r\n")
+ case Some(item) =>
+ log.debug("get %s", item)
+ if (opening) pendingTransaction = Some((key, item.xid))
+ writeResponse("VALUE " + key + " 0 " + item.data.length + "\r\n", item.data)
+ }
}
+ } catch {
+ case e: MismatchedQueueException =>
+ log.warning("Attempt to close a transaction on the wrong queue '%s' (sid %d, %s:%d)",
+ key, sessionID, remoteAddress.getHostName, remoteAddress.getPort)
+ writeResponse("ERROR\r\n")
+ session.close
}
- ScarlingStats.getRequests.incr
- Scarling.queues.remove(key, timeout) {
- case None => writeResponse("END\r\n")
- case Some(data) => writeResponse("VALUE " + key + " 0 " + data.length + "\r\n", data)
+ }
+
+ // returns true if a transaction was actually closed.
+ private def closeTransaction(name: String): Boolean = {
+ pendingTransaction match {
+ case None => false
+ case Some((qname, xid)) =>
+ if (qname != name) {
+ throw new MismatchedQueueException
+ } else {
+ Scarling.queues.confirmRemove(qname, xid)
+ pendingTransaction = None
+ }
+ true
+ }
+ }
+
+ private def abortAnyTransaction() = {
+ pendingTransaction match {
+ case None =>
+ case Some((qname, xid)) =>
+ Scarling.queues.unremove(qname, xid)
+ pendingTransaction = None
}
}
@@ -132,7 +192,7 @@ class ScarlingHandler(val session: IoSession, val config: Config) extends Actor
private def stats = {
var report = new mutable.ArrayBuffer[(String, String)]
report += (("uptime", Scarling.uptime.toString))
- report += (("time", (System.currentTimeMillis / 1000).toString))
+ report += (("time", (Time.now / 1000).toString))
report += (("version", Scarling.runtime.jarVersion))
report += (("curr_items", Scarling.queues.currentItems.toString))
report += (("total_items", Scarling.queues.totalAdded.toString))
View
44 src/main/scala/com/twitter/scarling/Timer.scala
@@ -1,44 +0,0 @@
-package com.twitter.scarling
-
-import scala.collection.Map
-import scala.collection.mutable
-
-
-class Timer(val name: String, val reportAt: Int) {
- var counter = 0
- var total = 0
- var sum: Long = 0
-
- def add(timing: Long) = synchronized {
- counter += 1
- total += 1
- sum += timing
-
- if (counter == reportAt) {
- val average: Double = sum.asInstanceOf[Double] / total
- Console.println("TIMER " + name + " = " + (average / 1000) + " usec")
- counter = 0
- }
- }
-}
-
-
-object Timer {
- private val timers = new mutable.HashMap[String, Timer]
-
- def get(name: String, count: Int): Timer = synchronized {
- timers.get(name) match {
- case None => { timers(name) = new Timer(name, count); timers(name) }
- case Some(t) => t
- }
- }
-
- def run[T](name: String, count: Int)(f: => T) = {
- val timer = get(name, count)
- val startTime = System.nanoTime
- val result = f
- val timing = System.nanoTime - startTime
- timer.add(timing)
- result
- }
-}
View
85 src/test/scala/com/twitter/scarling/PersistentQueueSpec.scala
@@ -10,6 +10,20 @@ import org.specs._
object PersistentQueueSpec extends Specification with TestHelper {
+ def dumpJournal(folderName: String, qname: String): String = {
+ var rv = List[JournalItem]()
+ new Journal(new File(folderName, qname).getCanonicalPath).replay(qname) { item => rv = item :: rv }
+ rv.reverse map {
+ case JournalItem.Add(item) => "add(%s)".format(new String(item.data))
+ case JournalItem.Remove => "remove"
+ case JournalItem.RemoveTentative => "remove-tentative"
+ case JournalItem.SavedXid(xid) => "xid(%d)".format(xid)
+ case JournalItem.Unremove(xid) => "unremove(%d)".format(xid)
+ case JournalItem.ConfirmRemove(xid) => "confirm-remove(%d)".format(xid)
+ } mkString ", "
+ }
+
+
"PersistentQueue" should {
"add and remove one item" in {
withTempFolder {
@@ -28,7 +42,7 @@ object PersistentQueueSpec extends Specification with TestHelper {
q.bytes mustEqual 11
q.journalSize mustEqual 32
- new String(q.remove.get) mustEqual "hello kitty"
+ new String(q.remove.get.data) mustEqual "hello kitty"
q.length mustEqual 0
q.totalItems mustEqual 1
@@ -71,8 +85,8 @@ object PersistentQueueSpec extends Specification with TestHelper {
q.length mustEqual 0
q.totalItems mustEqual 2
q.bytes mustEqual 0
- q.journalSize mustEqual 0
- new File(folderName, "rolling").length mustEqual 0
+ q.journalSize mustEqual 5 // saved xid.
+ new File(folderName, "rolling").length mustEqual 5
PersistentQueue.maxJournalSize = 16 * 1024 * 1024
}
@@ -84,14 +98,14 @@ object PersistentQueueSpec extends Specification with TestHelper {
q.setup
q.add("first".getBytes)
q.add("second".getBytes)
- new String(q.remove.get) mustEqual "first"
+ new String(q.remove.get.data) mustEqual "first"
q.journalSize mustEqual 5 + 6 + 16 + 16 + 5 + 5 + 1
q.close
val q2 = new PersistentQueue(folderName, "rolling", Config.fromMap(Map.empty))
q2.setup
q2.journalSize mustEqual 5 + 6 + 16 + 16 + 5 + 5 + 1
- new String(q2.remove.get) mustEqual "second"
+ new String(q2.remove.get.data) mustEqual "second"
q2.journalSize mustEqual 5 + 6 + 16 + 16 + 5 + 5 + 1 + 1
q2.length mustEqual 0
q2.close
@@ -120,13 +134,13 @@ object PersistentQueueSpec extends Specification with TestHelper {
q.setup
q.add("sunny".getBytes) mustEqual true
q.length mustEqual 1
- Thread.sleep(1000)
+ Time.advance(1000)
q.remove mustEqual None
config("max_age") = 60
q.add("rainy".getBytes) mustEqual true
config("max_age") = 1
- Thread.sleep(1000)
+ Time.advance(1000)
q.remove mustEqual None
}
}
@@ -149,7 +163,7 @@ object PersistentQueueSpec extends Specification with TestHelper {
q.memoryBytes mustEqual 1024
// read 1 item. queue should pro-actively read the next item in from disk.
- val d0 = q.remove.get
+ val d0 = q.remove.get.data
d0(0) mustEqual 0
q.inReadBehind mustBe true
q.length mustEqual 9
@@ -168,7 +182,7 @@ object PersistentQueueSpec extends Specification with TestHelper {
q.memoryBytes mustEqual 1024
// read again.
- val d1 = q.remove.get
+ val d1 = q.remove.get.data
d1(0) mustEqual 1
q.inReadBehind mustBe true
q.length mustEqual 9
@@ -177,7 +191,7 @@ object PersistentQueueSpec extends Specification with TestHelper {
q.memoryBytes mustEqual 1024
// and again.
- val d2 = q.remove.get
+ val d2 = q.remove.get.data
d2(0) mustEqual 2
q.inReadBehind mustBe true
q.length mustEqual 8
@@ -186,7 +200,7 @@ object PersistentQueueSpec extends Specification with TestHelper {
q.memoryBytes mustEqual 1024
for (i <- 3 until 11) {
- val d = q.remove.get
+ val d = q.remove.get.data
d(0) mustEqual i
q.inReadBehind mustBe false
q.length mustEqual 10 - i
@@ -225,7 +239,7 @@ object PersistentQueueSpec extends Specification with TestHelper {
q2.memoryBytes mustEqual 1024
for (i <- 0 until 10) {
- val d = q2.remove.get
+ val d = q2.remove.get.data
d(0) mustEqual i
q2.inReadBehind mustEqual (i < 2)
q2.length mustEqual 9 - i
@@ -286,8 +300,8 @@ object PersistentQueueSpec extends Specification with TestHelper {
var rv: String = null
val latch = new CountDownLatch(1)
actor {
- q.remove(System.currentTimeMillis + 250) { item =>
- rv = new String(item.get)
+ q.remove(Time.now + 250, false) { item =>
+ rv = new String(item.get.data)
latch.countDown
}
}
@@ -295,5 +309,48 @@ object PersistentQueueSpec extends Specification with TestHelper {
rv mustEqual "hello"
}
}
+
+ "correctly interleave transactions in the journal" in {
+ withTempFolder {
+ PersistentQueue.maxMemorySize = 1024
+ val q = new PersistentQueue(folderName, "things", Config.fromMap(Map.empty))
+ q.setup
+ q.add("house".getBytes)
+ q.add("cat".getBytes)
+ q.journalSize mustEqual 2 * 21 + 8
+
+ val house = q.remove(true).get
+ new String(house.data) mustEqual "house"
+ house.xid mustEqual 1
+ q.journalSize mustEqual 2 * 21 + 8 + 1
+
+ val cat = q.remove(true).get
+ new String(cat.data) mustEqual "cat"
+ cat.xid mustEqual 2
+ q.journalSize mustEqual 2 * 21 + 8 + 1 + 1
+
+ q.unremove(house.xid)
+ q.journalSize mustEqual 2 * 21 + 8 + 1 + 1 + 5
+
+ q.confirmRemove(cat.xid)
+ q.journalSize mustEqual 2 * 21 + 8 + 1 + 1 + 5 + 5
+ q.length mustEqual 1
+ q.bytes mustEqual 5
+
+ new String(q.remove.get.data) mustEqual "house"
+ q.length mustEqual 0
+ q.bytes mustEqual 0
+
+ q.close
+ dumpJournal(folderName, "things") mustEqual
+ "add(house), add(cat), remove-tentative, remove-tentative, unremove(1), confirm-remove(2), remove"
+
+ // and journal is replayed correctly.
+ val q2 = new PersistentQueue(folderName, "things", Config.fromMap(Map.empty))
+ q2.setup
+ q2.length mustEqual 0
+ q2.bytes mustEqual 0
+ }
+ }
}
}
View
102 src/test/scala/com/twitter/scarling/ServerSpec.scala
@@ -66,19 +66,113 @@ object ServerSpec extends Specification with TestHelper {
val v = (Math.random * 0x7fffffff).toInt
val client = new TestClient("localhost", 22122)
client.get("test_set_with_expiry") mustEqual ""
- client.set("test_set_with_epxiry", (v + 2).toString, (System.currentTimeMillis / 1000).toInt) mustEqual "STORED"
+ client.set("test_set_with_expiry", (v + 2).toString, (Time.now / 1000).toInt) mustEqual "STORED"
client.set("test_set_with_expiry", v.toString) mustEqual "STORED"
- Thread.sleep(1000)
+ Time.advance(1000)
client.get("test_set_with_expiry") mustEqual v.toString
}
}
+ "commit a transactional get" in {
+ withTempFolder {
+ makeServer
+ val v = (Math.random * 0x7fffffff).toInt
+ val client = new TestClient("localhost", 22122)
+ client.set("commit", v.toString) mustEqual "STORED"
+
+ val client2 = new TestClient("localhost", 22122)
+ val client3 = new TestClient("localhost", 22122)
+ var stats = client3.stats
+ stats("queue_commit_items") mustEqual "1"
+ stats("queue_commit_total_items") mustEqual "1"
+ stats("queue_commit_bytes") mustEqual v.toString.length.toString
+
+ client2.get("commit/open") mustEqual v.toString
+ stats = client3.stats
+ stats("queue_commit_items") mustEqual "0"
+ stats("queue_commit_total_items") mustEqual "1"
+ stats("queue_commit_bytes") mustEqual "0"
+
+ client2.get("commit/close") mustEqual ""
+ stats = client3.stats
+ stats("queue_commit_items") mustEqual "0"
+ stats("queue_commit_total_items") mustEqual "1"
+ stats("queue_commit_bytes") mustEqual "0"
+
+ client2.disconnect
+ Thread.sleep(10)
+ stats = client3.stats
+ stats("queue_commit_items") mustEqual "0"
+ stats("queue_commit_total_items") mustEqual "1"
+ stats("queue_commit_bytes") mustEqual "0"
+ }
+ }
+
+ "auto-rollback a transaction on disconnect" in {
+ withTempFolder {
+ makeServer
+ val v = (Math.random * 0x7fffffff).toInt
+ val client = new TestClient("localhost", 22122)
+ client.set("auto-rollback", v.toString) mustEqual "STORED"
+
+ val client2 = new TestClient("localhost", 22122)
+ client2.get("auto-rollback/open") mustEqual v.toString
+ val client3 = new TestClient("localhost", 22122)
+ client3.get("auto-rollback") mustEqual ""
+ var stats = client3.stats
+ stats("queue_auto-rollback_items") mustEqual "0"
+ stats("queue_auto-rollback_total_items") mustEqual "1"
+ stats("queue_auto-rollback_bytes") mustEqual "0"
+
+ // oops, client2 dies before committing!
+ client2.disconnect
+ Thread.sleep(10)
+ stats = client3.stats
+ stats("queue_auto-rollback_items") mustEqual "1"
+ stats("queue_auto-rollback_total_items") mustEqual "1"
+ stats("queue_auto-rollback_bytes") mustEqual v.toString.length.toString
+
+ // subsequent fetch must get the same data item back.
+ client3.get("auto-rollback/open") mustEqual v.toString
+ stats = client3.stats
+ stats("queue_auto-rollback_items") mustEqual "0"
+ stats("queue_auto-rollback_total_items") mustEqual "1"
+ stats("queue_auto-rollback_bytes") mustEqual "0"
+ }
+ }
+
+ "auto-commit cycles of transactional gets" in {
+ withTempFolder {
+ makeServer
+ val v = (Math.random * 0x7fffffff).toInt
+ val client = new TestClient("localhost", 22122)
+ client.set("auto-commit", v.toString) mustEqual "STORED"
+ client.set("auto-commit", (v + 1).toString) mustEqual "STORED"
+ client.set("auto-commit", (v + 2).toString) mustEqual "STORED"
+
+ val client2 = new TestClient("localhost", 22122)
+ client2.get("auto-commit/open") mustEqual v.toString
+ client2.get("auto-commit/open") mustEqual (v + 1).toString
+ client2.get("auto-commit/open") mustEqual (v + 2).toString
+ client2.disconnect
+ Thread.sleep(10)
+
+ val client3 = new TestClient("localhost", 22122)
+ client3.get("auto-commit") mustEqual (v + 2).toString
+
+ var stats = client3.stats
+ stats("queue_auto-commit_items") mustEqual "0"
+ stats("queue_auto-commit_total_items") mustEqual "3"
+ stats("queue_auto-commit_bytes") mustEqual "0"
+ }
+ }
+
"age" in {
withTempFolder {
makeServer
val client = new TestClient("localhost", 22122)
client.set("test_age", "nibbler") mustEqual "STORED"
- Thread.sleep(1000)
+ Time.advance(1000)
client.get("test_age") mustEqual "nibbler"
client.stats.contains("queue_test_age_age") mustEqual true
client.stats("queue_test_age_age").toInt >= 1000 mustEqual true
@@ -104,7 +198,7 @@ object ServerSpec extends Specification with TestHelper {
client.set("test_log_rotation", v) mustEqual "STORED"
new File(folderName + "/test_log_rotation").length mustEqual 2 * (8192 + 16 + 5) + 1
client.get("test_log_rotation") mustEqual v
- new File(folderName + "/test_log_rotation").length mustEqual 0
+ new File(folderName + "/test_log_rotation").length mustEqual 5
new File(folderName).listFiles.length mustEqual 1
}
}
View
128 src/test/scala/com/twitter/scarling/TestClient.scala
@@ -8,81 +8,81 @@ import scala.collection.mutable
class TestClient(host: String, port: Int) {
- var socket: Socket = null
- var out: OutputStream = null
- var in: DataInputStream = null
+ var socket: Socket = null
+ var out: OutputStream = null
+ var in: DataInputStream = null
- connect
+ connect
- def connect = {
- socket = new Socket(host, port)
- out = socket.getOutputStream
- in = new DataInputStream(socket.getInputStream)
- }
+ def connect = {
+ socket = new Socket(host, port)
+ out = socket.getOutputStream
+ in = new DataInputStream(socket.getInputStream)
+ }
- def disconnect = {
- socket.close
- }
+ def disconnect = {
+ socket.close
+ }
- private def readline = {
- // this isn't meant to be efficient, just simple.
- val out = new StringBuilder
- var done = false
- while (!done) {
- val ch: Int = in.read
- if ((ch < 0) || (ch == 10)) {
- done = true
- } else if (ch != 13) {
- out += ch.toChar
- }
- }
- out.toString
+ private def readline = {
+ // this isn't meant to be efficient, just simple.
+ val out = new StringBuilder
+ var done = false
+ while (!done) {
+ val ch: Int = in.read
+ if ((ch < 0) || (ch == 10)) {
+ done = true
+ } else if (ch != 13) {
+ out += ch.toChar
+ }
}
+ out.toString
+ }
- def set(key: String, value: String): String = {
- out.write(("set " + key + " 0 0 " + value.length + "\r\n" + value + "\r\n").getBytes)
- readline
- }
+ def set(key: String, value: String): String = {
+ out.write(("set " + key + " 0 0 " + value.length + "\r\n" + value + "\r\n").getBytes)
+ readline
+ }
- def set(key: String, value: String, expiry: Int) = {
- out.write(("set " + key + " 0 " + expiry + " " + value.length + "\r\n" + value + "\r\n").getBytes)
- readline
- }
+ def set(key: String, value: String, expiry: Int) = {
+ out.write(("set " + key + " 0 " + expiry + " " + value.length + "\r\n" + value + "\r\n").getBytes)
+ readline
+ }
- def get(key: String): String = {
- out.write(("get " + key + "\r\n").getBytes)
- val line = readline
- if (line == "END") {
- return ""
- }
- // VALUE <name> <flags> <length>
- val len = line.split(" ")(3).toInt
- val buffer = new Array[Byte](len)
- in.readFully(buffer)
- readline
- readline // "END"
- new String(buffer)
+ def get(key: String): String = {
+ out.write(("get " + key + "\r\n").getBytes)
+ val line = readline
+ if (line == "END") {
+ return ""
}
+ // VALUE <name> <flags> <length>
+ val len = line.split(" ")(3).toInt
+ val buffer = new Array[Byte](len)
+ in.readFully(buffer)
+ readline
+ readline // "END"
+ new String(buffer)
+ }
- def add(key: String, value: String) = {
- out.write(("add " + key + " 0 0 " + value.length + "\r\n" + value + "\r\n").getBytes)
- readline
- }
+ def add(key: String, value: String) = {
+ out.write(("add " + key + " 0 0 " + value.length + "\r\n" + value + "\r\n").getBytes)
+ readline
+ }
- def stats: Map[String, String] = {
- out.write("stats\r\n".getBytes)
- var done = false
- val map = new mutable.HashMap[String, String]
- while (!done) {
- val line = readline
- if (line startsWith "STAT") {
- val args = line.split(" ")
- map(args(1)) = args(2)
- } else if (line == "END") {
- done = true
- }
- }
- map
+ def stats: Map[String, String] = {
+ out.write("stats\r\n".getBytes)
+ var done = false
+ val map = new mutable.HashMap[String, String]
+ while (!done) {
+ val line = readline
+ if (line startsWith "STAT") {
+ val args = line.split(" ")
+ map(args(1)) = args(2)
+ } else if (line == "END") {
+ done = true
+ }
}
+ map
+ }
}
Please sign in to comment.
Something went wrong with that request. Please try again.