From 6b8610dc1685f412cd002e4086153a83c97b5692 Mon Sep 17 00:00:00 2001 From: Peilin Yang Date: Fri, 30 Jun 2023 13:27:37 -0700 Subject: [PATCH] fix --- .../com/twitter/algebird/SketchMap.scala | 6 +++- .../com/twitter/algebird/SketchMapTest.scala | 34 +++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/algebird-core/src/main/scala/com/twitter/algebird/SketchMap.scala b/algebird-core/src/main/scala/com/twitter/algebird/SketchMap.scala index e327ed57c..1352bdb0e 100644 --- a/algebird-core/src/main/scala/com/twitter/algebird/SketchMap.scala +++ b/algebird-core/src/main/scala/com/twitter/algebird/SketchMap.scala @@ -18,6 +18,7 @@ package com.twitter.algebird import algebra.CommutativeMonoid import com.twitter.algebird.matrix.AdaptiveMatrix +import com.twitter.algebird.matrix.DenseMatrix /** * A Sketch Map is a generalized version of the Count-Min Sketch that is an approximation of Map[K, V] that @@ -50,7 +51,10 @@ class SketchMapMonoid[K, V](val params: SketchMapParams[K])(implicit SketchMap(AdaptiveMatrix.fill(params.depth, params.width)(monoid.zero), Nil, monoid.zero) override def plus(left: SketchMap[K, V], right: SketchMap[K, V]): SketchMap[K, V] = { - val newValuesTable = Monoid.plus(left.valuesTable, right.valuesTable) + val newValuesTable = right.valuesTable match { + case DenseMatrix(_, _, _) => Monoid.plus(right.valuesTable, left.valuesTable) + case _ => Monoid.plus(left.valuesTable, right.valuesTable) + } val newHeavyHitters = left.heavyHitterKeys.toSet ++ right.heavyHitterKeys SketchMap( diff --git a/algebird-test/src/test/scala/com/twitter/algebird/SketchMapTest.scala b/algebird-test/src/test/scala/com/twitter/algebird/SketchMapTest.scala index 86a4452d6..80e14c614 100644 --- a/algebird-test/src/test/scala/com/twitter/algebird/SketchMapTest.scala +++ b/algebird-test/src/test/scala/com/twitter/algebird/SketchMapTest.scala @@ -1,5 +1,7 @@ package com.twitter.algebird +import com.twitter.algebird.matrix.DenseMatrix +import com.twitter.algebird.matrix.SparseColumnMatrix import org.scalacheck.{Arbitrary, Gen} import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec @@ -53,6 +55,38 @@ class SketchMapTest extends AnyWordSpec with Matchers { assert(sm.totalValue == totalCount) } + "plus should work commutatively" in { + implicit val m = Monoid.longMonoid + val valueTableLeft = + DenseMatrix( + PARAMS.width, + PARAMS.depth, + rowsByColumns = IndexedSeq(10L) ++ (1 until PARAMS.width * PARAMS.depth).map(_ => 0L) + ) + val testLeft = SketchMap[Int, Long]( + valuesTable = valueTableLeft, + heavyHitterKeys = List(1), + totalValue = 10 + ) + + val valueTableRight = SparseColumnMatrix(rowsByColumns = + IndexedSeq( + SparseVector( + map = Map(1 -> 1L), + sparseValue = 1L, + length = PARAMS.width * PARAMS.depth + ) + ) + ) + val testRight = SketchMap[Int, Long]( + valuesTable = valueTableRight, + heavyHitterKeys = List(1), + totalValue = 1 + ) + + assert(MONOID.plus(testLeft, testRight) == MONOID.plus(testRight, testLeft)) + } + "exactly compute frequencies in a small stream" in { val one = MONOID.create((1, 1L)) val two = MONOID.create((2, 1L))