Skip to content

Commit

Permalink
Custom hashCode methods for case classes.
Browse files Browse the repository at this point in the history
No boxing, no MODULE$ indirection.
  • Loading branch information
paulp committed May 9, 2012
1 parent 09f380d commit 0e197e8
Show file tree
Hide file tree
Showing 6 changed files with 214 additions and 45 deletions.
1 change: 1 addition & 0 deletions src/compiler/scala/reflect/internal/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,7 @@ trait Definitions extends reflect.api.StandardDefinitions {
// boxed classes
lazy val ObjectRefClass = requiredClass[scala.runtime.ObjectRef[_]]
lazy val VolatileObjectRefClass = requiredClass[scala.runtime.VolatileObjectRef[_]]
lazy val RuntimeStaticsModule = getRequiredModule("scala.runtime.Statics")
lazy val BoxesRunTimeModule = getRequiredModule("scala.runtime.BoxesRunTime")
lazy val BoxesRunTimeClass = BoxesRunTimeModule.moduleClass
lazy val BoxedNumberClass = getClass(sn.BoxedNumber)
Expand Down
35 changes: 34 additions & 1 deletion src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ trait SyntheticMethods extends ast.TreeDSL {
def forwardToRuntime(method: Symbol): Tree =
forwardMethod(method, getMember(ScalaRunTimeModule, method.name prepend "_"))(mkThis :: _)

def callStaticsMethod(name: String)(args: Tree*): Tree = {
val method = termMember(RuntimeStaticsModule, name)
Apply(gen.mkAttributedRef(method), args.toList)
}

// Any member, including private
def hasConcreteImpl(name: Name) =
clazz.info.member(name).alternatives exists (m => !m.isDeferred && !m.isSynthetic)
Expand Down Expand Up @@ -222,13 +227,41 @@ trait SyntheticMethods extends ast.TreeDSL {
)
}

def hashcodeImplementation(sym: Symbol): Tree = {
sym.tpe.finalResultType.typeSymbol match {
case UnitClass | NullClass => Literal(Constant(0))
case BooleanClass => If(Ident(sym), Literal(Constant(1231)), Literal(Constant(1237)))
case IntClass | ShortClass | ByteClass | CharClass => Ident(sym)
case LongClass => callStaticsMethod("longHash")(Ident(sym))
case DoubleClass => callStaticsMethod("doubleHash")(Ident(sym))
case FloatClass => callStaticsMethod("floatHash")(Ident(sym))
case _ => callStaticsMethod("anyHash")(Ident(sym))
}
}

def specializedHashcode = {
createMethod(nme.hashCode_, Nil, IntClass.tpe) { m =>
val accumulator = m.newVariable(newTermName("acc"), m.pos, SYNTHETIC) setInfo IntClass.tpe
val valdef = ValDef(accumulator, Literal(Constant(0xcafebabe)))
val mixes = accessors map (acc =>
Assign(
Ident(accumulator),
callStaticsMethod("mix")(Ident(accumulator), hashcodeImplementation(acc))
)
)
val finish = callStaticsMethod("finalizeHash")(Ident(accumulator), Literal(Constant(arity)))

Block(valdef :: mixes, finish)
}
}

def valueClassMethods = List(
Any_hashCode -> (() => hashCodeDerivedValueClassMethod),
Any_equals -> (() => equalsDerivedValueClassMethod)
)

def caseClassMethods = productMethods ++ productNMethods ++ Seq(
Object_hashCode -> (() => forwardToRuntime(Object_hashCode)),
Object_hashCode -> (() => specializedHashcode),
Object_toString -> (() => forwardToRuntime(Object_toString)),
Object_equals -> (() => equalsCaseClassMethod)
)
Expand Down
89 changes: 89 additions & 0 deletions src/library/scala/runtime/Statics.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package scala.runtime;

/** Not for public consumption. Usage by the runtime only.
*/

public final class Statics {
public static int mix(int hash, int data) {
int h = mixLast(hash, data);
h = Integer.rotateLeft(h, 13);
return h * 5 + 0xe6546b64;
}

public static int mixLast(int hash, int data) {
int k = data;

k *= 0xcc9e2d51;
k = Integer.rotateLeft(k, 15);
k *= 0x1b873593;

return hash ^ k;
}

public static int finalizeHash(int hash, int length) {
return avalanche(hash ^ length);
}

/** Force all bits of the hash to avalanche. Used for finalizing the hash. */
public static int avalanche(int h) {
h ^= h >>> 16;
h *= 0x85ebca6b;
h ^= h >>> 13;
h *= 0xc2b2ae35;
h ^= h >>> 16;

return h;
}

public static int longHash(long lv) {
if ((int)lv == lv)
return (int)lv;
else
return (int)(lv ^ (lv >>> 32));
}

public static int doubleHash(double dv) {
int iv = (int)dv;
if (iv == dv)
return iv;

float fv = (float)dv;
if (fv == dv)
return java.lang.Float.floatToIntBits(fv);

long lv = (long)dv;
if (lv == dv)
return (int)lv;

lv = Double.doubleToLongBits(dv);
return (int)(lv ^ (lv >>> 32));
}

public static int floatHash(float fv) {
int iv = (int)fv;
if (iv == fv)
return iv;

long lv = (long)fv;
if (lv == fv)
return (int)(lv^(lv>>>32));

return java.lang.Float.floatToIntBits(fv);
}

public static int anyHash(Object x) {
if (x == null)
return 0;

if (x instanceof java.lang.Long)
return longHash(((java.lang.Long)x).longValue());

if (x instanceof java.lang.Double)
return doubleHash(((java.lang.Double)x).doubleValue());

if (x instanceof java.lang.Float)
return floatHash(((java.lang.Float)x).floatValue());

return x.hashCode();
}
}
9 changes: 9 additions & 0 deletions test/files/run/caseClassHash.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Foo(true,-1,-1,d,-5,-10,500.0,500.0,List(),5.0)
Foo(true,-1,-1,d,-5,-10,500.0,500.0,List(),5)
1383698062
1383698062
true
## method 1: 1383698062
## method 2: 1383698062
Murmur 1: 1383698062
Murmur 2: 1383698062
37 changes: 37 additions & 0 deletions test/files/run/caseClassHash.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
case class Foo[T](a: Boolean, b: Byte, c: Short, d: Char, e: Int, f: Long, g: Double, h: Float, i: AnyRef, j: T) { }

object Test {
def mkFoo[T](x: T) = Foo[T](true, -1, -1, 100, -5, -10, 500d, 500f, Nil, x)

def main(args: Array[String]): Unit = {
val foo1 = mkFoo[Double](5.0d)
val foo2 = mkFoo[Long](5l)

List(foo1, foo2, foo1.##, foo2.##, foo1 == foo2) foreach println

println("## method 1: " + foo1.##)
println("## method 2: " + foo2.##)
println(" Murmur 1: " + scala.util.MurmurHash3.productHash(foo1))
println(" Murmur 2: " + scala.util.MurmurHash3.productHash(foo2))
}
}

object Timing {
var hash = 0
def mkFoo(i: Int) = Foo(i % 2 == 0, i.toByte, i.toShort, i.toChar, i, i, 1.1, 1.1f, this, this)

def main(args: Array[String]): Unit = {
val reps = if (args.isEmpty) 100000000 else args(0).toInt
val start = System.nanoTime

println("Warmup.")
1 to 10000 foreach mkFoo

hash = 0
1 to reps foreach (i => hash += mkFoo(i).##)

val end = System.nanoTime
println("hash = " + hash)
println("Elapsed: " + ((end - start) / 1e6) + " ms.")
}
}
Loading

0 comments on commit 0e197e8

Please sign in to comment.