1
- from memory import UnsafePointer, memset_zero, memcpy, ArcPointer
2
- from sys.info import sizeof
1
+ from memory import UnsafePointer, memset_zero, memcpy, ArcPointer, Span, memset
2
+ from sys.info import sizeof, simdbytewidth
3
3
import math
4
+ from bit import pop_count, count_trailing_zeros
4
5
5
6
6
7
fn _required_bytes (length : Int, T : DType) -> Int:
@@ -12,6 +13,11 @@ fn _required_bytes(length: Int, T: DType) -> Int:
12
13
return math.align_up(size, 64 )
13
14
14
15
16
+ alias simd_width = simdbytewidth()
17
+
18
+ alias simd_widths = (simd_width, simd_width // 2 , 1 )
19
+
20
+
15
21
struct Buffer (Movable ):
16
22
var ptr : UnsafePointer[UInt8]
17
23
var size : Int
@@ -155,15 +161,124 @@ struct Bitmap(Movable, Writable):
155
161
fn unsafe_set (mut self , index : Int, value : Bool) -> None :
156
162
self .buffer.unsafe_set[DType.bool](index, value)
157
163
164
+ @always_inline
158
165
fn length (self ) -> Int:
159
166
return self .buffer.length[DType.bool]()
160
167
168
+ @always_inline
161
169
fn size (self ) -> Int:
162
170
return self .buffer.size
163
171
164
172
fn grow [I : Intable](mut self , target_length : I):
165
173
return self .buffer.grow[DType.bool](target_length)
166
174
175
+ fn bit_count (self ) -> Int:
176
+ """ The number of bits with value 1 in the Bitmap."""
177
+ var start = 0
178
+ var count = 0
179
+ while start < self .buffer.size:
180
+ if self .buffer.size - start > simd_width:
181
+ count += (
182
+ self .buffer.offset(start)
183
+ .load[width=simd_width]()
184
+ .reduce_bit_count()
185
+ )
186
+ start += simd_width
187
+ else :
188
+ count += (
189
+ self .buffer.offset(start).load[width=1 ]().reduce_bit_count()
190
+ )
191
+ start += 1
192
+ return count
193
+
194
+ fn count_leading_bits (self , start : Int = 0 , value : Bool = False ) -> Int:
195
+ """ Count the number of leading bits with the given value in the bitmap, starting at a given posiion.
196
+
197
+ Note that index 0 in the bitmap translates to right most bit in the first byte of the buffer.
198
+ So when we are looking for leading zeros from a bitmap standpoing we need to look at
199
+ trailing zeros in the bitmap's associated buffer.
200
+
201
+ The SIMD API available looks at leading zeros only, we negate the input when needed.
202
+
203
+ Args:
204
+ start: The position where we should start counting.
205
+ value: The value of the bits we want to count.
206
+
207
+ Returns:
208
+ The number of leadinging bits with the given value in the bitmap.
209
+ """
210
+
211
+ var count = 0
212
+ var index = start // 8
213
+ var bit_in_first_byte = start % 8
214
+
215
+ if bit_in_first_byte != 0 :
216
+ # Process the partial first byte by applying a mask.
217
+ var loaded = self .buffer.offset(index).load[width=1 ]()
218
+ if value:
219
+ loaded = ~ loaded
220
+ var mask = (1 << bit_in_first_byte) - 1
221
+ loaded &= ~ mask
222
+ leading_zeros = Int(count_trailing_zeros(loaded))
223
+ if leading_zeros == 0 :
224
+ return count
225
+ count = leading_zeros - bit_in_first_byte
226
+ if leading_zeros != 8 :
227
+ # The first byte has some bits of the other value, just return the count.
228
+ return count
229
+
230
+ index += 1
231
+
232
+ # Process full bytes.
233
+ while index < self .size():
234
+
235
+ @parameter
236
+ for width_index in range (len (simd_widths)):
237
+ alias width = simd_widths[width_index]
238
+ if self .size() - index >= width:
239
+ var loaded = self .buffer.offset(index).load[width=width]()
240
+ if value:
241
+ loaded = ~ loaded
242
+ var leading_zeros = count_trailing_zeros(loaded)
243
+ for i in range (width):
244
+ count += Int(leading_zeros[i])
245
+ if leading_zeros[i] != 8 :
246
+ return count
247
+ index += width
248
+ # break from the simd widths loop
249
+ break
250
+ return count
251
+
252
+ fn count_leading_zeros (self , start : Int = 0 ) -> Int:
253
+ """ Count the number of leading 0s in the given value in the bitmap, starting at a given posiion.
254
+
255
+ Note that index 0 in the bitmap translates to right most bit in the first byte of the buffer.
256
+ So when we are looking for leading zeros from a bitmap standpoing we need to look at
257
+ trailing zeros in the bitmap's associated buffer.
258
+
259
+ Args:
260
+ start: The position where we should start counting.
261
+
262
+ Returns:
263
+ The number of leading zeros in the bitmap.
264
+ """
265
+ return self .count_leading_bits(start, value = False )
266
+
267
+ fn count_leading_ones (self , start : Int = 0 ) -> Int:
268
+ """ Count the number of leading 1s in the given value in the bitmap, starting at a given posiion.
269
+
270
+ Note that index 0 in the bitmap translates to right most bit in the first byte of the buffer.
271
+ So when we are looking for leading zeros from a bitmap standpoing we need to look at
272
+ trailing zeros in the bitmap's associated buffer.
273
+
274
+ Args:
275
+ start: The position where we should start counting.
276
+
277
+ Returns:
278
+ The number of leading ones in the bitmap.
279
+ """
280
+ return self .count_leading_bits(start, value = True )
281
+
167
282
fn extend (
168
283
mut self ,
169
284
other : Bitmap,
@@ -182,3 +297,71 @@ struct Bitmap(Movable, Writable):
182
297
183
298
for i in range (length):
184
299
self .unsafe_set(i + start, other.unsafe_get(i))
300
+
301
+ fn partial_byte_set (
302
+ mut self ,
303
+ byte_index : Int,
304
+ bit_pos_start : Int,
305
+ bit_pos_end : Int,
306
+ value : Bool,
307
+ ) -> None :
308
+ """ Set a range of bits in one specific byte of the bitmap to the specified value.
309
+ """
310
+
311
+ debug_assert(
312
+ bit_pos_start >= 0
313
+ and bit_pos_end <= 8
314
+ and bit_pos_start <= bit_pos_end,
315
+ " Invalid range: " ,
316
+ bit_pos_start,
317
+ " to " ,
318
+ bit_pos_end,
319
+ )
320
+
321
+ # Process the partial byte at the start, if appropriate.
322
+ var mask = (1 << (bit_pos_end - bit_pos_start)) - 1
323
+ mask = mask << bit_pos_start
324
+ var initial_value = self .buffer.unsafe_get[DType.uint8](byte_index)
325
+ var buffer_value = initial_value
326
+ if value:
327
+ buffer_value = buffer_value | mask
328
+ else :
329
+ buffer_value = buffer_value & ~ mask
330
+ self .buffer.unsafe_set[DType.uint8](byte_index, buffer_value)
331
+
332
+ fn unsafe_range_set (mut self , start : Int, length : Int, value : Bool) -> None :
333
+ """ Set a range of bits in the bitmap to the specified value.
334
+
335
+ Args:
336
+ start: The starting index in the bitmap.
337
+ length: The number of bits to set.
338
+ value: The value to set the bits to.
339
+ """
340
+
341
+ # Process the partial byte at the ends.
342
+ var start_index = start // 8
343
+ var bit_pos_start = start % 8
344
+ var end_index = (start + length) // 8
345
+ var bit_pos_end = (start + length) % 8
346
+
347
+ if bit_pos_start != 0 or bit_pos_end != 0 :
348
+ if start_index == end_index:
349
+ self .partial_byte_set(
350
+ start_index, bit_pos_start, bit_pos_end, value
351
+ )
352
+ else :
353
+ if bit_pos_start != 0 :
354
+ self .partial_byte_set(start_index, bit_pos_start, 8 , value)
355
+ start_index += 1
356
+ if bit_pos_end != 0 :
357
+ self .partial_byte_set(end_index, 0 , bit_pos_end, value)
358
+ end_index -= 1
359
+
360
+ # Now take care of the full bytes.
361
+ if end_index > start_index:
362
+ var byte_value = 255 if value else 0
363
+ memset(
364
+ self .buffer.offset(start_index),
365
+ value = byte_value,
366
+ count = end_index - start_index,
367
+ )
0 commit comments