Skip to content

Commit

Permalink
Add in-place shift operations to mutable.BitSet
Browse files Browse the repository at this point in the history
  • Loading branch information
linasm committed Sep 25, 2019
1 parent 8f64810 commit e51ad97
Show file tree
Hide file tree
Showing 3 changed files with 230 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package scala.collection.decorators

import scala.collection.{BitSetOps, mutable}

class MutableBitSetDecorator(protected val bs: mutable.BitSet) {

import BitSetDecorator._
import BitSetOps._

/**
* Updates this BitSet to the left shift of itself by the given shift distance.
* The shift distance may be negative, in which case this method performs a right shift.
* @param shiftBy shift distance, in bits
* @return the BitSet itself
*/
def <<=(shiftBy: Int): mutable.BitSet = {

if (bs.nwords == 0 || bs.nwords == 1 && bs.word(0) == 0) ()
else if (shiftBy > 0) shiftLeftInPlace(shiftBy)
else if (shiftBy < 0) shiftRightInPlace(-shiftBy)

bs
}

/**
* Updates this BitSet to the right shift of itself by the given shift distance.
* The shift distance may be negative, in which case this method performs a left shift.
* @param shiftBy shift distance, in bits
* @return the BitSet itself
*/
def >>=(shiftBy: Int): mutable.BitSet = {

if (bs.nwords == 0 || bs.nwords == 1 && bs.word(0) == 0) ()
else if (shiftBy > 0) shiftRightInPlace(shiftBy)
else if (shiftBy < 0) shiftLeftInPlace(-shiftBy)

bs
}

private def shiftLeftInPlace(shiftBy: Int): Unit = {

val bitOffset = shiftBy & WordMask
val wordOffset = shiftBy >>> LogWL

var significantWordCount = bs.nwords
while (significantWordCount > 0 && bs.word(significantWordCount - 1) == 0) {
significantWordCount -= 1
}

if (bitOffset == 0) {
val newSize = significantWordCount + wordOffset
require(newSize <= MaxSize)
ensureCapacity(newSize)
System.arraycopy(bs.elems, 0, bs.elems, wordOffset, significantWordCount)
} else {
val revBitOffset = WordLength - bitOffset
val extraBits = bs.elems(significantWordCount - 1) >>> revBitOffset
val extraWordCount = if (extraBits == 0) 0 else 1
val newSize = significantWordCount + wordOffset + extraWordCount
require(newSize <= MaxSize)
ensureCapacity(newSize)
var i = significantWordCount - 1
var previous = bs.elems(i)
while (i > 0) {
val current = bs.elems(i - 1)
bs.elems(i + wordOffset) = (current >>> revBitOffset) | (previous << bitOffset)
previous = current
i -= 1
}
bs.elems(wordOffset) = previous << bitOffset
if (extraWordCount != 0) bs.elems(newSize - 1) = extraBits
}
java.util.Arrays.fill(bs.elems, 0, wordOffset, 0)
}

private def shiftRightInPlace(shiftBy: Int): Unit = {

val bitOffset = shiftBy & WordMask

if (bitOffset == 0) {
val wordOffset = shiftBy >>> LogWL
val newSize = bs.nwords - wordOffset
if (newSize > 0) {
System.arraycopy(bs.elems, wordOffset, bs.elems, 0, newSize)
java.util.Arrays.fill(bs.elems, newSize, bs.nwords, 0)
} else bs.clear()
} else {
val wordOffset = (shiftBy >>> LogWL) + 1
val extraBits = bs.elems(bs.nwords - 1) >>> bitOffset
val extraWordCount = if (extraBits == 0) 0 else 1
val newSize = bs.nwords - wordOffset + extraWordCount
if (newSize > 0) {
val revBitOffset = WordLength - bitOffset
var previous = bs.elems(wordOffset - 1)
var i = wordOffset
while (i < bs.nwords) {
val current = bs.elems(i)
bs.elems(i - wordOffset) = (previous >>> bitOffset) | (current << revBitOffset)
previous = current
i += 1
}
if (extraWordCount != 0) bs.elems(newSize - 1) = extraBits
java.util.Arrays.fill(bs.elems, newSize, bs.nwords, 0)
} else bs.clear()
}
}

protected final def ensureCapacity(idx: Int): Unit = {
// Copied from mutable.BitSet.ensureCapacity (which is inaccessible from here).
require(idx < MaxSize)
if (idx >= bs.nwords) {
var newlen = bs.nwords
while (idx >= newlen) newlen = math.min(newlen * 2, MaxSize)
val elems1 = new Array[Long](newlen)
Array.copy(bs.elems, 0, elems1, 0, bs.nwords)
bs.elems = elems1
}
}

}
3 changes: 3 additions & 0 deletions src/main/scala/scala/collection/decorators/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,7 @@ package object decorators {
implicit def bitSetDecorator[C <: BitSet with BitSetOps[C]](bs: C): BitSetDecorator[C] =
new BitSetDecorator(bs)

implicit def mutableBitSetDecorator(bs: mutable.BitSet): MutableBitSetDecorator =
new MutableBitSetDecorator(bs)

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package scala.collection.decorators

import org.junit.{Assert, Test}

import scala.collection.mutable.BitSet

class MutableBitSetDecoratorTest {

import Assert.{assertEquals, assertSame}
import BitSet.empty

@Test
def shiftEmptyLeftInPlace(): Unit = {
for (shiftBy <- 0 to 128) {
val bs = empty
bs <<= shiftBy
assertEquals(empty, bs)
assertEquals(empty.nwords, bs.nwords)
}
}

@Test
def shiftLowestBitLeftInPlace(): Unit = {
for (shiftBy <- 0 to 128) {
val bs = BitSet(0)
bs <<= shiftBy
assertEquals(BitSet(shiftBy), bs)
}
}

@Test
def shiftNegativeLeftInPlace(): Unit = {
val bs = BitSet(1)
bs <<= -1
assertEquals(BitSet(0), bs)
}

@Test
def largeShiftLeftInPlace(): Unit = {
for (shiftBy <- 0 to 128) {
val bs = BitSet(0 to 300 by 5: _*)
val expected = bs.map(_ + shiftBy)
bs <<= shiftBy
assertEquals(expected, bs)
}
}

@Test
def skipZeroWordsOnShiftLeftInPlace(): Unit = {
val bs = BitSet(5 * 64 - 1)
bs <<= 64
assertEquals(BitSet(6 * 64 - 1), bs)
assertEquals(8, bs.nwords)
}

@Test
def shiftEmptyRightInPlace(): Unit = {
for (shiftBy <- 0 to 128) {
val bs = empty
bs >>= shiftBy
assertEquals(empty, bs)
assertEquals(empty.nwords, bs.nwords)
}
}

@Test
def shiftLowestBitRightInPlace(): Unit = {
val bs = BitSet(0)
bs >>= 0
assertEquals(BitSet(0), bs)

for (shiftBy <- 1 to 128) {
val bs = BitSet(0)
bs >>= shiftBy
assertEquals(empty, bs)
}
}

@Test
def shiftToLowestBitRightInPlace(): Unit = {
for (shiftBy <- 0 to 128) {
val bs = BitSet(shiftBy)
bs >>= shiftBy
assertEquals(BitSet(0), bs)
}
}

@Test
def shiftNegativeRightInPlace(): Unit = {
val bs = BitSet(0)
bs >>= -1
assertEquals(BitSet(1), bs)
}

@Test
def largeShiftRightInPlace(): Unit = {
for (shiftBy <- 0 to 128) {
val bs = BitSet(0 to 300 by 5: _*)
val expected = bs.collect {
case b if b >= shiftBy => b - shiftBy
}
bs >>= shiftBy
assertEquals(expected, bs)
}
}

}

0 comments on commit e51ad97

Please sign in to comment.