Skip to content

Commit

Permalink
HashMap bulk operations should retain existing keys
Browse files Browse the repository at this point in the history
... when the operand overwrites mapping with a key that is == but ne.

This is consistent with the super class implementation, which
was overridden in 2.12.11 for efficiency.

I also found a pair of ClassCastExceptions in the new implementations
of `HashMap.++:`, for instance:

```
case class C(a: Int)(override val toString: String); implicit val Ordering_C: Ordering[C] = Ordering.by(_.a); val c0l = C(0)("l"); val c0r = C(0)("r"); import collection.immutable._; println(HashMap((c0l, ())).++:(TreeMap((c0r, ()))))'; done
v2.12.10
Map(r -> ())
v2.12.11
java.lang.ClassCastException: scala.collection.immutable.HashMap$adder$1$ cannot be cast to scala.collection.immutable.HashMap
	at scala.collection.immutable.HashMap$adder$1$.<init>(HashMap.scala:215)
```
  • Loading branch information
retronym committed Jun 23, 2020
1 parent 141efea commit 4fe011f
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 55 deletions.
120 changes: 71 additions & 49 deletions src/library/scala/collection/immutable/HashMap.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ sealed class HashMap[A, +B] extends AbstractMap[A, B]
with CustomParallelizable[(A, B), ParHashMap[A, B]]
with HasForeachEntry[A, B]
{
import HashMap.{nullToEmpty, bufferSize}
import HashMap.{bufferSize, concatMerger, nullToEmpty}

override def size: Int = 0

Expand Down Expand Up @@ -159,7 +159,7 @@ sealed class HashMap[A, +B] extends AbstractMap[A, B]
override def values: scala.collection.Iterable[B] = new HashMapValues

override final def transform[W, That](f: (A, B) => W)(implicit bf: CanBuildFrom[HashMap[A, B], (A, W), That]): That =
if ((bf eq Map.canBuildFrom) || (bf eq HashMap.canBuildFrom)) transformImpl(f).asInstanceOf[That]
if ((bf eq Map.canBuildFrom) || (bf eq HashMap.canBuildFrom)) castToThat(transformImpl(f))
else super.transform(f)(bf)

/* `transform` specialized to return a HashMap */
Expand All @@ -179,17 +179,18 @@ sealed class HashMap[A, +B] extends AbstractMap[A, B]
override def ++[C >: (A, B), That](that: GenTraversableOnce[C])(implicit bf: CanBuildFrom[HashMap[A, B], C, That]): That = {
if (isCompatibleCBF(bf)) {
//here we know that That =:= HashMap[_, _], or compatible with it
if (this eq that.asInstanceOf[AnyRef]) that.asInstanceOf[That]
else if (that.isEmpty) this.asInstanceOf[That]
else that match {
case thatHash: HashMap[A, B] =>
//default Merge prefers to keep than replace
//so we merge from thatHash
(thatHash.merged(this) (null) ).asInstanceOf[That]
case that =>
var result: HashMap[Any, _] = this.asInstanceOf[HashMap[Any, _]]
that foreach { case kv: (_, _) => result = result + kv }
result.asInstanceOf[That]
if (this eq that.asInstanceOf[AnyRef]) castToThat(that)
else if (that.isEmpty) castToThat(this)
else {
val result: HashMap[A, B] = that match {
case thatHash: HashMap[A, B] =>
this.merge0(thatHash, 0, concatMerger[A, B])
case that =>
var result: HashMap[A, B] = this
that.asInstanceOf[GenTraversableOnce[(A, B)]] foreach { case kv: (_, _) => result = result + kv }
result
}
castToThat(result)
}
} else super.++(that)(bf)
}
Expand All @@ -203,42 +204,45 @@ sealed class HashMap[A, +B] extends AbstractMap[A, B]
if (isCompatibleCBF(bf)) addSimple(that)
else super.++:(that)
}
private def addSimple[C >: (A, B), That](that: TraversableOnce[C]): That = {
private def addSimple[C >: (A, B), That](that: TraversableOnce[C])(implicit bf: CanBuildFrom[HashMap[A, B], C, That]): That = {
//here we know that That =:= HashMap[_, _], or compatible with it
if (this eq that.asInstanceOf[AnyRef]) that.asInstanceOf[That]
else if (that.isEmpty) this.asInstanceOf[That]
else that match {
case thatHash: HashMap[A, B] =>
val merger: Merger[A, B] = HashMap.liftMerger[A, B](null)
// merger prefers to keep than replace
// so we invert
(this.merge0(thatHash, 0, merger.invert)).asInstanceOf[That]

case that:HasForeachEntry[A, B] =>
object adder extends Function2[A, B, Unit] {
var result: HashMap[A, B] = this.asInstanceOf[HashMap[A, B]]
val merger = HashMap.liftMerger[A, B](null)

override def apply(key: A, value: B): Unit = {
result = result.updated0(key, computeHash(key), 0, value, null, merger)
if (this eq that.asInstanceOf[AnyRef]) castToThat(that)
else if (that.isEmpty) castToThat(this)
else {
val merger = HashMap.concatMerger[A, B].invert
val result: HashMap[A, B] = that match {
case thatHash: HashMap[A, B] =>
this.merge0(thatHash, 0, HashMap.concatMerger[A, B].invert)

case that:HasForeachEntry[A, B] =>
object adder extends Function2[A, B, Unit] {
var result: HashMap[A, B] = HashMap.this
override def apply(key: A, value: B): Unit = {
result = result.updated0(key, computeHash(key), 0, value, null, merger)
}
}
}
that foreachEntry adder
adder.result.asInstanceOf[That]
case that =>
object adder extends Function1[(A,B), Unit] {
var result: HashMap[A, B] = this.asInstanceOf[HashMap[A, B]]
val merger = HashMap.liftMerger[A, B](null)

override def apply(kv: (A, B)): Unit = {
val key = kv._1
result = result.updated0(key, computeHash(key), 0, kv._2, kv, merger)
that foreachEntry adder
adder.result
case that =>
object adder extends Function1[(A,B), Unit] {
var result: HashMap[A, B] = HashMap.this
override def apply(kv: (A, B)): Unit = {
val key = kv._1
result = result.updated0(key, computeHash(key), 0, kv._2, kv, merger)
}
}
}
that.asInstanceOf[scala.Traversable[(A,B)]] foreach adder
adder.result.asInstanceOf[That]
that.asInstanceOf[scala.Traversable[(A,B)]] foreach adder
adder.result
}
castToThat(result)
}
}
private[this] def castToThat[C, That](m: HashMap[A, B])(implicit bf: CanBuildFrom[HashMap[A, B], C, That]): That = {
m.asInstanceOf[That]
}
private[this] def castToThat[C, That](m: GenTraversableOnce[C])(implicit bf: CanBuildFrom[HashMap[A, B], C, That]): That = {
m.asInstanceOf[That]
}
}

/** $factoryInfo
Expand All @@ -261,9 +265,10 @@ object HashMap extends ImmutableMapFactory[HashMap] with BitOperations.Int {
private type MergeFunction[A1, B1] = ((A1, B1), (A1, B1)) => (A1, B1)

private def liftMerger[A1, B1](mergef: MergeFunction[A1, B1]): Merger[A1, B1] =
if (mergef == null) defaultMerger.asInstanceOf[Merger[A1, B1]] else liftMerger0(mergef)
if (mergef == null) defaultMerger[A1, B1] else liftMerger0(mergef)

private val defaultMerger : Merger[Any, Any] = new Merger[Any, Any] {
private def defaultMerger[A, B]: Merger[A, B] = _defaultMerger.asInstanceOf[Merger[A, B]]
private[this] val _defaultMerger : Merger[Any, Any] = new Merger[Any, Any] {
override def apply(a: (Any, Any), b: (Any, Any)): (Any, Any) = a
override def retainIdentical: Boolean = true
override val invert: Merger[Any, Any] = new Merger[Any, Any] {
Expand All @@ -273,6 +278,23 @@ object HashMap extends ImmutableMapFactory[HashMap] with BitOperations.Int {
}
}

private def concatMerger[A, B]: Merger[A, B] = _concatMerger.asInstanceOf[Merger[A, B]]
private[this] val _concatMerger : Merger[Any, Any] = new Merger[Any, Any] {
override def apply(a: (Any, Any), b: (Any, Any)): (Any, Any) = {
if (a._1.asInstanceOf[AnyRef] eq b._1.asInstanceOf[AnyRef]) b
else (a._1, b._2)
}
override def retainIdentical: Boolean = true
override val invert: Merger[Any, Any] = new Merger[Any, Any] {
override def apply(a: (Any, Any), b: (Any, Any)): (Any, Any) = {
if (b._1.asInstanceOf[AnyRef] eq a._1.asInstanceOf[AnyRef]) a
else (b._1, a._2)
}
override def retainIdentical: Boolean = true
override def invert = concatMerger
}
}

private[this] def liftMerger0[A1, B1](mergef: MergeFunction[A1, B1]): Merger[A1, B1] = new Merger[A1, B1] {
self =>
def apply(kv1: (A1, B1), kv2: (A1, B1)): (A1, B1) = mergef(kv1, kv2)
Expand Down Expand Up @@ -1049,7 +1071,7 @@ object HashMap extends ImmutableMapFactory[HashMap] with BitOperations.Int {

/** The root node of the partially build hashmap */
private var rootNode: HashMap[A, B] = HashMap.empty
private def plusPlusMerger = liftMerger[A, B](null).invert

private def isMutable(hs: HashMap[A, B]) = {
hs.isInstanceOf[HashTrieMap[A, B]] && hs.size == -1
}
Expand Down Expand Up @@ -1249,11 +1271,11 @@ object HashMap extends ImmutableMapFactory[HashMap] with BitOperations.Int {
if (toNode eq toBeAdded) toNode
else toBeAdded match {
case bLeaf: HashMap1[A, B] =>
if (toNodeHash == bLeaf.hash) toNode.merge0(bLeaf, level, plusPlusMerger)
if (toNodeHash == bLeaf.hash) toNode.merge0(bLeaf, level, concatMerger[A, B])
else makeMutableTrie(toNode, bLeaf, level)

case bLeaf: HashMapCollision1[A, B] =>
if (toNodeHash == bLeaf.hash) toNode.merge0(bLeaf, level, plusPlusMerger)
if (toNodeHash == bLeaf.hash) toNode.merge0(bLeaf, level, concatMerger[A, B])
else makeMutableTrie(toNode, bLeaf, level)

case bTrie: HashTrieMap[A, B] =>
Expand Down
31 changes: 30 additions & 1 deletion test/junit/scala/collection/immutable/HashMapTest.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package scala.collection.immutable

import java.util.Collections

import org.junit.Assert._
import org.junit.Test
import org.junit.runner.RunWith
Expand Down Expand Up @@ -203,7 +205,7 @@ class HashMapTest extends AllocationTest {
"i" -> 9,
"j" -> 10
)
assertSame(nonEmpty2, nonAllocating(nonEmpty1 ++ nonEmpty2))
assertSame(nonEmpty1, nonAllocating(nonEmpty1 ++ nonEmpty2))
}

@Test
Expand Down Expand Up @@ -269,4 +271,31 @@ class HashMapTest extends AllocationTest {
//calls GenTraversableOnce.++
assertEquals(1, (m2 ++ m1).apply(2))
}

@Test
def retainLeft(): Unit = {
case class C(a: Int)(override val toString: String)
implicit val ordering: Ordering[C] = Ordering.by(_.a)
val c0l = C(0)("l")
val c0r = C(0)("r")
def assertIdenticalKeys(expected: Map[C, Unit], actual: Map[C, Unit]): Unit = {
val expected1, actual1 = Collections.newSetFromMap[C](new java.util.IdentityHashMap())
expected.keys.foreach(expected1.add)
actual.keys.foreach(actual1.add)
assertEquals(expected1, actual1)
}
assertIdenticalKeys(Map((c0l, ())), HashMap((c0l, ())).updated(c0r, ()))

def check(factory: Seq[(C, Unit)] => Map[C, Unit]): Unit = {
val c0LMap = factory(Seq((c0l, ())))
val c0RMap = factory(Seq((c0r, ())))
assertIdenticalKeys(Map((c0l, ())), HashMap((c0l, ())).++(c0RMap))
assertIdenticalKeys(Map((c0l, ())), HashMap.newBuilder[C, Unit].++=(HashMap((c0l, ()))).++=(c0RMap).result())
assertIdenticalKeys(Map((c0l, ())), HashMap((c0l, ())).++(c0RMap))
assertIdenticalKeys(Map((c0l, ())), c0LMap ++: HashMap((c0r, ())))
}
check(cs => HashMap(cs: _*)) // exercise special case for HashMap/HashMap
check(cs => TreeMap(cs: _*)) // exercise special case for HashMap/HasForEachEntry
check(cs => HashMap(cs: _*).withDefault(_ => ???)) // default cases
}
}
8 changes: 3 additions & 5 deletions test/junit/scala/collection/immutable/TreeMapTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ class TreeMapTest extends AllocationTest {
}

@Test
def unionAndIntersectRetainLeft(): Unit = {
def retainLeft(): Unit = {
case class C(a: Int)(override val toString: String)
implicit val ordering: Ordering[C] = Ordering.by(_.a)
val c0l = C(0)("l")
Expand All @@ -188,14 +188,12 @@ class TreeMapTest extends AllocationTest {
assertEquals(expected1, actual1)
}

// This holds in 2.13.x only
//assertIdenticalKeys(Map((c0l, ())), HashMap((c0l, ())).++(HashMap((c0r, ()))))
assertIdenticalKeys(Map((c0l, ())), HashMap((c0l, ())).++(HashMap((c0r, ()))))

assertIdenticalKeys(Map((c0l, ())), TreeMap((c0l, ())).++(HashMap((c0r, ()))))
assertIdenticalKeys(Map((c0l, ())), TreeMap((c0l, ())).++(TreeMap((c0r, ()))))

// This holds in 2.13.x only
//assertIdenticalKeys(Map((c0l, ())), HashMap.newBuilder[C, Unit].++=(HashMap((c0l, ()))).++=(HashMap((c0r, ()))).result())
assertIdenticalKeys(Map((c0l, ())), HashMap.newBuilder[C, Unit].++=(HashMap((c0l, ()))).++=(HashMap((c0r, ()))).result())

assertIdenticalKeys(Map((c0l, ())), TreeMap.newBuilder[C, Unit].++=(TreeMap((c0l, ()))).++=(HashMap((c0r, ()))).result())
assertIdenticalKeys(Map((c0l, ())), TreeMap.newBuilder[C, Unit].++=(TreeMap((c0l, ()))).++=(TreeMap((c0r, ()))).result())
Expand Down

0 comments on commit 4fe011f

Please sign in to comment.