Skip to content

Commit

Permalink
Fixed leastNBitsMask, blitElement, at and more for various usage cond…
Browse files Browse the repository at this point in the history
…itions. Turns out integer promotion causes things to work in a complicated way when using the same code to mess with integers of different widths. The implemented solution feels a little hacky, but it fixes the immediate errors (and probably ones we didn't even notice) (#93)
  • Loading branch information
Scottbruceheart committed Jun 15, 2024
1 parent be78698 commit be280d3
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 25 deletions.
35 changes: 17 additions & 18 deletions inc/zoo/swar/SWAR.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,21 @@ constexpr __uint128_t lsbIndex(__uint128_t v) noexcept {
}
#endif

// This is placed ahead of other bare manipulation functions to reduce integer
// promotion weirdness in the constant creation of SWAR and related.
template<int NBits, typename T>
constexpr auto leastNBitsMask() {
// It was integer promotion screwing this up at equal sizes all along, special casing unblocks.
constexpr auto type_bits = sizeof(T) * 8;
if constexpr(NBits < type_bits) {
return (T{1}<<NBits)-1;
} else if constexpr(type_bits < 32) {
return (T{1}<<NBits)-1;
} else if constexpr(type_bits >= 32) {
return ~T{0};
}
}

/// Core abstraction around SIMD Within A Register (SWAR). Specifies 'lanes'
/// of NBits width against a type T, and provides an abstraction for performing
/// SIMD operations against that primitive type T treated as a SIMD register.
Expand All @@ -86,13 +101,7 @@ struct SWAR {
LeastSignificantBit = meta::BitmaskMaker<T, std::make_unsigned_t<T>{1ull}, NBits>::value,
// Simply shifting over Least causes problems with lanes that don't fit the SWAR exactly.
MostSignificantBit = meta::BitmaskMaker<T, std::make_unsigned_t<T>{1ull<<(NBits - 1)}, NBits>::value,
// Computing LeastSignificantLaneMask with uint16_t results in an
// unknown narrowing or type conversion that causes use of .at() and
// similar with 16bit sized SWARS to fail.
LeastSignificantLaneMask =
sizeof(T) * 8 == NBits ? // needed to avoid shifting all bits
~T(0) :
~(~T(0) << NBits),
LeastSignificantLaneMask = leastNBitsMask<NBits, T>(),
// Use LowerBits in favor of ~MostSignificantBit to not pollute
// "don't care" bits when non-power-of-two bit lane sizes are supported
LowerBits = MostSignificantBit - LeastSignificantBit,
Expand Down Expand Up @@ -223,7 +232,7 @@ constexpr auto horizontalEquality(SWAR<NBits, T> left, SWAR<NBits, T> right) {
// TODO(scottbruceheart) Attempting to use binary not (~) results in negative shift warnings.
template<int NBits, typename T = uint64_t>
constexpr auto isolate(T pattern) {
return pattern & ((T(1)<<NBits)-1);
return pattern & leastNBitsMask<NBits, T>();
}

/// Clears the least bit set in type T
Expand All @@ -238,16 +247,6 @@ constexpr auto isolateLSB(T v) {
return v & ~clearLSB(v);
}

template<int NBits, typename T>
constexpr auto leastNBitsMask() {
return (T{1}<<NBits)-1;
}

template<int NBits, uint64_t T>
constexpr auto leastNBitsMask() {
return ~((0ull)<<NBits);
}

template<int NBits, typename T = uint64_t>
constexpr T mostNBitsMask() {
return ~leastNBitsMask<sizeof(T)*8-NBits, T>();
Expand Down
1 change: 0 additions & 1 deletion test/map/RobinHood.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ auto showMetadata(std::size_t index, MD *md) {
//WARN(__VA_ARGS__)

using SMap = zoo::rh::RH_Frontend_WithSkarupkeTail<std::string, int, 255, 5, 3>;

auto valueInvoker(void *p, std::size_t index) {
return static_cast<SMap *>(p)->values_[index].value();
}
Expand Down
50 changes: 44 additions & 6 deletions test/swar/BasicOperations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,23 +349,49 @@ static_assert(0xF000'0000ul == mostNBitsMask<4, u32>());
static_assert(0xF800'0000ul == mostNBitsMask<5, u32>());
static_assert(0xFC00'0000ul == mostNBitsMask<6, u32>());



static_assert(0x01u == leastNBitsMask<1, u8>());
static_assert(0x03u == leastNBitsMask<2, u8>());
static_assert(0x07u == leastNBitsMask<3, u8>());
static_assert(0x0Fu == leastNBitsMask<4, u8>());
static_assert(0x1Fu == leastNBitsMask<5, u8>());

static_assert(0x0000'01ul == leastNBitsMask<1, u32>());
static_assert(0x0000'03ul == leastNBitsMask<2, u32>());
static_assert(0x0000'07ul == leastNBitsMask<3, u32>());
static_assert(0x0000'0Ful == leastNBitsMask<4, u32>());
static_assert(0x0000'1Ful == leastNBitsMask<5, u32>());
static_assert(0x3Fu == leastNBitsMask<6, u8>());
static_assert(0x7Fu == leastNBitsMask<7, u8>());
static_assert(0xFFu == leastNBitsMask<8, u8>());

static_assert(0x0001u == leastNBitsMask<1, u16>());
static_assert(0x0003u == leastNBitsMask<2, u16>());
static_assert(0x0007u == leastNBitsMask<3, u16>());
static_assert(0x000Fu == leastNBitsMask<4, u16>());
static_assert(0x001Fu == leastNBitsMask<5, u16>());
static_assert(0x003Fu == leastNBitsMask<6, u16>());
static_assert(0x007Fu == leastNBitsMask<7, u16>());
static_assert(0x00FFu == leastNBitsMask<8, u16>());
static_assert(0x01FFu == leastNBitsMask<9, u16>());
static_assert(0x03FFu == leastNBitsMask<10, u16>());
static_assert(0x07FFu == leastNBitsMask<11, u16>());
static_assert(0x0FFFu == leastNBitsMask<12, u16>());
static_assert(0x1FFFu == leastNBitsMask<13, u16>());
static_assert(0x3FFFu == leastNBitsMask<14, u16>());
static_assert(0x7FFFu == leastNBitsMask<15, u16>());
static_assert(0xFFFFu == leastNBitsMask<16, u16>());

static_assert(0x0000'0001ul == leastNBitsMask<1, u32>());
static_assert(0x0000'0003ul == leastNBitsMask<2, u32>());
static_assert(0x0000'0007ul == leastNBitsMask<3, u32>());
static_assert(0x0000'000Ful == leastNBitsMask<4, u32>());
static_assert(0x0000'001Ful == leastNBitsMask<5, u32>());
static_assert(0x7FFF'FFFFul == leastNBitsMask<31, u32>());
static_assert(0xFFFF'FFFFul == leastNBitsMask<32, u32>());

static_assert(0x01ull == leastNBitsMask<1, u64>());
static_assert(0x03ull == leastNBitsMask<2, u64>());
static_assert(0x07ull == leastNBitsMask<3, u64>());
static_assert(0x0Full == leastNBitsMask<4, u64>());
static_assert(0x1Full == leastNBitsMask<5, u64>());
static_assert(0x7FFF'FFFF'FFFF'FFFFull == leastNBitsMask<63, u64>());
static_assert(0xFFFF'FFFF'FFFF'FFFFull == leastNBitsMask<64, u64>());

static_assert(0xB == isolate<4>(0x1337'BDBC'2448'ACABull));
static_assert(0xAB == isolate<8>(0x1337'BDBC'2448'ACABull));
Expand Down Expand Up @@ -414,6 +440,18 @@ static_assert(8 == lsbIndex(1<<8));
static_assert(17 == lsbIndex(1<<17));
static_assert(30 == lsbIndex(1<<30));

static_assert(S3_16{Literals<3,u16>, {5,4,3,2,5}}.value() == S3_16{Literals<3,u16>, {5,4,3,2,1}}.blitElement(0, 5).value());
static_assert(S3_16{Literals<3,u16>, {5,4,3,5,1}}.value() == S3_16{Literals<3,u16>, {5,4,3,2,1}}.blitElement(1, 5).value());
static_assert(S3_16{Literals<3,u16>, {5,4,5,2,1}}.value() == S3_16{Literals<3,u16>, {5,4,3,2,1}}.blitElement(2, 5).value());
static_assert(S3_16{Literals<3,u16>, {5,1,3,2,1}}.value() == S3_16{Literals<3,u16>, {5,4,3,2,1}}.blitElement(3, 1).value());
static_assert(S3_16{Literals<3,u16>, {1,4,3,2,1}}.value() == S3_16{Literals<3,u16>, {5,4,3,2,1}}.blitElement(4, 1).value());

static_assert(1 == S3_16{Literals<3,u16>, {5,4,3,2,1}}.at(0));
static_assert(2 == S3_16{Literals<3,u16>, {5,4,3,2,1}}.at(1));
static_assert(3 == S3_16{Literals<3,u16>, {5,4,3,2,1}}.at(2));
static_assert(4 == S3_16{Literals<3,u16>, {5,4,3,2,1}}.at(3));
static_assert(5 == S3_16{Literals<3,u16>, {5,4,3,2,1}}.at(4));

#define GE_MSB_TEST(left, right, result) static_assert(result == greaterEqual_MSB_off<4, u32>(SWAR<4, u32>(left), SWAR<4, u32>(right)).value());

GE_MSB_TEST(
Expand Down

0 comments on commit be280d3

Please sign in to comment.