Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve null policy in wrapped Java Maps #10129

Merged
merged 1 commit into from Sep 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/library/scala/collection/concurrent/Map.scala
Expand Up @@ -97,7 +97,7 @@ trait Map[K, V] extends scala.collection.mutable.Map[K, V] {
case None =>
val v = op
putIfAbsent(key, v) match {
case Some(nv) => nv
case Some(ov) => ov
case None => v
}
}
Expand Down
43 changes: 37 additions & 6 deletions src/library/scala/collection/convert/JavaCollectionWrappers.scala
Expand Up @@ -21,6 +21,7 @@ import java.{lang => jl, util => ju}
import scala.jdk.CollectionConverters._
import scala.util.Try
import scala.util.chaining._
import scala.util.control.ControlThrowable

/** Wrappers for exposing Scala collections as Java collections and vice-versa */
@SerialVersionUID(3L)
Expand Down Expand Up @@ -332,7 +333,12 @@ private[collection] object JavaCollectionWrappers extends Serializable {
else
None
}
override def getOrElseUpdate(key: K, op: => V): V = underlying.computeIfAbsent(key, _ => op)

override def getOrElseUpdate(key: K, op: => V): V =
underlying.computeIfAbsent(key, _ => op) match {
case null => update(key, null.asInstanceOf[V]); null.asInstanceOf[V]
case v => v
}

def addOne(kv: (K, V)): this.type = { underlying.put(kv._1, kv._2); this }
def subtractOne(key: K): this.type = { underlying remove key; this }
Expand All @@ -355,8 +361,17 @@ private[collection] object JavaCollectionWrappers extends Serializable {

override def update(k: K, v: V): Unit = underlying.put(k, v)

override def updateWith(key: K)(remappingFunction: Option[V] => Option[V]): Option[V] = Option {
underlying.compute(key, (_, v) => remappingFunction(Option(v)).getOrElse(null.asInstanceOf[V]))
override def updateWith(key: K)(remappingFunction: Option[V] => Option[V]): Option[V] = {
def remap(k: K, v: V): V =
remappingFunction(Option(v)) match {
case Some(null) => throw PutNull
case Some(x) => x
case None => null.asInstanceOf[V]
}
try Option(underlying.compute(key, remap))
catch {
case PutNull => update(key, null.asInstanceOf[V]); Some(null.asInstanceOf[V])
}
}

// support Some(null) if currently bound to null
Expand Down Expand Up @@ -441,7 +456,11 @@ private[collection] object JavaCollectionWrappers extends Serializable {

override def get(k: K) = Option(underlying get k)

override def getOrElseUpdate(key: K, op: => V): V = underlying.computeIfAbsent(key, _ => op)
override def getOrElseUpdate(key: K, op: => V): V =
underlying.computeIfAbsent(key, _ => op) match {
case null => super/*[concurrent.Map]*/.getOrElseUpdate(key, op)
case v => v
}

override def isEmpty: Boolean = underlying.isEmpty
override def knownSize: Int = if (underlying.isEmpty) 0 else super.knownSize
Expand All @@ -462,8 +481,17 @@ private[collection] object JavaCollectionWrappers extends Serializable {
case _ => Try(last).toOption
}

override def updateWith(key: K)(remappingFunction: Option[V] => Option[V]): Option[V] = Option {
underlying.compute(key, (_, v) => remappingFunction(Option(v)).getOrElse(null.asInstanceOf[V]))
override def updateWith(key: K)(remappingFunction: Option[V] => Option[V]): Option[V] = {
def remap(k: K, v: V): V =
remappingFunction(Option(v)) match {
case Some(null) => throw PutNull // see scala/scala#10129
case Some(x) => x
case None => null.asInstanceOf[V]
}
try Option(underlying.compute(key, remap))
catch {
case PutNull => super/*[concurrent.Map]*/.updateWith(key)(remappingFunction)
}
}
}

Expand Down Expand Up @@ -572,4 +600,7 @@ private[collection] object JavaCollectionWrappers extends Serializable {

override def mapFactory = mutable.HashMap
}

/** Thrown when certain Map operations attempt to put a null value. */
private val PutNull = new ControlThrowable {}
}
69 changes: 68 additions & 1 deletion test/junit/scala/collection/convert/MapWrapperTest.scala
Expand Up @@ -2,12 +2,13 @@ package scala.collection.convert

import java.{util => jutil}

import org.junit.Assert._
import org.junit.Assert.{assertEquals, assertFalse, assertTrue}
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.JUnit4

import scala.jdk.CollectionConverters._
import scala.tools.testkit.AssertUtil.assertThrows
import scala.util.chaining._

@RunWith(classOf[JUnit4])
Expand Down Expand Up @@ -107,4 +108,70 @@ class MapWrapperTest {
loki.done = true
runner.join()
}
@Test def `updateWith and getOrElseUpdate should reflect null policy of update`: Unit = {
val jmap = new jutil.concurrent.ConcurrentHashMap[String, String]()
val wrapped = jmap.asScala
assertThrows[NullPointerException](jmap.put("K", null))
assertThrows[NullPointerException](jmap.putIfAbsent("K", null))
assertThrows[NullPointerException](wrapped.put("K", null))
assertThrows[NullPointerException](wrapped.update("K", null))
assertThrows[NullPointerException](wrapped.updateWith("K")(_ => Some(null)))
assertThrows[NullPointerException](wrapped.getOrElseUpdate("K", null))

var count = 0
def v = {
count += 1
null
}
assertThrows[NullPointerException](wrapped.update("K", v))
assertEquals(1, count)
assertThrows[NullPointerException](wrapped.updateWith("K")(_ => Some(v)))
assertEquals(3, count) // extra count in retry
}
@Test def `more updateWith and getOrElseUpdate should reflect null policy of update`: Unit = {
val jmap = new jutil.HashMap[String, String]()
val wrapped = jmap.asScala
wrapped.put("K", null)
assertEquals(1, wrapped.size)
wrapped.remove("K")
assertEquals(0, wrapped.size)
wrapped.update("K", null)
assertEquals(1, wrapped.size)
wrapped.remove("K")
wrapped.updateWith("K")(_ => Some(null))
assertEquals(1, wrapped.size)
wrapped.remove("K")
wrapped.getOrElseUpdate("K", null)
assertEquals(1, wrapped.size)

var count = 0
def v = {
count += 1
null
}
wrapped.update("K", v)
assertEquals(1, count)
wrapped.remove("K")
wrapped.updateWith("K")(_ => Some(v))
assertEquals(2, count)
}

@Test def `getOrElseUpdate / updateWith support should insert null`: Unit = {
val jmap = new jutil.HashMap[String, String]()
val wrapped = jmap.asScala

wrapped.getOrElseUpdate("a", null)
assertTrue(jmap.containsKey("a"))

wrapped.getOrElseUpdate(null, "x")
assertTrue(jmap.containsKey(null))

jmap.clear()

wrapped.updateWith("b")(_ => Some(null))
assertTrue(jmap.containsKey("b"))

wrapped.updateWith(null)(_ => Some("x"))
assertTrue(jmap.containsKey(null))
}
}