Skip to content

Commit 586e273

Browse files
feat: add bitmap utility functions (#11)
* Feat: add bitmap utility functions * Address feedback
1 parent fd54225 commit 586e273

File tree

3 files changed

+340
-9
lines changed

3 files changed

+340
-9
lines changed

firebolt/buffers.mojo

Lines changed: 185 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from memory import UnsafePointer, memset_zero, memcpy, ArcPointer
2-
from sys.info import sizeof
1+
from memory import UnsafePointer, memset_zero, memcpy, ArcPointer, Span, memset
2+
from sys.info import sizeof, simdbytewidth
33
import math
4+
from bit import pop_count, count_trailing_zeros
45

56

67
fn _required_bytes(length: Int, T: DType) -> Int:
@@ -12,6 +13,11 @@ fn _required_bytes(length: Int, T: DType) -> Int:
1213
return math.align_up(size, 64)
1314

1415

16+
alias simd_width = simdbytewidth()
17+
18+
alias simd_widths = (simd_width, simd_width // 2, 1)
19+
20+
1521
struct Buffer(Movable):
1622
var ptr: UnsafePointer[UInt8]
1723
var size: Int
@@ -155,15 +161,124 @@ struct Bitmap(Movable, Writable):
155161
fn unsafe_set(mut self, index: Int, value: Bool) -> None:
156162
self.buffer.unsafe_set[DType.bool](index, value)
157163

164+
@always_inline
158165
fn length(self) -> Int:
159166
return self.buffer.length[DType.bool]()
160167

168+
@always_inline
161169
fn size(self) -> Int:
162170
return self.buffer.size
163171

164172
fn grow[I: Intable](mut self, target_length: I):
165173
return self.buffer.grow[DType.bool](target_length)
166174

175+
fn bit_count(self) -> Int:
176+
"""The number of bits with value 1 in the Bitmap."""
177+
var start = 0
178+
var count = 0
179+
while start < self.buffer.size:
180+
if self.buffer.size - start > simd_width:
181+
count += (
182+
self.buffer.offset(start)
183+
.load[width=simd_width]()
184+
.reduce_bit_count()
185+
)
186+
start += simd_width
187+
else:
188+
count += (
189+
self.buffer.offset(start).load[width=1]().reduce_bit_count()
190+
)
191+
start += 1
192+
return count
193+
194+
fn count_leading_bits(self, start: Int = 0, value: Bool = False) -> Int:
195+
"""Count the number of leading bits with the given value in the bitmap, starting at a given posiion.
196+
197+
Note that index 0 in the bitmap translates to right most bit in the first byte of the buffer.
198+
So when we are looking for leading zeros from a bitmap standpoing we need to look at
199+
trailing zeros in the bitmap's associated buffer.
200+
201+
The SIMD API available looks at leading zeros only, we negate the input when needed.
202+
203+
Args:
204+
start: The position where we should start counting.
205+
value: The value of the bits we want to count.
206+
207+
Returns:
208+
The number of leadinging bits with the given value in the bitmap.
209+
"""
210+
211+
var count = 0
212+
var index = start // 8
213+
var bit_in_first_byte = start % 8
214+
215+
if bit_in_first_byte != 0:
216+
# Process the partial first byte by applying a mask.
217+
var loaded = self.buffer.offset(index).load[width=1]()
218+
if value:
219+
loaded = ~loaded
220+
var mask = (1 << bit_in_first_byte) - 1
221+
loaded &= ~mask
222+
leading_zeros = Int(count_trailing_zeros(loaded))
223+
if leading_zeros == 0:
224+
return count
225+
count = leading_zeros - bit_in_first_byte
226+
if leading_zeros != 8:
227+
# The first byte has some bits of the other value, just return the count.
228+
return count
229+
230+
index += 1
231+
232+
# Process full bytes.
233+
while index < self.size():
234+
235+
@parameter
236+
for width_index in range(len(simd_widths)):
237+
alias width = simd_widths[width_index]
238+
if self.size() - index >= width:
239+
var loaded = self.buffer.offset(index).load[width=width]()
240+
if value:
241+
loaded = ~loaded
242+
var leading_zeros = count_trailing_zeros(loaded)
243+
for i in range(width):
244+
count += Int(leading_zeros[i])
245+
if leading_zeros[i] != 8:
246+
return count
247+
index += width
248+
# break from the simd widths loop
249+
break
250+
return count
251+
252+
fn count_leading_zeros(self, start: Int = 0) -> Int:
253+
"""Count the number of leading 0s in the given value in the bitmap, starting at a given posiion.
254+
255+
Note that index 0 in the bitmap translates to right most bit in the first byte of the buffer.
256+
So when we are looking for leading zeros from a bitmap standpoing we need to look at
257+
trailing zeros in the bitmap's associated buffer.
258+
259+
Args:
260+
start: The position where we should start counting.
261+
262+
Returns:
263+
The number of leading zeros in the bitmap.
264+
"""
265+
return self.count_leading_bits(start, value=False)
266+
267+
fn count_leading_ones(self, start: Int = 0) -> Int:
268+
"""Count the number of leading 1s in the given value in the bitmap, starting at a given posiion.
269+
270+
Note that index 0 in the bitmap translates to right most bit in the first byte of the buffer.
271+
So when we are looking for leading zeros from a bitmap standpoing we need to look at
272+
trailing zeros in the bitmap's associated buffer.
273+
274+
Args:
275+
start: The position where we should start counting.
276+
277+
Returns:
278+
The number of leading ones in the bitmap.
279+
"""
280+
return self.count_leading_bits(start, value=True)
281+
167282
fn extend(
168283
mut self,
169284
other: Bitmap,
@@ -182,3 +297,71 @@ struct Bitmap(Movable, Writable):
182297

183298
for i in range(length):
184299
self.unsafe_set(i + start, other.unsafe_get(i))
300+
301+
fn partial_byte_set(
302+
mut self,
303+
byte_index: Int,
304+
bit_pos_start: Int,
305+
bit_pos_end: Int,
306+
value: Bool,
307+
) -> None:
308+
"""Set a range of bits in one specific byte of the bitmap to the specified value.
309+
"""
310+
311+
debug_assert(
312+
bit_pos_start >= 0
313+
and bit_pos_end <= 8
314+
and bit_pos_start <= bit_pos_end,
315+
"Invalid range: ",
316+
bit_pos_start,
317+
" to ",
318+
bit_pos_end,
319+
)
320+
321+
# Process the partial byte at the start, if appropriate.
322+
var mask = (1 << (bit_pos_end - bit_pos_start)) - 1
323+
mask = mask << bit_pos_start
324+
var initial_value = self.buffer.unsafe_get[DType.uint8](byte_index)
325+
var buffer_value = initial_value
326+
if value:
327+
buffer_value = buffer_value | mask
328+
else:
329+
buffer_value = buffer_value & ~mask
330+
self.buffer.unsafe_set[DType.uint8](byte_index, buffer_value)
331+
332+
fn unsafe_range_set(mut self, start: Int, length: Int, value: Bool) -> None:
333+
"""Set a range of bits in the bitmap to the specified value.
334+
335+
Args:
336+
start: The starting index in the bitmap.
337+
length: The number of bits to set.
338+
value: The value to set the bits to.
339+
"""
340+
341+
# Process the partial byte at the ends.
342+
var start_index = start // 8
343+
var bit_pos_start = start % 8
344+
var end_index = (start + length) // 8
345+
var bit_pos_end = (start + length) % 8
346+
347+
if bit_pos_start != 0 or bit_pos_end != 0:
348+
if start_index == end_index:
349+
self.partial_byte_set(
350+
start_index, bit_pos_start, bit_pos_end, value
351+
)
352+
else:
353+
if bit_pos_start != 0:
354+
self.partial_byte_set(start_index, bit_pos_start, 8, value)
355+
start_index += 1
356+
if bit_pos_end != 0:
357+
self.partial_byte_set(end_index, 0, bit_pos_end, value)
358+
end_index -= 1
359+
360+
# Now take care of the full bytes.
361+
if end_index > start_index:
362+
var byte_value = 255 if value else 0
363+
memset(
364+
self.buffer.offset(start_index),
365+
value=byte_value,
366+
count=end_index - start_index,
367+
)

firebolt/test_fixtures/arrays.mojo

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from firebolt.arrays import BoolArray, ArrayData
2+
from memory import ArcPointer
3+
from firebolt.buffers import Buffer, Bitmap
4+
from firebolt.dtypes import uint8
5+
from testing import assert_equal
6+
from builtin._location import __call_location
7+
8+
9+
fn as_bool_array_scalar(value: Bool) -> BoolArray.scalar:
10+
"""Bool conversion function."""
11+
return BoolArray.scalar(Scalar[DType.bool](value))
12+
13+
14+
fn bool_array(*values: Bool) -> BoolArray:
15+
var a = BoolArray(len(values))
16+
for value in values:
17+
a.unsafe_append(as_bool_array_scalar(value))
18+
return a^
19+
20+
21+
def build_array_data(length: Int, nulls: Int) -> ArrayData:
22+
"""Builds an ArrayData object with nulls.
23+
24+
Args:
25+
length: The length of the array.
26+
nulls: The number of nulls to set.
27+
"""
28+
var bitmap = Bitmap.alloc(length)
29+
var buffer = Buffer.alloc[DType.uint8](length)
30+
for i in range(length):
31+
buffer.unsafe_set(i, i % 256)
32+
# Check to see if the current index should be valid or null.
33+
var is_valid = True
34+
if nulls > 0:
35+
if i % (Int(length / nulls)) == 0:
36+
is_valid = False
37+
bitmap.unsafe_set(i, is_valid)
38+
39+
var buffers = List(ArcPointer(buffer^))
40+
return ArrayData(
41+
dtype=uint8,
42+
length=length,
43+
bitmap=ArcPointer(bitmap^),
44+
buffers=buffers,
45+
children=List[ArcPointer[ArrayData]](),
46+
)
47+
48+
49+
@always_inline
50+
def assert_bitmap_set(
51+
bitmap: Bitmap, expected_true_pos: List[Int], message: StringLiteral
52+
) -> None:
53+
var list_pos = 0
54+
for i in range(bitmap.length()):
55+
var expected_value = False
56+
if list_pos < len(expected_true_pos):
57+
if expected_true_pos[list_pos] == i:
58+
expected_value = True
59+
list_pos += 1
60+
var current_value = bitmap.unsafe_get(i)
61+
assert_equal(
62+
current_value,
63+
expected_value,
64+
String(
65+
"{}: Bitmap index {} is {}, expected {} as per list position {}"
66+
).format(message, i, current_value, expected_value, list_pos),
67+
location=__call_location(),
68+
)

0 commit comments

Comments
 (0)