Skip to content

Commit

Permalink
Use branchless code in generic union2by2 loop
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel Lemire <lemire@gmail.com>
  • Loading branch information
lemire authored and puzpuzpuz committed Jul 15, 2021
1 parent 1477e28 commit f127594
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 48 deletions.
3 changes: 2 additions & 1 deletion benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ package roaring
import (
"bytes"
"fmt"
"github.com/stretchr/testify/require"
"math/rand"
"runtime"
"testing"

"github.com/stretchr/testify/require"

"github.com/willf/bitset"
)

Expand Down
10 changes: 10 additions & 0 deletions real_data_benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,16 @@ func BenchmarkRealDataParOr(b *testing.B) {
})
}

func BenchmarkRealDataOr(b *testing.B) {
benchmarkRealDataAggregate(b, func(bitmaps []*Bitmap) uint64 {
t := uint64(0)
for i := 1; i < len(bitmaps); i++ {
t += Or(bitmaps[i-1], bitmaps[i]).GetCardinality()
}
return t
})
}

func BenchmarkRealDataParHeapOr(b *testing.B) {
benchmarkRealDataAggregate(b, func(bitmaps []*Bitmap) uint64 {
return ParHeapOr(0, bitmaps...).GetCardinality()
Expand Down
16 changes: 16 additions & 0 deletions setutil.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package roaring

import "unsafe"

func equal(a, b []uint16) bool {
if len(a) != len(b) {
return false
Expand Down Expand Up @@ -548,3 +550,17 @@ func binarySearch(array []uint16, ikey uint16) int {
}
return -(low + 1)
}

// compareuint16 compares two number in a branchless manner.
// Returns -1 if s1 < s2, zero otherwise.
func compareuint16(x, y uint16) int {
return (int(x) - int(y)) >> 63
}

// uint16SlicePtr returns a pointer at the given slice
// index avoiding bound checks. Use cautiously.
func uint16SlicePtr(slice []uint16, idx uint) *uint16 {
p := unsafe.Pointer(&slice[0])
indexp := (unsafe.Pointer)(uintptr(p) + 2*uintptr(idx))
return (*uint16)(indexp)
}
75 changes: 29 additions & 46 deletions setutil_generic.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
package roaring

func union2by2(set1 []uint16, set2 []uint16, buffer []uint16) int {
pos := 0
k1 := 0
k2 := 0
if 0 == len(set2) {
buffer = buffer[:len(set1)]
copy(buffer, set1[:])
Expand All @@ -16,48 +13,34 @@ func union2by2(set1 []uint16, set2 []uint16, buffer []uint16) int {
copy(buffer, set2[:])
return len(set2)
}
s1 := set1[k1]
s2 := set2[k2]
buffer = buffer[:cap(buffer)]
for {
if s1 < s2 {
buffer[pos] = s1
pos++
k1++
if k1 >= len(set1) {
copy(buffer[pos:], set2[k2:])
pos += len(set2) - k2
break
}
s1 = set1[k1]
} else if s1 == s2 {
buffer[pos] = s1
pos++
k1++
k2++
if k1 >= len(set1) {
copy(buffer[pos:], set2[k2:])
pos += len(set2) - k2
break
}
if k2 >= len(set2) {
copy(buffer[pos:], set1[k1:])
pos += len(set1) - k1
break
}
s1 = set1[k1]
s2 = set2[k2]
} else { // if (set1[k1]>set2[k2])
buffer[pos] = s2
pos++
k2++
if k2 >= len(set2) {
copy(buffer[pos:], set1[k1:])
pos += len(set1) - k1
break
}
s2 = set2[k2]
}
var s1, s2 uint16
pos := uint(0)
k1 := uint(0)
k2 := uint(0)
len1 := uint(len(set1))
len2 := uint(len(set2))
buffer = buffer[:len1+len2]
for k1 < len1 && k2 < len2 {
s1 = *uint16SlicePtr(set1, k1)
s2 = *uint16SlicePtr(set2, k2)

sflag := compareuint16(s1, s2) // -1 if s1 < s2, zero otherwise
lflag := compareuint16(s2, s1) // -1 if s2 < s1, zero otherwise
*uint16SlicePtr(buffer, pos) = uint16(-sflag)*s1 + uint16(1+sflag)*s2

pos++
k1 += uint(1 + lflag)
k2 += uint(1 + sflag)
}
if k1 >= len1 {
copy(buffer[pos:], set2[k2:])
pos += len2 - k2
return int(pos)
}
if k2 >= len2 {
copy(buffer[pos:], set1[k1:])
pos += len1 - k1
return int(pos)
}
return pos
return int(pos)
}
85 changes: 84 additions & 1 deletion setutil_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@ package roaring
// to run just these tests: go test -run TestSetUtil*

import (
"github.com/stretchr/testify/assert"
"math/rand"
"sort"
"testing"

"github.com/stretchr/testify/assert"
)

func TestSetUtilDifference(t *testing.T) {
Expand Down Expand Up @@ -41,6 +44,19 @@ func TestSetUtilDifference(t *testing.T) {
assert.Equal(t, expectedresult, result)
}

func TestCompareuint16(t *testing.T) {
assert.Equal(t, 0, compareuint16(42, 42))
assert.Equal(t, 0, compareuint16(42, 1))
assert.Equal(t, -1, compareuint16(1, 42))
}

func TestUint16SlicePtr(t *testing.T) {
slice := []uint16{42, 41, 1, 2, 3}
for i := range slice {
assert.Equal(t, slice[i], *uint16SlicePtr(slice, uint(i)))
}
}

func TestSetUtilUnion(t *testing.T) {
data1 := []uint16{0, 1, 2, 3, 4, 9}
data2 := []uint16{2, 3, 4, 5, 8, 9, 11}
Expand Down Expand Up @@ -136,3 +152,70 @@ func TestSetUtilBinarySearch(t *testing.T) {
}
}
}

// go test -bench BenchmarkUnion2by2 -run -
func BenchmarkUnion2by2(b *testing.B) {
r := rand.New(rand.NewSource(123456))

// this is important: we pre-generate a large amount of randomized
// sorted arrays in order to disable the effects branch prediction,
// making benchmarks against non-branchless implementations
// more realistic.

sarrsnum := 1024
sz1 := 1024
sarrs := make([][]uint16, sarrsnum)
for i := 0; i < sarrsnum; i++ {
sarrs[i] = make([]uint16, sz1)
for j := 0; j < sz1; j++ {
sarrs[i][j] = uint16(r.Intn(MaxUint16))
}
sort.Sort(uint16Slice(sarrs[i]))
}

sz2 := 1024
s2 := make([]uint16, sz2)

sz3 := 1024
s3 := make([]uint16, sz3)

sz4 := 1024
s4 := make([]uint16, sz4)

// We are going to populate our arrays with random data.
// Importantly, we need to sort. There might be a few
// duplicates, by random chance, but it should not affect
// results too much.

for i := 0; i < sz2; i++ {
s2[i] = uint16(r.Intn(MaxUint16))
}
sort.Sort(uint16Slice(s2))

for i := 0; i < sz3; i++ {
s3[i] = uint16(r.Intn(MaxUint16))
}
sort.Sort(uint16Slice(s3))

for i := 0; i < sz4; i++ {
s4[i] = uint16(r.Intn(MaxUint16))
}
sort.Sort(uint16Slice(s4))

buf := make([]uint16, sz1+sz2+sz3+sz4)

b.Run("union2by2", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
for i := 0; i < sarrsnum; i++ {
union2by2(sarrs[i], s2, buf)
union2by2(sarrs[i], s3, buf)
union2by2(sarrs[i], s4, buf)
}
}
})

// the old, non-branchless implementation for performance
// comparison can be found here:
// https://github.com/RoaringBitmap/roaring/blob/ff33c3b226c3ac033bf1a0b0f3ed647fc9cd2efa/setutil_generic.go
}

0 comments on commit f127594

Please sign in to comment.