From c8b7f62af775442fe8cd9c01bb2b4d395a44cc90 Mon Sep 17 00:00:00 2001 From: Lukas Rytz Date: Tue, 19 Dec 2023 14:38:52 +0100 Subject: [PATCH] [backport] Fix RedBlackTree.doFrom / doTo / doUntil `upd` may return a red tree with a red child. Need to use `maybeBlacken` when such a tree is not expected. --- .../collection/immutable/RedBlackTree.scala | 23 ++++- .../collection/immutable/SortedSetTest.scala | 84 +++++++++++++++++++ test/scalacheck/redblacktree.scala | 2 +- 3 files changed, 104 insertions(+), 5 deletions(-) diff --git a/src/library/scala/collection/immutable/RedBlackTree.scala b/src/library/scala/collection/immutable/RedBlackTree.scala index 15cbbab64cc4..effbb86db29b 100644 --- a/src/library/scala/collection/immutable/RedBlackTree.scala +++ b/src/library/scala/collection/immutable/RedBlackTree.scala @@ -25,6 +25,21 @@ import scala.annotation.tailrec * optimizations behind a reasonably clean API. */ private[collection] object NewRedBlackTree { + def validate[A](tree: Tree[A, _])(implicit ordering: Ordering[A]): tree.type = { + def impl(tree: Tree[A, _], keyProp: A => Boolean): Int = { + assert(keyProp(tree.key), s"key check failed: $tree") + if (tree.isRed) { + assert(tree.left == null || tree.left.isBlack, s"red-red left $tree") + assert(tree.right == null || tree.right.isBlack, s"red-red right $tree") + } + val leftBlacks = if (tree.left == null) 0 else impl(tree.left, k => keyProp(k) && ordering.compare(k, tree.key) < 0) + val rightBlacks = if (tree.right == null) 0 else impl(tree.right, k => keyProp(k) && ordering.compare(k, tree.key) > 0) + assert(leftBlacks == rightBlacks, s"not balanced: $tree") + leftBlacks + (if (tree.isBlack) 1 else 0) + } + if (tree != null) impl(tree, _ => true) + tree + } def isEmpty(tree: Tree[_, _]): Boolean = tree eq null @@ -447,7 +462,7 @@ private[collection] object NewRedBlackTree { if (ordering.lt(tree.key, from)) return doFrom(tree.right, from) val newLeft = doFrom(tree.left, from) if (newLeft eq tree.left) tree - else if (newLeft eq null) upd(tree.right, tree.key, tree.value, overwrite = false) + else if (newLeft eq null) maybeBlacken(upd(tree.right, tree.key, tree.value, overwrite = false)) else join(newLeft, tree.key, tree.value, tree.right) } private[this] def doTo[A, B](tree: Tree[A, B], to: A)(implicit ordering: Ordering[A]): Tree[A, B] = { @@ -455,15 +470,15 @@ private[collection] object NewRedBlackTree { if (ordering.lt(to, tree.key)) return doTo(tree.left, to) val newRight = doTo(tree.right, to) if (newRight eq tree.right) tree - else if (newRight eq null) upd(tree.left, tree.key, tree.value, overwrite = false) - else join (tree.left, tree.key, tree.value, newRight) + else if (newRight eq null) maybeBlacken(upd(tree.left, tree.key, tree.value, overwrite = false)) + else join(tree.left, tree.key, tree.value, newRight) } private[this] def doUntil[A, B](tree: Tree[A, B], until: A)(implicit ordering: Ordering[A]): Tree[A, B] = { if (tree eq null) return null if (ordering.lteq(until, tree.key)) return doUntil(tree.left, until) val newRight = doUntil(tree.right, until) if (newRight eq tree.right) tree - else if (newRight eq null) upd(tree.left, tree.key, tree.value, overwrite = false) + else if (newRight eq null) maybeBlacken(upd(tree.left, tree.key, tree.value, overwrite = false)) else join(tree.left, tree.key, tree.value, newRight) } diff --git a/test/junit/scala/collection/immutable/SortedSetTest.scala b/test/junit/scala/collection/immutable/SortedSetTest.scala index 21bc235d5133..d3be3652e560 100644 --- a/test/junit/scala/collection/immutable/SortedSetTest.scala +++ b/test/junit/scala/collection/immutable/SortedSetTest.scala @@ -1,8 +1,10 @@ package scala.collection.immutable +import org.junit.Assert.assertEquals import org.junit.Test import scala.tools.testing.AllocationTest +import scala.tools.testing.AssertUtil.assertThrows class SortedSetTest extends AllocationTest{ @@ -23,4 +25,86 @@ class SortedSetTest extends AllocationTest{ val ord = Ordering[String] exactAllocates(168)(SortedSet("a", "b")(ord)) } + + @Test def redBlackValidate(): Unit = { + import NewRedBlackTree._ + def redLeaf(x: Int) = RedTree(x, null, null, null) + def blackLeaf(x: Int) = BlackTree(x, null, null, null) + + validate(redLeaf(1)) + validate(blackLeaf(1)) + assertThrows[AssertionError](validate(RedTree(2, null, redLeaf(1), null)), _.contains("red-red")) + assertThrows[AssertionError](validate(RedTree(2, null, blackLeaf(1), null)), _.contains("not balanced")) + validate(RedTree(2, null, blackLeaf(1), blackLeaf(3))) + validate(BlackTree(2, null, blackLeaf(1), blackLeaf(3))) + assertThrows[AssertionError](validate(RedTree(4, null, blackLeaf(1), blackLeaf(3))), _.contains("key check")) + } + + @Test def t12921(): Unit = { + val s1 = TreeSet(6, 1, 11, 9, 10, 8) + NewRedBlackTree.validate(s1.tree) + + val s2 = s1.from(2) + NewRedBlackTree.validate(s2.tree) + assertEquals(Set(6, 8, 9, 10, 11), s2) + + val s3 = s2 ++ Seq(7,3,5) + NewRedBlackTree.validate(s3.tree) + assertEquals(Set(3, 5, 6, 7, 8, 9, 10, 11), s3) + + val s4 = s3.from(4) + NewRedBlackTree.validate(s4.tree) + assertEquals(Set(5, 6, 7, 8, 9, 10, 11), s4) + } + + @Test def t12921b(): Unit = { + import NewRedBlackTree._ + val t = BlackTree( + 5, + null, + BlackTree( + 3, + null, + RedTree(1, null, null, null), + RedTree(4, null, null, null) + ), + BlackTree(7, null, RedTree(6, null, null, null), null) + ) + validate(t) + validate(from(t, 2)) + } + + @Test def t12921c(): Unit = { + import NewRedBlackTree._ + val t = BlackTree( + 8, + null, + BlackTree(4, null, null, RedTree(6, null, null, null)), + BlackTree( + 12, + null, + RedTree(10, null, null, null), + RedTree(14, null, null, null) + ) + ) + validate(t) + validate(to(t, 13)) + } + + @Test def t12921d(): Unit = { + import NewRedBlackTree._ + val t = BlackTree( + 8, + null, + BlackTree(4, null, null, RedTree(6, null, null, null)), + BlackTree( + 12, + null, + RedTree(10, null, null, null), + RedTree(14, null, null, null) + ) + ) + validate(t) + validate(until(t, 13)) + } } diff --git a/test/scalacheck/redblacktree.scala b/test/scalacheck/redblacktree.scala index 038c608bebc0..eaa54a72c494 100644 --- a/test/scalacheck/redblacktree.scala +++ b/test/scalacheck/redblacktree.scala @@ -65,7 +65,7 @@ abstract class RedBlackTreeTest extends Properties("RedBlackTree") { def genInput: Gen[(Tree[String, Int], ModifyParm, Tree[String, Int])] = for { tree <- genTree parm <- genParm(tree) - } yield (tree, parm, modify(tree, parm)) + } yield (tree, parm, validate(modify(tree, parm))) } trait RedBlackTreeInvariants {