diff --git a/src/library/scala/collection/convert/JavaCollectionWrappers.scala b/src/library/scala/collection/convert/JavaCollectionWrappers.scala index 29c3dcbac5db..f8bc1f670e03 100644 --- a/src/library/scala/collection/convert/JavaCollectionWrappers.scala +++ b/src/library/scala/collection/convert/JavaCollectionWrappers.scala @@ -34,18 +34,33 @@ private[collection] object JavaCollectionWrappers extends Serializable { def hasMoreElements = underlying.hasNext def nextElement() = underlying.next() override def remove() = throw new UnsupportedOperationException + override def equals(other: Any): Boolean = other match { + case that: IteratorWrapper[_] => this.underlying == that.underlying + case _ => false + } + override def hashCode: Int = underlying.hashCode() } @SerialVersionUID(3L) class JIteratorWrapper[A](val underlying: ju.Iterator[A]) extends AbstractIterator[A] with Iterator[A] with Serializable { def hasNext = underlying.hasNext def next() = underlying.next + override def equals(other: Any): Boolean = other match { + case that: JIteratorWrapper[_] => this.underlying == that.underlying + case _ => false + } + override def hashCode: Int = underlying.hashCode() } @SerialVersionUID(3L) class JEnumerationWrapper[A](val underlying: ju.Enumeration[A]) extends AbstractIterator[A] with Iterator[A] with Serializable { def hasNext = underlying.hasMoreElements def next() = underlying.nextElement + override def equals(other: Any): Boolean = other match { + case that: JEnumerationWrapper[_] => this.underlying == that.underlying + case _ => false + } + override def hashCode: Int = underlying.hashCode() } trait IterableWrapperTrait[A] extends ju.AbstractCollection[A] { @@ -57,13 +72,11 @@ private[collection] object JavaCollectionWrappers extends Serializable { @SerialVersionUID(3L) class IterableWrapper[A](val underlying: Iterable[A]) extends ju.AbstractCollection[A] with IterableWrapperTrait[A] with Serializable { - import scala.runtime.Statics._ - override def equals(other: Any): Boolean = - other match { - case other: IterableWrapper[_] => underlying.equals(other.underlying) - case _ => false - } - override def hashCode = finalizeHash(mix(mix(0xcafebabe, "IterableWrapper".hashCode), anyHash(underlying)), 1) + override def equals(other: Any): Boolean = other match { + case that: IterableWrapper[_] => this.underlying == that.underlying + case _ => false + } + override def hashCode: Int = underlying.hashCode() } @SerialVersionUID(3L) @@ -74,6 +87,11 @@ private[collection] object JavaCollectionWrappers extends Serializable { def iterator = underlying.iterator.asScala override def iterableFactory = mutable.ArrayBuffer override def isEmpty: Boolean = !underlying.iterator().hasNext + override def equals(other: Any): Boolean = other match { + case that: JIterableWrapper[_] => this.underlying == that.underlying + case _ => false + } + override def hashCode: Int = underlying.hashCode() } @SerialVersionUID(3L) @@ -86,6 +104,11 @@ private[collection] object JavaCollectionWrappers extends Serializable { override def knownSize: Int = if (underlying.isEmpty) 0 else super.knownSize override def isEmpty = underlying.isEmpty override def iterableFactory = mutable.ArrayBuffer + override def equals(other: Any): Boolean = other match { + case that: JCollectionWrapper[_] => this.underlying == that.underlying + case _ => false + } + override def hashCode: Int = underlying.hashCode() } @SerialVersionUID(3L) @@ -254,7 +277,7 @@ private[collection] object JavaCollectionWrappers extends Serializable { def getKey = k def getValue = v def setValue(v1 : V) = self.put(k, v1) - + // It's important that this implementation conform to the contract // specified in the javadocs of java.util.Map.Entry.hashCode // @@ -529,6 +552,13 @@ private[collection] object JavaCollectionWrappers extends Serializable { } catch { case ex: ClassCastException => null.asInstanceOf[V] } + + override def equals(other: Any): Boolean = other match { + case that: DictionaryWrapper[_, _] => this.underlying == that.underlying + case _ => false + } + + override def hashCode: Int = underlying.hashCode() } @SerialVersionUID(3L) diff --git a/test/junit/scala/collection/convert/EqualsTest.scala b/test/junit/scala/collection/convert/EqualsTest.scala index b3f9ae17176b..c1622f3b87e1 100644 --- a/test/junit/scala/collection/convert/EqualsTest.scala +++ b/test/junit/scala/collection/convert/EqualsTest.scala @@ -4,11 +4,23 @@ package scala.collection.convert import org.junit.Test import org.junit.Assert._ +import java.util.{ + AbstractList, + AbstractMap, + AbstractSet, + Collections, + Collection => JCollection, + HashSet => JHashSet, + List => JList, + Map => JMap, + Set => JSet +} +import java.lang.{Iterable => JIterable} +import java.util.concurrent.{ConcurrentHashMap => JCMap} +import scala.collection.{AbstractIterable, concurrent, mutable} import scala.jdk.CollectionConverters._ import JavaCollectionWrappers._ -import java.util.{AbstractList, AbstractSet, List => JList, Set => JSet} - class JTestList(vs: Int*) extends AbstractList[Int] { def this() = this(Nil: _*) override def size = vs.size @@ -21,58 +33,134 @@ class JTestSet(vs: Int*) extends AbstractSet[Int] { override def iterator = vs.iterator.asJava } +object JTestMap { + case class JTestMapEntry(key: Int, value: String) extends JMap.Entry[Int, String] { + override def getKey: Int = key + override def getValue: String = value + override def setValue(value: String): String = + throw new UnsupportedOperationException("Cannot set value on JTestMapEntry") + } +} + +class JTestMap(vs: (Int, String)*) extends AbstractMap[Int, String] { + import JTestMap._ + override def entrySet(): JSet[JMap.Entry[Int, String]] = { + val entrySet = new JHashSet[JMap.Entry[Int, String]](vs.size); + vs.foreach { case (k, v) => entrySet.add(JTestMapEntry(k, v)) } + entrySet + } +} + /** Test that collection wrappers forward equals and hashCode where appropriate. */ class EqualsTest { - def jlstOf(vs: Int*): JList[Int] = new JTestList(vs: _*) - def jsetOf(vs: Int*): JSet[Int] = new JTestSet(vs: _*) - - // Seq extending AbstractList inherits equals + def jListOf(vs: Int*): JList[Int] = new JTestList(vs: _*) + def jSetOf(vs: Int*): JSet[Int] = new JTestSet(vs: _*) + def jMapOf(vs: (Int, String)*): JMap[Int, String] = new JTestMap(vs: _*) - @Test def `List as JList has equals`: Unit = { - val list = List(1, 2, 3) - val jlst = new SeqWrapper(list) - assertEquals(jlstOf(1, 2, 3), jlst) - assertEquals(jlst, jlstOf(1, 2, 3)) - assertTrue(jlst == jlstOf(1, 2, 3)) - assertEquals(jlst.hashCode, jlst.hashCode) + // SeqWrapper extending util.AbstractList inherits equals + @Test def `Seq as JList has equals`: Unit = { + def seq = Seq(1, 2, 3) + def jList = new SeqWrapper(seq) + assertEquals(jList, jList) + assertEquals(jListOf(1, 2, 3), jList) + assertEquals(jList, jListOf(1, 2, 3)) + assertTrue(jList == jListOf(1, 2, 3)) + assertEquals(jList.hashCode, jList.hashCode) } + // SetWrapper extending util.AbstractSet inherits equals @Test def `Set as JSet has equals`: Unit = { - val set = Set(1, 2, 3) - val jset = new SetWrapper(set) - assertEquals(jsetOf(1, 2, 3), jset) - assertEquals(jset, jsetOf(1, 2, 3)) - assertTrue(jset == jsetOf(1, 2, 3)) - assertEquals(jset.hashCode, jset.hashCode) + def set = Set(1, 2, 3) + def jSet = new SetWrapper(set) + assertEquals(jSet, jSet) + assertEquals(jSetOf(1, 2, 3), jSet) + assertEquals(jSet, jSetOf(1, 2, 3)) + assertTrue(jSet == jSetOf(1, 2, 3)) + assertEquals(jSet.hashCode, jSet.hashCode) } + // MapWrapper extending util.AbstractMap inherits equals @Test def `Map as JMap has equals`: Unit = { - val map = Map(1 -> "one", 2 -> "two", 3 -> "three") - val jmap = new MapWrapper(map) - assertEquals(jmap, jmap) + def map = Map(1 -> "one", 2 -> "two", 3 -> "three") + def jMap = new MapWrapper(map) + assertEquals(jMap, jMap) + assertEquals(jMapOf(1 -> "one", 2 -> "two", 3 -> "three"), jMap) + assertEquals(jMap, jMapOf(1 -> "one", 2 -> "two", 3 -> "three")) + assertTrue(jMap == jMapOf(1 -> "one", 2 -> "two", 3 -> "three")) + assertEquals(jMap.hashCode, jMap.hashCode) } - @Test def `Anything as Collection is equal to Anything`: Unit = { - def set = Set(1, 2, 3) - def jset = new IterableWrapper(set) - assertTrue(jset == jset) - assertEquals(jset, jset) - assertNotEquals(jset, set) - assertEquals(jset.hashCode, jset.hashCode) + @Test def `Iterable as JIterable does not compare equal`: Unit = { + // scala iterable without element equality defined + def iterable: Iterable[Int] = new AbstractIterable[Int] { + override def iterator: Iterator[Int] = Iterator(1, 2, 3) + } + def jIterable = new IterableWrapper(iterable) + assertNotEquals(jIterable, jIterable) + assertNotEquals(jIterable.hashCode, jIterable.hashCode) } - @Test def `Iterator wrapper does not compare equal`: Unit = { - def it = List(1, 2, 3).iterator - def jit = new IteratorWrapper(it) - assertNotEquals(jit, jit) - assertNotEquals(jit.hashCode, jit.hashCode) + @Test def `Iterator as JIterator does not compare equal`: Unit = { + def iterator = Iterator(1, 2, 3) + def jIterator = new IteratorWrapper(iterator) + assertNotEquals(jIterator, jIterator) + assertNotEquals(jIterator.hashCode, jIterator.hashCode) } - @Test def `Anything.asScala Iterable has case equals`: Unit = { - def vs = jlstOf(42, 27, 37) - def it = new JListWrapper(vs) - assertEquals(it, it) - assertEquals(it.hashCode, it.hashCode) + @Test def `All wrapper compare equal if underlying is equal`(): Unit = { + val jList = Collections.emptyList[String]() + assertEquals(jList.asScala, jList.asScala) + + val jIterator = jList.iterator() + assertEquals(jIterator.asScala, jIterator.asScala) + + val jEnumeration = Collections.emptyEnumeration[String]() + assertEquals(jEnumeration.asScala, jEnumeration.asScala) + + val jIterable = jList.asInstanceOf[JIterable[String]] + assertEquals(jIterable.asScala, jIterable.asScala) + + val jCollection = jList.asInstanceOf[JCollection[String]] + assertEquals(jCollection.asScala, jCollection.asScala) + + val jSet = Collections.emptySet[String]() + assertEquals(jSet.asScala, jSet.asScala) + + val jMap = Collections.emptyMap[String, String]() + assertEquals(jMap.asScala, jMap.asScala) + + val jCMap = new JCMap[String, String]() + assertEquals(jCMap.asScala, jCMap.asScala) + + val iterator = Iterator.empty[String] + assertEquals(iterator.asJava, iterator.asJava) + assertEquals(iterator.asJavaEnumeration, iterator.asJavaEnumeration) + + val iterable = Iterable.empty[String] + assertEquals(iterable.asJava, iterable.asJava) + assertEquals(iterable.asJavaCollection, iterable.asJavaCollection) + + val buffer = mutable.Buffer.empty[String] + assertEquals(buffer.asJava, buffer.asJava) + + val seq = mutable.Seq.empty[String] + assertEquals(seq.asJava, seq.asJava) + + val mutableSet = mutable.Set.empty[String] + assertEquals(mutableSet.asJava, mutableSet.asJava) + + val set = Set.empty[String] + assertEquals(set.asJava, set.asJava) + + val mutableMap = mutable.Map.empty[String, String] + assertEquals(mutableMap.asJava, mutableMap.asJava) + assertEquals(mutableMap.asJavaDictionary, mutableMap.asJavaDictionary) + + val map = Map.empty[String, String] + assertEquals(map.asJava, map.asJava) + + val concurrentMap = concurrent.TrieMap.empty[String, String] + assertEquals(concurrentMap.asJava, concurrentMap.asJava) } }