Permalink
Browse files

Custom hashCode methods for case classes.

No boxing, no MODULE$ indirection.
  • Loading branch information...
paulp committed May 9, 2012
1 parent 09f380d commit 0e197e89ac96ec0dd8b136de8e07ad2e15f94371
@@ -896,6 +896,7 @@ trait Definitions extends reflect.api.StandardDefinitions {
// boxed classes
lazy val ObjectRefClass = requiredClass[scala.runtime.ObjectRef[_]]
lazy val VolatileObjectRefClass = requiredClass[scala.runtime.VolatileObjectRef[_]]
lazy val RuntimeStaticsModule = getRequiredModule("scala.runtime.Statics")
lazy val BoxesRunTimeModule = getRequiredModule("scala.runtime.BoxesRunTime")
lazy val BoxesRunTimeClass = BoxesRunTimeModule.moduleClass
lazy val BoxedNumberClass = getClass(sn.BoxedNumber)
@@ -89,6 +89,11 @@ trait SyntheticMethods extends ast.TreeDSL {
def forwardToRuntime(method: Symbol): Tree =
forwardMethod(method, getMember(ScalaRunTimeModule, method.name prepend "_"))(mkThis :: _)
def callStaticsMethod(name: String)(args: Tree*): Tree = {
val method = termMember(RuntimeStaticsModule, name)
Apply(gen.mkAttributedRef(method), args.toList)
}
// Any member, including private
def hasConcreteImpl(name: Name) =
clazz.info.member(name).alternatives exists (m => !m.isDeferred && !m.isSynthetic)
@@ -222,13 +227,41 @@ trait SyntheticMethods extends ast.TreeDSL {
)
}
def hashcodeImplementation(sym: Symbol): Tree = {
sym.tpe.finalResultType.typeSymbol match {
case UnitClass | NullClass => Literal(Constant(0))
case BooleanClass => If(Ident(sym), Literal(Constant(1231)), Literal(Constant(1237)))
case IntClass | ShortClass | ByteClass | CharClass => Ident(sym)
case LongClass => callStaticsMethod("longHash")(Ident(sym))
case DoubleClass => callStaticsMethod("doubleHash")(Ident(sym))
case FloatClass => callStaticsMethod("floatHash")(Ident(sym))
case _ => callStaticsMethod("anyHash")(Ident(sym))
}
}
def specializedHashcode = {
createMethod(nme.hashCode_, Nil, IntClass.tpe) { m =>
val accumulator = m.newVariable(newTermName("acc"), m.pos, SYNTHETIC) setInfo IntClass.tpe
val valdef = ValDef(accumulator, Literal(Constant(0xcafebabe)))
val mixes = accessors map (acc =>
Assign(
Ident(accumulator),
callStaticsMethod("mix")(Ident(accumulator), hashcodeImplementation(acc))
)
)
val finish = callStaticsMethod("finalizeHash")(Ident(accumulator), Literal(Constant(arity)))
Block(valdef :: mixes, finish)
}
}
def valueClassMethods = List(
Any_hashCode -> (() => hashCodeDerivedValueClassMethod),
Any_equals -> (() => equalsDerivedValueClassMethod)
)
def caseClassMethods = productMethods ++ productNMethods ++ Seq(
Object_hashCode -> (() => forwardToRuntime(Object_hashCode)),
Object_hashCode -> (() => specializedHashcode),
Object_toString -> (() => forwardToRuntime(Object_toString)),
Object_equals -> (() => equalsCaseClassMethod)
)
@@ -0,0 +1,89 @@
package scala.runtime;
/** Not for public consumption. Usage by the runtime only.
*/
public final class Statics {
public static int mix(int hash, int data) {
int h = mixLast(hash, data);
h = Integer.rotateLeft(h, 13);
return h * 5 + 0xe6546b64;
}
public static int mixLast(int hash, int data) {
int k = data;
k *= 0xcc9e2d51;
k = Integer.rotateLeft(k, 15);
k *= 0x1b873593;
return hash ^ k;
}
public static int finalizeHash(int hash, int length) {
return avalanche(hash ^ length);
}
/** Force all bits of the hash to avalanche. Used for finalizing the hash. */
public static int avalanche(int h) {
h ^= h >>> 16;
h *= 0x85ebca6b;
h ^= h >>> 13;
h *= 0xc2b2ae35;
h ^= h >>> 16;
return h;
}
public static int longHash(long lv) {
if ((int)lv == lv)
return (int)lv;
else
return (int)(lv ^ (lv >>> 32));
}
public static int doubleHash(double dv) {
int iv = (int)dv;
if (iv == dv)
return iv;
float fv = (float)dv;
if (fv == dv)
return java.lang.Float.floatToIntBits(fv);
long lv = (long)dv;
if (lv == dv)
return (int)lv;
lv = Double.doubleToLongBits(dv);
return (int)(lv ^ (lv >>> 32));
}
public static int floatHash(float fv) {
int iv = (int)fv;
if (iv == fv)
return iv;
long lv = (long)fv;
if (lv == fv)
return (int)(lv^(lv>>>32));
return java.lang.Float.floatToIntBits(fv);
}
public static int anyHash(Object x) {
if (x == null)
return 0;
if (x instanceof java.lang.Long)
return longHash(((java.lang.Long)x).longValue());
if (x instanceof java.lang.Double)
return doubleHash(((java.lang.Double)x).doubleValue());
if (x instanceof java.lang.Float)
return floatHash(((java.lang.Float)x).floatValue());
return x.hashCode();
}
}
@@ -0,0 +1,9 @@
Foo(true,-1,-1,d,-5,-10,500.0,500.0,List(),5.0)
Foo(true,-1,-1,d,-5,-10,500.0,500.0,List(),5)
1383698062
1383698062
true
## method 1: 1383698062
## method 2: 1383698062
Murmur 1: 1383698062
Murmur 2: 1383698062
@@ -0,0 +1,37 @@
case class Foo[T](a: Boolean, b: Byte, c: Short, d: Char, e: Int, f: Long, g: Double, h: Float, i: AnyRef, j: T) { }
object Test {
def mkFoo[T](x: T) = Foo[T](true, -1, -1, 100, -5, -10, 500d, 500f, Nil, x)
def main(args: Array[String]): Unit = {
val foo1 = mkFoo[Double](5.0d)
val foo2 = mkFoo[Long](5l)
List(foo1, foo2, foo1.##, foo2.##, foo1 == foo2) foreach println
println("## method 1: " + foo1.##)
println("## method 2: " + foo2.##)
println(" Murmur 1: " + scala.util.MurmurHash3.productHash(foo1))
println(" Murmur 2: " + scala.util.MurmurHash3.productHash(foo2))
}
}
object Timing {
var hash = 0
def mkFoo(i: Int) = Foo(i % 2 == 0, i.toByte, i.toShort, i.toChar, i, i, 1.1, 1.1f, this, this)
def main(args: Array[String]): Unit = {
val reps = if (args.isEmpty) 100000000 else args(0).toInt
val start = System.nanoTime
println("Warmup.")
1 to 10000 foreach mkFoo
hash = 0
1 to reps foreach (i => hash += mkFoo(i).##)
val end = System.nanoTime
println("hash = " + hash)
println("Elapsed: " + ((end - start) / 1e6) + " ms.")
}
}
Oops, something went wrong.

0 comments on commit 0e197e8

Please sign in to comment.