diff --git a/src/library/scala/collection/concurrent/Map.scala b/src/library/scala/collection/concurrent/Map.scala index f330b40ebe54..2fd444035bf5 100644 --- a/src/library/scala/collection/concurrent/Map.scala +++ b/src/library/scala/collection/concurrent/Map.scala @@ -97,7 +97,7 @@ trait Map[K, V] extends scala.collection.mutable.Map[K, V] { case None => val v = op putIfAbsent(key, v) match { - case Some(nv) => nv + case Some(ov) => ov case None => v } } diff --git a/src/library/scala/collection/convert/JavaCollectionWrappers.scala b/src/library/scala/collection/convert/JavaCollectionWrappers.scala index 362816f7b399..49d6596b1a44 100644 --- a/src/library/scala/collection/convert/JavaCollectionWrappers.scala +++ b/src/library/scala/collection/convert/JavaCollectionWrappers.scala @@ -21,6 +21,7 @@ import java.{lang => jl, util => ju} import scala.jdk.CollectionConverters._ import scala.util.Try import scala.util.chaining._ +import scala.util.control.ControlThrowable /** Wrappers for exposing Scala collections as Java collections and vice-versa */ @SerialVersionUID(3L) @@ -332,7 +333,12 @@ private[collection] object JavaCollectionWrappers extends Serializable { else None } - override def getOrElseUpdate(key: K, op: => V): V = underlying.computeIfAbsent(key, _ => op) + + override def getOrElseUpdate(key: K, op: => V): V = + underlying.computeIfAbsent(key, _ => op) match { + case null => update(key, null.asInstanceOf[V]); null.asInstanceOf[V] + case v => v + } def addOne(kv: (K, V)): this.type = { underlying.put(kv._1, kv._2); this } def subtractOne(key: K): this.type = { underlying remove key; this } @@ -355,8 +361,17 @@ private[collection] object JavaCollectionWrappers extends Serializable { override def update(k: K, v: V): Unit = underlying.put(k, v) - override def updateWith(key: K)(remappingFunction: Option[V] => Option[V]): Option[V] = Option { - underlying.compute(key, (_, v) => remappingFunction(Option(v)).getOrElse(null.asInstanceOf[V])) + override def updateWith(key: K)(remappingFunction: Option[V] => Option[V]): Option[V] = { + def remap(k: K, v: V): V = + remappingFunction(Option(v)) match { + case Some(null) => throw PutNull + case Some(x) => x + case None => null.asInstanceOf[V] + } + try Option(underlying.compute(key, remap)) + catch { + case PutNull => update(key, null.asInstanceOf[V]); Some(null.asInstanceOf[V]) + } } // support Some(null) if currently bound to null @@ -441,7 +456,11 @@ private[collection] object JavaCollectionWrappers extends Serializable { override def get(k: K) = Option(underlying get k) - override def getOrElseUpdate(key: K, op: => V): V = underlying.computeIfAbsent(key, _ => op) + override def getOrElseUpdate(key: K, op: => V): V = + underlying.computeIfAbsent(key, _ => op) match { + case null => super/*[concurrent.Map]*/.getOrElseUpdate(key, op) + case v => v + } override def isEmpty: Boolean = underlying.isEmpty override def knownSize: Int = if (underlying.isEmpty) 0 else super.knownSize @@ -462,8 +481,17 @@ private[collection] object JavaCollectionWrappers extends Serializable { case _ => Try(last).toOption } - override def updateWith(key: K)(remappingFunction: Option[V] => Option[V]): Option[V] = Option { - underlying.compute(key, (_, v) => remappingFunction(Option(v)).getOrElse(null.asInstanceOf[V])) + override def updateWith(key: K)(remappingFunction: Option[V] => Option[V]): Option[V] = { + def remap(k: K, v: V): V = + remappingFunction(Option(v)) match { + case Some(null) => throw PutNull // see scala/scala#10129 + case Some(x) => x + case None => null.asInstanceOf[V] + } + try Option(underlying.compute(key, remap)) + catch { + case PutNull => super/*[concurrent.Map]*/.updateWith(key)(remappingFunction) + } } } @@ -572,4 +600,7 @@ private[collection] object JavaCollectionWrappers extends Serializable { override def mapFactory = mutable.HashMap } + + /** Thrown when certain Map operations attempt to put a null value. */ + private val PutNull = new ControlThrowable {} } diff --git a/test/junit/scala/collection/convert/MapWrapperTest.scala b/test/junit/scala/collection/convert/MapWrapperTest.scala index 55fbe4025065..521cec2410c4 100644 --- a/test/junit/scala/collection/convert/MapWrapperTest.scala +++ b/test/junit/scala/collection/convert/MapWrapperTest.scala @@ -2,12 +2,13 @@ package scala.collection.convert import java.{util => jutil} -import org.junit.Assert._ +import org.junit.Assert.{assertEquals, assertFalse, assertTrue} import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.JUnit4 import scala.jdk.CollectionConverters._ +import scala.tools.testkit.AssertUtil.assertThrows import scala.util.chaining._ @RunWith(classOf[JUnit4]) @@ -107,4 +108,70 @@ class MapWrapperTest { loki.done = true runner.join() } + @Test def `updateWith and getOrElseUpdate should reflect null policy of update`: Unit = { + val jmap = new jutil.concurrent.ConcurrentHashMap[String, String]() + val wrapped = jmap.asScala + assertThrows[NullPointerException](jmap.put("K", null)) + assertThrows[NullPointerException](jmap.putIfAbsent("K", null)) + assertThrows[NullPointerException](wrapped.put("K", null)) + assertThrows[NullPointerException](wrapped.update("K", null)) + assertThrows[NullPointerException](wrapped.updateWith("K")(_ => Some(null))) + assertThrows[NullPointerException](wrapped.getOrElseUpdate("K", null)) + + var count = 0 + def v = { + count += 1 + null + } + assertThrows[NullPointerException](wrapped.update("K", v)) + assertEquals(1, count) + assertThrows[NullPointerException](wrapped.updateWith("K")(_ => Some(v))) + assertEquals(3, count) // extra count in retry + } + @Test def `more updateWith and getOrElseUpdate should reflect null policy of update`: Unit = { + val jmap = new jutil.HashMap[String, String]() + val wrapped = jmap.asScala + wrapped.put("K", null) + assertEquals(1, wrapped.size) + wrapped.remove("K") + assertEquals(0, wrapped.size) + wrapped.update("K", null) + assertEquals(1, wrapped.size) + wrapped.remove("K") + wrapped.updateWith("K")(_ => Some(null)) + assertEquals(1, wrapped.size) + wrapped.remove("K") + wrapped.getOrElseUpdate("K", null) + assertEquals(1, wrapped.size) + + var count = 0 + def v = { + count += 1 + null + } + wrapped.update("K", v) + assertEquals(1, count) + wrapped.remove("K") + wrapped.updateWith("K")(_ => Some(v)) + assertEquals(2, count) + } + + @Test def `getOrElseUpdate / updateWith support should insert null`: Unit = { + val jmap = new jutil.HashMap[String, String]() + val wrapped = jmap.asScala + + wrapped.getOrElseUpdate("a", null) + assertTrue(jmap.containsKey("a")) + + wrapped.getOrElseUpdate(null, "x") + assertTrue(jmap.containsKey(null)) + + jmap.clear() + + wrapped.updateWith("b")(_ => Some(null)) + assertTrue(jmap.containsKey("b")) + + wrapped.updateWith(null)(_ => Some("x")) + assertTrue(jmap.containsKey(null)) + } }