Permalink
Browse files

tests passing

  • Loading branch information...
1 parent 7f0c88e commit 698b9f0e7135cdb480c102d29607167aa8fa9f19 @ningtwitter ningtwitter committed Apr 3, 2012
View
@@ -69,6 +69,7 @@ Please report any bugs to: <https://github.com/twitter/cassovary/issues>
* [Aneesh Sharma](https://twitter.com/aneeshs)
* [Ashish Goel](https://twitter.com/ashishgoel)
* [Mengqiu Wang](https://twitter.com/4ad)
+* [Ning Liang](https://twitter.com/ningliang)
## License
Copyright 2012 Twitter, Inc.
@@ -13,9 +13,8 @@
*/
package com.twitter.cassovary.graph
-import it.unimi.dsi.fastutil.{Arrays, Swapper}
-import it.unimi.dsi.fastutil.objects.{Object2IntOpenHashMap, ObjectArrayPriorityQueue}
-import it.unimi.dsi.fastutil.ints.{Int2ObjectOpenHashMap, IntComparator}
+import it.unimi.dsi.fastutil.objects._
+import it.unimi.dsi.fastutil.ints._
import java.util.Comparator
@@ -55,13 +54,13 @@ class DirectedPathCollection {
}
/**
- * @return an array of top DirectedPaths by occurrence
- * ending at {@code node}
+ * @return an map of top DirectedPaths with occurrence
+ * ending at {@code node}, sorted decreasing
*/
- def topPathsTill(node: Int, num: Int): Array[DirectedPath] = {
+ def topPathsTill(node: Int, num: Int): Object2IntMap[DirectedPath] = {
val pathCountMap = pathCountsPerIdWithDefault(node)
val pathCount = pathCountMap.size
- val pathArray = new Array[DirectedPath](scala.math.min(num, pathCount))
+ val returnMap = new Object2IntArrayMap[DirectedPath]
priQ.synchronized {
comparator.setNode(node)
@@ -73,22 +72,21 @@ class DirectedPathCollection {
priQ.enqueue(path)
}
- var counter = 0
- while (counter < num && !priQ.isEmpty) {
- pathArray(counter) = priQ.dequeue()
- counter += 1
+ while (returnMap.size < num && !priQ.isEmpty) {
+ val path = priQ.dequeue()
+ returnMap.put(path, pathCountMap.get(path))
}
}
- pathArray
+ returnMap
}
/**
* @param num the number of top paths to return for a node
* @return an array of tuples, each containing a node and array of top paths ending at node, with scores
*/
- def topPathsPerNodeId(num: Int): Int2ObjectOpenHashMap[Array[DirectedPath]] = {
- val topPathMap = new Int2ObjectOpenHashMap[Array[DirectedPath]]
+ def topPathsPerNodeId(num: Int): Int2ObjectOpenHashMap[Object2IntMap[DirectedPath]] = {
+ val topPathMap = new Int2ObjectOpenHashMap[Object2IntMap[DirectedPath]]
val nodeIterator = pathCountsPerId.keySet.iterator
while (nodeIterator.hasNext) {
@@ -18,6 +18,7 @@ import com.twitter.cassovary.graph.tourist._
import com.twitter.ostrich.stats.Stats
import it.unimi.dsi.fastutil.ints.{Int2IntMap, Int2ObjectMap}
+import it.unimi.dsi.fastutil.objects.Object2IntMap
import net.lag.logging.Logger
import scala.util.Random
@@ -133,7 +134,7 @@ class GraphUtils(val graph: Graph) {
* in the form of (P as a {@link DirectedPath}, frequency of walking P).
*/
def calculatePersonalizedReputation(startNodeIds: Seq[Int], walkParams: RandomWalkParams):
- (Int2IntMap, Option[Int2ObjectMap[Array[DirectedPath]]]) = {
+ (Int2IntMap, Option[Int2ObjectMap[Object2IntMap[DirectedPath]]]) = {
Stats.time ("%s_total".format("PTC")) {
val (visitsCounter, pathsCounterOption) = randomWalk(walkParams.dir, startNodeIds, walkParams)
val topPathsOption = pathsCounterOption flatMap { counter => Some(counter.infoAllNodes) }
@@ -142,7 +143,7 @@ class GraphUtils(val graph: Graph) {
}
def calculatePersonalizedReputation(startNodeId: Int, walkParams: RandomWalkParams):
- (Int2IntMap, Option[Int2ObjectMap[Array[DirectedPath]]]) = {
+ (Int2IntMap, Option[Int2ObjectMap[Object2IntMap[DirectedPath]]]) = {
calculatePersonalizedReputation(Seq(startNodeId), walkParams)
}
@@ -15,6 +15,7 @@ package com.twitter.cassovary.graph.tourist
import com.twitter.cassovary.graph.{DirectedPath, DirectedPathCollection}
import it.unimi.dsi.fastutil.ints.Int2ObjectMap
+import it.unimi.dsi.fastutil.objects.Object2IntMap
/**
* A tourist that keeps track of the paths ending at each node. It keeps
@@ -25,7 +26,8 @@ import it.unimi.dsi.fastutil.ints.Int2ObjectMap
*/
class PathsCounter(numTopPathsPerNode: Int, homeNodeIds: Seq[Int])
- extends NodeTourist with InfoKeeper[Int, Array[DirectedPath], Int2ObjectMap[Array[DirectedPath]]] {
+ extends NodeTourist with InfoKeeper[Int, Object2IntMap[DirectedPath],
+ Int2ObjectMap[Object2IntMap[DirectedPath]]] {
def this() = this(0, Nil)
@@ -42,15 +44,15 @@ class PathsCounter(numTopPathsPerNode: Int, homeNodeIds: Seq[Int])
// NOOP use visit
}
- def infoOfNode(id: Int): Option[Array[DirectedPath]] = {
+ def infoOfNode(id: Int): Option[Object2IntMap[DirectedPath]] = {
if (paths.containsNode(id)) {
Some(paths.topPathsTill(id, numTopPathsPerNode))
} else {
None
}
}
- def infoAllNodes: Int2ObjectMap[Array[DirectedPath]] = paths.topPathsPerNodeId(numTopPathsPerNode)
+ def infoAllNodes: Int2ObjectMap[Object2IntMap[DirectedPath]] = paths.topPathsPerNodeId(numTopPathsPerNode)
def clear() {
paths.clear()
@@ -0,0 +1,49 @@
+/*
+ * Copyright 2012 Twitter, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this
+ * file except in compliance with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software distributed
+ * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
+ * CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ */
+
+package com.twitter.cassovary.graph.util
+
+import it.unimi.dsi.fastutil.ints.Int2IntMap
+import it.unimi.dsi.fastutil.objects.Object2IntMap
+
+object FastUtilConversion {
+
+ def object2IntMapToArray[T](map: Object2IntMap[T]): Array[(T, Int)] = {
+ val result = new Array[(T, Int)](map.size)
+
+ val iterator = map.keySet.iterator
+ var counter = 0
+ while (iterator.hasNext) {
+ val key = iterator.next
+ result(counter) = (key, map.getInt(key))
+ counter += 1
+ }
+
+ result
+ }
+
+ def int2IntMapToArray(map: Int2IntMap): Array[(Int, Int)] = {
+ val result = new Array[(Int, Int)](map.size)
+
+ val iterator = map.keySet.iterator
+ var counter = 0
+ while (iterator.hasNext) {
+ val key = iterator.nextInt
+ result(counter) = (key, map.get(key))
+ }
+
+ result
+ }
+
+}
@@ -13,6 +13,8 @@
*/
package com.twitter.cassovary.graph
+import com.twitter.cassovary.graph.util.FastUtilConversion
+import it.unimi.dsi.fastutil.objects.Object2IntMap
import org.specs.Specification
class DirectedPathCollectionSpec extends Specification {
@@ -36,9 +38,9 @@ class DirectedPathCollectionSpec extends Specification {
paths.resetCurrentPath()
addPath(paths, List(0,1,2))
- paths.topPathsTill(testPathIds(0), 10).toSeq mustEqual Array(getPath(0)).toSeq
- paths.topPathsTill(testPathIds(1), 10).toSeq mustEqual Array(getPath(0, 1)).toSeq
- paths.topPathsTill(testPathIds(2), 10).toSeq mustEqual Array(getPath(0, 1, 2)).toSeq
+ pathMapToSeq(paths.topPathsTill(testPathIds(0), 10)) mustEqual Array((getPath(0), times)).toSeq
+ pathMapToSeq(paths.topPathsTill(testPathIds(1), 10)) mustEqual Array((getPath(0, 1), times)).toSeq
+ pathMapToSeq(paths.topPathsTill(testPathIds(2), 10)) mustEqual Array((getPath(0, 1, 2), times)).toSeq
List(0,1,2) foreach { id =>
paths.numUniquePathsTill(testPathIds(id)) mustEqual 1
@@ -58,25 +60,29 @@ class DirectedPathCollectionSpec extends Specification {
paths.resetCurrentPath()
addPath(paths, List(1,2,3))
- //println(paths.topPathsTill(testPathIds(0), 10))
- //println(List((getPath(0), 1), (getPath(1, 0), 1)))
- paths.topPathsTill(testPathIds(0), 10).toSeq mustEqual Array(getPath(1, 0), getPath(0)).toSeq
- paths.topPathsTill(testPathIds(1), 10).toSeq mustEqual Array(getPath(1), getPath(0, 1)).toSeq
- paths.topPathsTill(testPathIds(2), 10).toSeq mustEqual Array(
- getPath(1, 2),
- getPath(1, 0, 3, 2),
- getPath(0, 1, 2)
+ pathMapToSeq(paths.topPathsTill(testPathIds(0), 10)) mustEqual Array(
+ (getPath(1, 0), 1),
+ (getPath(0), 1)
).toSeq
- paths.topPathsTill(testPathIds(3), 10).toSeq mustEqual Array(
- getPath(1, 2, 3),
- getPath(1, 0, 3, 2, 3),
- getPath(1, 0, 3)
+ pathMapToSeq(paths.topPathsTill(testPathIds(1), 10)) mustEqual Array(
+ (getPath(1), 3),
+ (getPath(0, 1), 1)
+ ).toSeq
+ pathMapToSeq(paths.topPathsTill(testPathIds(2), 10)) mustEqual Array(
+ (getPath(1, 2), 2),
+ (getPath(1, 0, 3, 2), 1),
+ (getPath(0, 1, 2), 1)
+ ).toSeq
+ pathMapToSeq(paths.topPathsTill(testPathIds(3), 10)) mustEqual Array(
+ (getPath(1, 2, 3), 2),
+ (getPath(1, 0, 3, 2, 3), 1),
+ (getPath(1, 0, 3), 1)
).toSeq
- paths.topPathsTill(testPathIds(2), 3).toSeq mustEqual Array(
- getPath(1, 2),
- getPath(1, 0, 3, 2),
- getPath(0, 1, 2)
+ pathMapToSeq(paths.topPathsTill(testPathIds(2), 3)) mustEqual Array(
+ (getPath(1, 2), 2),
+ (getPath(1, 0, 3, 2), 1),
+ (getPath(0, 1, 2), 1)
).toSeq
paths.numUniquePathsTill(testPathIds(0)) mustEqual 2
@@ -87,4 +93,8 @@ class DirectedPathCollectionSpec extends Specification {
paths.totalNumPaths mustEqual 10
}
}
+
+ def pathMapToSeq(map: Object2IntMap[DirectedPath]) = {
+ FastUtilConversion.object2IntMapToArray(map).toSeq
+ }
}
@@ -14,8 +14,10 @@
package com.twitter.cassovary.graph
import com.twitter.cassovary.graph.GraphDir._
+import com.twitter.cassovary.graph.util.FastUtilConversion
import com.twitter.util.Duration
import it.unimi.dsi.fastutil.ints.Int2IntMap
+import it.unimi.dsi.fastutil.objects.Object2IntMap
import org.specs.Specification
// TODO add a fake random so that the random walk tests can be controlled
@@ -61,8 +63,8 @@ class GraphUtilsSpec extends Specification {
visitsCountMap.get(2) mustEqual 1
val pathsCountMap = pathsCounterOption.get.infoAllNodes
- pathsCountMap.get(1).toSeq mustEqual Array(DirectedPath(Array(1))).toSeq
- pathsCountMap.get(2).toSeq mustEqual Array(DirectedPath(Array(1, 2))).toSeq
+ pathMapToSeq(pathsCountMap.get(1)) mustEqual Array((DirectedPath(Array(1)), 1)).toSeq
+ pathMapToSeq(pathsCountMap.get(2)) mustEqual Array((DirectedPath(Array(1, 2)), 1)).toSeq
// random walk but no top paths maintained
val (visitsCounter2, pathsCounterOption2) = graphUtils.randomWalk(OutDir, Seq(1),
@@ -214,7 +216,11 @@ class GraphUtilsSpec extends Specification {
}
}
- private def checkMapApproximatelyEquals(visitsPerNode: Int2IntMap, visitsPerNode2: Int2IntMap, delta: Int) {
+ def pathMapToSeq(map: Object2IntMap[DirectedPath]) = {
+ FastUtilConversion.object2IntMapToArray(map).toSeq
+ }
+
+ def checkMapApproximatelyEquals(visitsPerNode: Int2IntMap, visitsPerNode2: Int2IntMap, delta: Int) {
visitsPerNode.size mustEqual visitsPerNode2.size
val nodeIterator = visitsPerNode.keySet.iterator
@@ -14,6 +14,8 @@
package com.twitter.cassovary.graph
import com.twitter.cassovary.graph.tourist.{VisitsCounter, PathsCounter}
+import com.twitter.cassovary.graph.util.FastUtilConversion
+import it.unimi.dsi.fastutil.objects.Object2IntMap
import org.specs.Specification
class NodeTouristSpec extends Specification {
@@ -43,17 +45,21 @@ class NodeTouristSpec extends Specification {
}
val info = visitor.infoAllNodes
- info.get(1).toSeq mustEqual Array(DirectedPath(Array(1))).toSeq
- info.get(2).toSeq mustEqual Array(DirectedPath(Array(2))).toSeq
- info.get(3).toSeq mustEqual Array(
- DirectedPath(Array(2, 3)),
- DirectedPath(Array(2, 3, 4, 3)),
- DirectedPath(Array(1, 3))
+ pathMapToSeq(info.get(1)) mustEqual Array((DirectedPath(Array(1)), 5)).toSeq
+ pathMapToSeq(info.get(2)) mustEqual Array((DirectedPath(Array(2)), 3)).toSeq
+ pathMapToSeq(info.get(3)) mustEqual Array(
+ (DirectedPath(Array(2, 3)), 3),
+ (DirectedPath(Array(2, 3, 4, 3)), 1),
+ (DirectedPath(Array(1, 3)), 1)
).toSeq
- info.get(4).toSeq mustEqual Array(
- DirectedPath(Array(2, 3, 4)),
- DirectedPath(Array(1, 4))
+ pathMapToSeq(info.get(4)) mustEqual Array(
+ (DirectedPath(Array(2, 3, 4)), 2),
+ (DirectedPath(Array(1, 4)), 1)
).toSeq
}
}
+
+ def pathMapToSeq(map: Object2IntMap[DirectedPath]) = {
+ FastUtilConversion.object2IntMapToArray(map).toSeq
+ }
}

0 comments on commit 698b9f0

Please sign in to comment.