From 86c76708444bf653b052d9e0ba27e4e0e7a0bacb Mon Sep 17 00:00:00 2001 From: Oscar Boykin Date: Mon, 25 Feb 2013 16:55:00 -0800 Subject: [PATCH] Cleans up intTimes --- .../twitter/algebird/AdjoinedUnitRing.scala | 37 +++---------- .../scala/com/twitter/algebird/Group.scala | 15 +++++ .../scala/com/twitter/algebird/Monoid.scala | 16 +++++- .../scala/com/twitter/algebird/Ring.scala | 10 ++++ .../com/twitter/algebird/Semigroup.scala | 55 ++++++++++++++++++- .../com/twitter/algebird/BaseProperties.scala | 21 +++++-- .../twitter/algebird/AdJoinedUnitRing.scala | 6 +- 7 files changed, 124 insertions(+), 36 deletions(-) diff --git a/algebird-core/src/main/scala/com/twitter/algebird/AdjoinedUnitRing.scala b/algebird-core/src/main/scala/com/twitter/algebird/AdjoinedUnitRing.scala index 7293ea8b8..68645b59f 100644 --- a/algebird-core/src/main/scala/com/twitter/algebird/AdjoinedUnitRing.scala +++ b/algebird-core/src/main/scala/com/twitter/algebird/AdjoinedUnitRing.scala @@ -21,7 +21,9 @@ import scala.annotation.tailrec * This is for the case where your Ring[T] is a Rng (i.e. there is no unit). * @see http://en.wikipedia.org/wiki/Pseudo-ring#Adjoining_an_identity_element */ -case class AdjoinedUnit[T](ones: BigInt, get: T) +case class AdjoinedUnit[T](ones: BigInt, get: T) { + def unwrap: Option[T] = if (ones == 0) Some(get) else None +} object AdjoinedUnit { def apply[T](item: T): AdjoinedUnit[T] = new AdjoinedUnit[T](BigInt(0), item) @@ -33,7 +35,7 @@ class AdjoinedUnitRing[T](implicit ring: Ring[T]) extends Ring[AdjoinedUnit[T]] val zero = AdjoinedUnit[T](ring.zero) override def isNonZero(it: AdjoinedUnit[T]) = - (it.ones != 0) && (ring.isNonZero(it.get)) + (it.ones != 0) || ring.isNonZero(it.get) def plus(left: AdjoinedUnit[T], right: AdjoinedUnit[T]) = AdjoinedUnit(left.ones + right.ones, ring.plus(left.get, right.get)) @@ -43,36 +45,13 @@ class AdjoinedUnitRing[T](implicit ring: Ring[T]) extends Ring[AdjoinedUnit[T]] override def minus(left: AdjoinedUnit[T], right: AdjoinedUnit[T]) = AdjoinedUnit(left.ones - right.ones, ring.minus(left.get, right.get)) - final def intTimes(i: BigInt, v: T): T = { - if(i < 0) { - intTimes(i, ring.negate(v)) - } - else if (i == 0) { - ring.zero - } - else if(i == 1) { - v - } - else { - // i * v == ((i/2) * v + (i/2)*v) + (1/0)*v - val half = i / 2 - val rem = i % 2 - val ht = intTimes(half, v) - val twoV = ring.plus(ht, ht) - if (rem == 0) { - twoV - } - else { - ring.plus(twoV, v) - } - } - } - def times(left: AdjoinedUnit[T], right: AdjoinedUnit[T]) = { // (n1, g1) * (n1, g2) = (n1*n2, (n2*g1) + (n2*g1) + g1*g2)) + import Group.intTimes + val ones = left.ones * right.ones - val part0 = intTimes(left.ones, right.get) - val part1 = intTimes(right.ones, left.get) + val part0 = intTimes(left.ones, right.get)(ring) + val part1 = intTimes(right.ones, left.get)(ring) val part2 = ring.times(left.get, right.get) val nonUnit = ring.plus(part0, ring.plus(part1, part2)) diff --git a/algebird-core/src/main/scala/com/twitter/algebird/Group.scala b/algebird-core/src/main/scala/com/twitter/algebird/Group.scala index 02ad2de4b..9c8789cdd 100755 --- a/algebird-core/src/main/scala/com/twitter/algebird/Group.scala +++ b/algebird-core/src/main/scala/com/twitter/algebird/Group.scala @@ -19,6 +19,7 @@ import java.lang.{Integer => JInt, Short => JShort, Long => JLong, Float => JFlo import java.util.{List => JList, Map => JMap} import scala.annotation.implicitNotFound +import scala.math.Equiv /** * Group: this is a monoid that also has subtraction (and negation): * So, you can do (a-b), or -a (which is equal to 0 - a). @@ -50,6 +51,19 @@ object Group extends GeneratedGroupImplicits { // This pattern is really useful for typeclasses def negate[T](x : T)(implicit grp : Group[T]) = grp.negate(x) def minus[T](l : T, r : T)(implicit grp : Group[T]) = grp.minus(l,r) + // nonZero and subtraction give an equiv, useful for Map[K,V] + def equiv[T](implicit grp: Group[T]): Equiv[T] = Equiv.fromFunction[T] { (a, b) => + !grp.isNonZero(grp.minus(a, b)) + } + /** Same as v + v + v .. + v (i times in total) */ + def intTimes[T](i: BigInt, v: T)(implicit grp: Group[T]): T = + if(i < 0) { + Monoid.intTimes(-i, grp.negate(v)) + } + else { + Monoid.intTimes(i, v)(grp) + } + implicit val nullGroup : Group[Null] = NullGroup implicit val unitGroup : Group[Unit] = UnitGroup @@ -60,6 +74,7 @@ object Group extends GeneratedGroupImplicits { implicit val shortGroup : Group[Short] = ShortRing implicit val jshortGroup : Group[JShort] = JShortRing implicit val longGroup : Group[Long] = LongRing + implicit val bigIntGroup : Group[BigInt] = BigIntRing implicit val jlongGroup : Group[JLong] = JLongRing implicit val floatGroup : Group[Float] = FloatField implicit val jfloatGroup : Group[JFloat] = JFloatField diff --git a/algebird-core/src/main/scala/com/twitter/algebird/Monoid.scala b/algebird-core/src/main/scala/com/twitter/algebird/Monoid.scala index 2f391760e..8c0f4ee0a 100755 --- a/algebird-core/src/main/scala/com/twitter/algebird/Monoid.scala +++ b/algebird-core/src/main/scala/com/twitter/algebird/Monoid.scala @@ -45,7 +45,7 @@ trait Monoid[@specialized(Int,Long,Float,Double) T] extends Semigroup[T] { None } } - @deprecated("Just use Monoid.sum") + // Override this if there is a more efficient means to implement this def sum(vs: TraversableOnce[T]): T = Monoid.sum(vs)(this) } @@ -133,6 +133,19 @@ object Monoid extends GeneratedMonoidImplicits { def plus(l:T, r:T) = associativeFn(l,r) } + /** Same as v + v + v .. + v (i times in total) + * requires i >= 0, wish we had NonnegativeBigInt as a class + */ + def intTimes[T](i: BigInt, v: T)(implicit mon: Monoid[T]): T = { + require(i >= 0, "Cannot do negative products with a Monoid, try Group.intTimes") + if (i == 0) { + mon.zero + } + else { + Semigroup.intTimes(i, v)(mon) + } + } + implicit val nullMonoid : Monoid[Null] = NullGroup implicit val unitMonoid : Monoid[Unit] = UnitGroup implicit val boolMonoid : Monoid[Boolean] = BooleanField @@ -141,6 +154,7 @@ object Monoid extends GeneratedMonoidImplicits { implicit val jintMonoid : Monoid[JInt] = JIntRing implicit val shortMonoid : Monoid[Short] = ShortRing implicit val jshortMonoid : Monoid[JShort] = JShortRing + implicit val bigIntMonoid : Monoid[BigInt] = BigIntRing implicit val longMonoid : Monoid[Long] = LongRing implicit val jlongMonoid : Monoid[JLong] = JLongRing implicit val floatMonoid : Monoid[Float] = FloatField diff --git a/algebird-core/src/main/scala/com/twitter/algebird/Ring.scala b/algebird-core/src/main/scala/com/twitter/algebird/Ring.scala index 28f6bafac..0182f6a0b 100755 --- a/algebird-core/src/main/scala/com/twitter/algebird/Ring.scala +++ b/algebird-core/src/main/scala/com/twitter/algebird/Ring.scala @@ -62,6 +62,15 @@ object LongRing extends Ring[Long] { override def times(l : Long, r : Long) = l * r } +object BigIntRing extends Ring[BigInt] { + override val zero = BigInt(0) + override val one = BigInt(1) + override def negate(v : BigInt) = -v + override def plus(l : BigInt, r : BigInt) = l + r + override def minus(l : BigInt, r : BigInt) = l - r + override def times(l : BigInt, r : BigInt) = l * r +} + object Ring extends GeneratedRingImplicits { // This pattern is really useful for typeclasses def one[T](implicit rng : Ring[T]) = rng.one @@ -85,6 +94,7 @@ object Ring extends GeneratedRingImplicits { implicit val shortRing : Ring[Short] = ShortRing implicit val jshortRing : Ring[JShort] = JShortRing implicit val longRing : Ring[Long] = LongRing + implicit val bigIntRing : Ring[BigInt] = BigIntRing implicit val jlongRing : Ring[JLong] = JLongRing implicit val floatRing : Ring[Float] = FloatField implicit val jfloatRing : Ring[JFloat] = JFloatField diff --git a/algebird-core/src/main/scala/com/twitter/algebird/Semigroup.scala b/algebird-core/src/main/scala/com/twitter/algebird/Semigroup.scala index ee78b167d..ffd31e8f5 100755 --- a/algebird-core/src/main/scala/com/twitter/algebird/Semigroup.scala +++ b/algebird-core/src/main/scala/com/twitter/algebird/Semigroup.scala @@ -18,7 +18,8 @@ package com.twitter.algebird import java.lang.{Integer => JInt, Short => JShort, Long => JLong, Float => JFloat, Double => JDouble, Boolean => JBool} import java.util.{List => JList, Map => JMap} -import scala.annotation.implicitNotFound +import scala.collection.mutable.{Map => MMap} +import scala.annotation.{implicitNotFound, tailrec} /** * Semigroup: @@ -71,6 +72,57 @@ object Semigroup extends GeneratedSemigroupImplicits { def from[T](associativeFn: (T,T) => T): Semigroup[T] = new Semigroup[T] { def plus(l:T, r:T) = associativeFn(l,r) } + /** Same as v + v + v .. + v (i times in total) + * requires i > 0, wish we had PositiveBigInt as a class + */ + def intTimes[T](i: BigInt, v: T)(implicit sg: Semigroup[T]): T = { + require(i > 0, "Cannot do non-positive products with a Semigroup, try Monoid/Group.intTimes") + intTimesRec(i-1, v, 0, (v, Vector[T]())) + } + + @tailrec + private def intTimesRec[T](i: BigInt, v: T, pow: Int, vaccMemo: (T,Vector[T]))(implicit sg: Semigroup[T]): T = { + if(i == 0) { + vaccMemo._1 + } + else { + /* i2 = i % 2 + * 2^pow(i*v) + acc == 2^(pow+1)((i/2)*v) + (acc + 2^pow i2 * v) + */ + val half = i / 2 + val rem = i % 2 + val newAccMemo = if (rem == 0) vaccMemo else { + val (res, newMemo) = timesPow2(pow, v, vaccMemo._2) + (sg.plus(vaccMemo._1, res), newMemo) + } + intTimesRec(half, v, pow + 1, newAccMemo) + } + } + + // Returns (2^power) * v = (2^(power - 1) v + 2^(power - 1) v) + private def timesPow2[T](power: Int, v: T, memo: Vector[T])(implicit sg: Semigroup[T]): (T, Vector[T]) = { + val size = memo.size + require(power >= 0, "power cannot be negative") + if(power == 0) { + (v, memo) + } + else if (power <= size) { + (memo(power-1), memo) + } + else { + var item = if(size == 0) v else memo.last + var pow = size + var newMemo = memo + while(pow < power) { + // x = 2*x + item = sg.plus(item, item) + pow += 1 + newMemo = newMemo :+ item + } + (item, newMemo) + } + } + implicit val nullSemigroup : Semigroup[Null] = NullGroup implicit val unitSemigroup : Semigroup[Unit] = UnitGroup implicit val boolSemigroup : Semigroup[Boolean] = BooleanField @@ -80,6 +132,7 @@ object Semigroup extends GeneratedSemigroupImplicits { implicit val shortSemigroup : Semigroup[Short] = ShortRing implicit val jshortSemigroup : Semigroup[JShort] = JShortRing implicit val longSemigroup : Semigroup[Long] = LongRing + implicit val bigIntSemigroup : Semigroup[BigInt] = BigIntRing implicit val jlongSemigroup : Semigroup[JLong] = JLongRing implicit val floatSemigroup : Semigroup[Float] = FloatField implicit val jfloatSemigroup : Semigroup[JFloat] = JFloatField diff --git a/algebird-test/src/main/scala/com/twitter/algebird/BaseProperties.scala b/algebird-test/src/main/scala/com/twitter/algebird/BaseProperties.scala index bb6777638..548caaac3 100644 --- a/algebird-test/src/main/scala/com/twitter/algebird/BaseProperties.scala +++ b/algebird-test/src/main/scala/com/twitter/algebird/BaseProperties.scala @@ -18,7 +18,7 @@ package com.twitter.algebird import org.scalacheck.{ Arbitrary, Properties } import org.scalacheck.Prop.forAll - +import scala.math.Equiv /** * Base properties useful for all tests using Algebird's typeclasses. */ @@ -43,6 +43,17 @@ object BaseProperties { isAssociativeEq[T](eqfn) && isCommutativeEq[T](eqfn) def commutativeSemigroupLaws[T : Semigroup : Arbitrary] = commutativeSemigroupLawsEq[T](defaultEq _) + def isNonZeroWorksMonoid[T:Monoid:Arbitrary:Equiv] = forAll { (a: T, b: T) => + val aIsLikeZero = Equiv[T].equiv(Monoid.plus(a,b), b) + Monoid.isNonZero(a) || aIsLikeZero + } + + def isNonZeroWorksRing[T:Ring:Arbitrary] = forAll { (a: T, b: T) => + implicit val monT: Monoid[T] = implicitly[Ring[T]] + val prodZero = !monT.isNonZero(Ring.times(a,b)) + (Monoid.isNonZero(a) && Monoid.isNonZero(b)) || prodZero + } + def weakZero[T : Monoid : Arbitrary] = forAll { (a : T) => val mon = implicitly[Monoid[T]] val zero = mon.zero @@ -57,7 +68,7 @@ object BaseProperties { } def validZero[T : Monoid : Arbitrary] = validZeroEq[T](defaultEq _) - def monoidLaws[T : Monoid : Arbitrary] = validZero[T] && isAssociative[T] + def monoidLaws[T : Monoid : Arbitrary] = validZero[T] && isAssociative[T] && isNonZeroWorksMonoid[T] def monoidLawsEq[T : Monoid : Arbitrary](eqfn : (T,T) => Boolean) = validZeroEq[T](eqfn) && isAssociativeEq[T](eqfn) def commutativeMonoidLawsEq[T : Monoid : Arbitrary](eqfn : (T,T) => Boolean) = @@ -94,12 +105,14 @@ object BaseProperties { rng.times(a, rng.times(b,c)) == rng.times(rng.times(a,b),c) } def pseudoRingLaws[T:Ring:Arbitrary] = - isDistributive[T] && timesIsAssociative[T] && groupLaws[T] && isCommutative[T] + isDistributive[T] && timesIsAssociative[T] && groupLaws[T] && isCommutative[T] && + isNonZeroWorksRing[T] def semiringLaws[T:Ring:Arbitrary] = isDistributive[T] && timesIsAssociative[T] && validOne[T] && commutativeMonoidLaws[T] && - zeroAnnihilates[T] + zeroAnnihilates[T] && + isNonZeroWorksRing[T] def ringLaws[T : Ring : Arbitrary] = validOne[T] && pseudoRingLaws[T] diff --git a/algebird-test/src/test/scala/com/twitter/algebird/AdJoinedUnitRing.scala b/algebird-test/src/test/scala/com/twitter/algebird/AdJoinedUnitRing.scala index 169b14b91..471338cee 100644 --- a/algebird-test/src/test/scala/com/twitter/algebird/AdJoinedUnitRing.scala +++ b/algebird-test/src/test/scala/com/twitter/algebird/AdJoinedUnitRing.scala @@ -18,6 +18,7 @@ package com.twitter.algebird import org.scalacheck.Arbitrary import org.scalacheck.Arbitrary.arbitrary +import org.scalacheck.Prop.forAll import org.scalacheck.Properties import org.scalacheck.Gen.choose @@ -27,7 +28,10 @@ object AdjoinedRingSpecification extends Properties("AdjoinedRing") { implicit def adjoined[T:Arbitrary]: Arbitrary[AdjoinedUnit[T]] = Arbitrary { implicitly[Arbitrary[T]].arbitrary.map { t => AdjoinedUnit(t) } } - + // AdjoinedUnit requires this method to be correct, so it is tested here: + property("intTimes works correctly") = forAll { (bi0: BigInt, bi1: BigInt) => + Group.intTimes(bi0, bi1) == (bi0 * bi1) + } property("AdjoinedUnit[Int] is a Ring") = ringLaws[AdjoinedUnit[Int]] property("AdjoinedUnit[Long] is a Ring") = ringLaws[AdjoinedUnit[Long]] }