Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve masked load/store for bytes and words. #2821

Open
MkazemAkhgary opened this issue Mar 29, 2024 · 8 comments
Open

improve masked load/store for bytes and words. #2821

MkazemAkhgary opened this issue Mar 29, 2024 · 8 comments
Labels
Bugs Performance All issues related to performance/code generation

Comments

@MkazemAkhgary
Copy link

MkazemAkhgary commented Mar 29, 2024

Storing values in a divergent control flow can be inefficient for byte and word data types such as bool, int8, int16, uint8 and uint16. sometimes ispc may perform gather/scatter without emitting a performance warning. This can occur when working with mentioned types in masked regions, particularly when:

  • Assigning values to varying byte/word lvalue references.
  • Writing to or reading from arrays.

Here is an example

void SetValue(int8& dest, int8 src)
{
    dest = src;
}
Compiled for AVX2: (click to expand)
SetValue___REFvybvyb:                   # @SetValue___REFvybvyb
        vextracti128    xmm2, ymm0, 1
        vmovdqa xmm3, xmmword ptr [rip + .LCPI1_0] # xmm3 = <0,4,8,12,u,u,u,u,u,u,u,u,u,u,u,u>
        vpshufb xmm2, xmm2, xmm3
        vpshufb xmm0, xmm0, xmm3
        vpunpckldq      xmm0, xmm0, xmm2        # xmm0 = xmm0[0],xmm2[0],xmm0[1],xmm2[1]
        vpand   xmm0, xmm0, xmmword ptr [rip + .LCPI1_1]
        vmovmskps       eax, ymm1
        test    al, 1
        jne     .LBB1_1
        test    al, 2
        jne     .LBB1_3
.LBB1_4:
        test    al, 4
        jne     .LBB1_5
.LBB1_6:
        test    al, 8
        jne     .LBB1_7
.LBB1_8:
        test    al, 16
        jne     .LBB1_9
.LBB1_10:
        test    al, 32
        jne     .LBB1_11
.LBB1_12:
        test    al, 64
        jne     .LBB1_13
.LBB1_14:
        test    al, al
        js      .LBB1_15
.LBB1_16:
        vzeroupper
        ret
.LBB1_1:
        vpextrb byte ptr [rdi], xmm0, 0
        test    al, 2
        je      .LBB1_4
.LBB1_3:
        vpextrb byte ptr [rdi + 1], xmm0, 1
        test    al, 4
        je      .LBB1_6
.LBB1_5:
        vpextrb byte ptr [rdi + 2], xmm0, 2
        test    al, 8
        je      .LBB1_8
.LBB1_7:
        vpextrb byte ptr [rdi + 3], xmm0, 3
        test    al, 16
        je      .LBB1_10
.LBB1_9:
        vpextrb byte ptr [rdi + 4], xmm0, 4
        test    al, 32
        je      .LBB1_12
.LBB1_11:
        vpextrb byte ptr [rdi + 5], xmm0, 5
        test    al, 64
        je      .LBB1_14
.LBB1_13:
        vpextrb byte ptr [rdi + 6], xmm0, 6
        test    al, al
        jns     .LBB1_16
.LBB1_15:
        vpextrb byte ptr [rdi + 7], xmm0, 7
        vzeroupper
        ret

following optimization is possible, which can also be applied to other types (uint8, int16, uint16, bool)

inline void SetValue(int8& dest, int8 src)
{
    if ((((1 << TARGET_WIDTH) - 1) ^ lanemask()) == 0)
    {
        unmasked
        {
            dest = src;
        }
    }
    else
    {
#if TARGET_WIDTH <= 32
        typedef uint32 TMask;
#else
        typedef uint64 TMask;
#endif

        uniform TMask mask = lanemask();
        unmasked
        {
            dest = select((((TMask)1 << programIndex) & mask) != 0, src, dest);
        }
    }
}

for AVX2, this would compile to

SetValue___REFvytvyt:                  # @SetValue___REFvytvyt
        vmovmskps       eax, ymm1
        cmp     eax, 255
        je      .LBB1_2
        vmovd   xmm1, eax
        vpbroadcastd    ymm1, xmm1
        vpand   ymm1, ymm1, ymmword ptr [rip + .LCPI1_0]
        vpxor   xmm2, xmm2, xmm2
        vpcmpeqd        ymm1, ymm1, ymm2
        vextracti128    xmm2, ymm1, 1
        vpackssdw       xmm1, xmm1, xmm2
        vpacksswb       xmm1, xmm1, xmm1
        vmovq   xmm2, qword ptr [rdi]           # xmm2 = mem[0],zero
        vpblendvb       xmm0, xmm0, xmm2, xmm1
.LBB1_2:
        vmovq   qword ptr [rdi], xmm0
        vzeroupper
        ret

In case of accessing byte/word arrays

  • A performance warning should be emitted when gather/scatter is present.

  • It would be beneficial to have a set of functions for loading from and storing to byte/word arrays that assume the array size is a multiple of 4 or 2. These functions could use more efficient instructions such as vpmaskmov or vpgatherdd. If the array size is known at compile time, the compiler could automatically use these functions. However, invoking these functions with an incorrect array size may result in a memory access violation.

Here is an implementation to load from int8 array when the index is not continuous:

inline varying int8 FastLoadByte(const uniform int8 * const uniform arr, varying int32 Index)
{
    // return arr[Index]; // this is 30 instructions with AVX2 which loads bytes one by one.
    
    // This is 19 instructions with AVX2 using vpgatherdd, faster than loading bytes one by one.
    varying int32 DwInd = Index >> 2; // Index/4
    varying int32 Shift = (Index & 3) << 3; // (Index%4)*8
    varying uint32 Dword = ((uniform uint32 * uniform) & arr[0])[DwInd];
    return (Dword >> Shift) & 0xFF;
}

I think vpmaskmov can be used if index is continuous, but I haven't been able to implement it yet.

  • Alternatively, instead of specific functions, the assume keyword could be used to inform the compiler of a safe upper bound, allowing it to avoid gather/scatter using vpmaskmov or perform a more efficient gather as shown in the FastLoadByte function.
// Inform the compiler it's safe to read from or write to byteArray as int32, compiler should be able to use `vpmaskmov` or `vpgatherdd`.
assume(&((uniform int32 * uniform)byteArray)[reduce_max(i)/sizeof(uniform int32)] != NULL);
@MkazemAkhgary MkazemAkhgary changed the title improve masked store performance for bytes and words. improve masked load/store performance for bytes and words. Mar 29, 2024
@MkazemAkhgary MkazemAkhgary changed the title improve masked load/store performance for bytes and words. improve masked load/store for bytes and words. Mar 29, 2024
@MkazemAkhgary
Copy link
Author

MkazemAkhgary commented Mar 30, 2024

Here is a somewhat faster implementation for FastLoadByte. return type is changed to int32 to avoid unnecessary unpacking.

varying int32 FastLoadByte(const uniform int8 * const uniform arr, varying int32 Index)
{
    varying int32 DwInd = Index >> 2; // Index/4
    varying int32 Shift = ((uint32)Index << 30) >> 27; // (Index%4)*8
    varying uint32 Dword = ((uniform uint32 * uniform) & arr[0])[DwInd];
    return (Dword >> Shift) & 0xFF;
}
        vmovmskps       eax, ymm1
        vpbroadcastd    ymm2, dword ptr [rip + .LCPI1_0] # ymm2 = [4294967292,4294967292,4294967292,4294967292,4294967292,4294967292,4294967292,4294967292]
        vpand   ymm3, ymm0, ymm2
        cmp     eax, 255
        jne     .LBB1_2
        vpcmpeqd        ymm1, ymm1, ymm1
.LBB1_2:
        vpxor   xmm2, xmm2, xmm2
        vpgatherdd      ymm2, ymmword ptr [rdi + ymm3], ymm1
        vpslld  ymm0, ymm0, 3
        vpbroadcastd    ymm1, dword ptr [rip + .LCPI1_1] # ymm1 = [24,24,24,24,24,24,24,24]
        vpand   ymm0, ymm0, ymm1
        vpsrlvd ymm0, ymm2, ymm0
        vpand   ymm0, ymm0, ymmword ptr [rip + .LCPI1_2]
        ret

This is hand tuned which I think should work for any mask.

      vpand   ymm3, ymm0, dword ptr [rip + .LCPI1_0] # [0xfffffffc,...]
      vpxor   xmm2, xmm2, xmm2
      vpgatherdd      ymm2, ymmword ptr [rdi + ymm3], ymm1
      vpslld  ymm0, ymm0, 30
      vpsrld  ymm0, ymm0, 27
      vpsrlvd ymm0, ymm2, ymm0
      vpand   ymm0, ymm0, ymmword ptr [rip + .LCPI1_1] # [0xff,...]
      ret

I don't know what is the reason behind using vpxor before vpgatherdd. maybe it can be avoided with -O3? this should save a register too.

      vpand   ymm2, ymm0, dword ptr [rip + .LCPI1_0] # [0xfffffffc,...]
      vpgatherdd      ymm2, ymmword ptr [rdi + ymm2], ymm1
      vpslld  ymm0, ymm0, 30
      vpsrld  ymm0, ymm0, 27
      vpsrlvd ymm0, ymm2, ymm0
      vpand   ymm0, ymm0, ymmword ptr [rip + .LCPI1_1] # [0xff,...]
      ret

the other two vpand instructions can also be replaced with shifts to remove constants. (x & 0xfffffffc) == ((x >> 2) << 2) and (x & 0xff) == ((uint32)(x << 24) >> 24). not sure that it would be faster though.

@MkazemAkhgary
Copy link
Author

MkazemAkhgary commented Apr 1, 2024

I think the SetValue function can also be improved. here is the compiled code by ispc.

SetValue___REFvytvyt:                  # @SetValue___REFvytvyt
        vmovmskps       eax, ymm1
        cmp     eax, 255
        je      .LBB1_2
        vmovd   xmm1, eax
        vpbroadcastd    ymm1, xmm1
        vpand   ymm1, ymm1, ymmword ptr [rip + .LCPI1_0]
        vpxor   xmm2, xmm2, xmm2
        vpcmpeqd        ymm1, ymm1, ymm2
        vextracti128    xmm2, ymm1, 1
        vpackssdw       xmm1, xmm1, xmm2
        vpacksswb       xmm1, xmm1, xmm1
        vmovq   xmm2, qword ptr [rdi]           # xmm2 = mem[0],zero
        vpblendvb       xmm0, xmm0, xmm2, xmm1
.LBB1_2:
        vmovq   qword ptr [rdi], xmm0
        vzeroupper
        ret

for the masked part, it's doing some extra work which I think can be more straight forward. instead of building an inverted mask from eax, convert ymm1 mask to xmm1 and use that mask for vpblendvb.

  vextracti128    xmm2, ymm1, 1      # extract upper 128 bits from the mask
  vpackssdw       xmm1, xmm1, xmm2   # convert 256-bit mask to 128-bit mask
  vpacksswb       xmm1, xmm1, xmm1   # convert 128-bit mask to 64-bit mask.
  vmovq   xmm2, qword ptr [rdi]      # xmm2 = mem[0],zero
  vpblendvb       xmm0, xmm2, xmm0, xmm1   # if 1s, set xmm2, otherwise leave xmm0 unchanged

together with rest of the code:

SetValue___REFvytvyt:                  # @SetValue___REFvytvyt
       vmovmskps       eax, ymm1
       cmp     eax, 255
       je      .LBB1_2
       vextracti128    xmm2, ymm1, 1
       vpackssdw       xmm1, xmm1, xmm2
       vpacksswb       xmm1, xmm1, xmm1
       vmovq   xmm2, qword ptr [rdi]      # xmm2 = mem[0],zero
       vpblendvb       xmm0, xmm2, xmm0, xmm1
.LBB1_2:
       vmovq   qword ptr [rdi], xmm0
       vzeroupper
       ret

I tried a few ways and came up with this, it's still doing some redundant work. 😞

inline void SetByte(int8& dest, int8 src)
{
    if ((((1 << TARGET_WIDTH) - 1) ^ lanemask()) == 0)
    {
        unmasked
        {
            dest = src;
        }
    }
    else
    {
        uniform uint32<TARGET_WIDTH> bmask;
        unmasked
        {
            bmask[programIndex] = 0;
        }
        bmask[programIndex] = ~0;
        unmasked
        {
            dest = select(bmask[programIndex] == 0, dest, src);
        }
    }
}
SetByte___REFvyTvyT:                    # @SetByte___REFvyTvyT
        vmovmskps       eax, ymm1
        cmp     al, -1
        je      .LBB1_2
        vpcmpeqd        ymm2, ymm2, ymm2        # unnecessary
        vxorps          xmm3, xmm3, xmm3        # unnecessary
        vblendvps       ymm1, ymm3, ymm2, ymm1  # unnecessary
        vpcmpeqd        ymm1, ymm1, ymm3        # unnecessary
        vextracti128    xmm2, ymm1, 1
        vpackssdw       xmm1, xmm1, xmm2
        vpacksswb       xmm1, xmm1, xmm1
        vmovq   xmm2, qword ptr [rdi]           # xmm2 = mem[0],zero
        vpblendvb       xmm0, xmm0, xmm2, xmm1
.LBB1_2:                                # %common.ret
        vmovq   qword ptr [rdi], xmm0
        vzeroupper
        ret

@MkazemAkhgary
Copy link
Author

MkazemAkhgary commented Apr 2, 2024

I noticed that there is a __mask variable.

While debugging, a variable, __mask, is available to provide the current program execution mask at the current point in the program

looks like a very useful variable. this, unlike lanemask, gives direct access to the execution mask. this also worked for non-debug builds and it gave the best results so far. please don't remove it! 😄

void SetByte(int8& dest, int8 src)
{
    varying uint32 mask = __mask;
    unmasked
    {
        dest = mask ? src : dest;
    }
}

masked

SetByte___REFvyTvyT:                    # @SetByte___REFvyTvyT
        vpxor   xmm2, xmm2, xmm2          # unnecessary
        vpcmpeqd        ymm1, ymm1, ymm2  # unnecessary
        vextracti128    xmm2, ymm1, 1
        vpackssdw       xmm1, xmm1, xmm2
        vpacksswb       xmm1, xmm1, xmm1
        vmovq   xmm2, qword ptr [rdi]           # xmm2 = mem[0],zero
        vpblendvb       xmm0, xmm0, xmm2, xmm1
        vmovq   qword ptr [rdi], xmm0
        vzeroupper
        ret

unmasked

SetByte___UM_REFvyTvyT:                 # @SetByte___UM_REFvyTvyT
        vmovlps qword ptr [rdi], xmm0
        ret

this looks much cleaner, there are still two unnecessary instructions vpxor and vpcmpeqd but I'm gonna leave it at that.

@nurmukhametov
Copy link
Collaborator

nurmukhametov commented Apr 2, 2024

I noticed that there is a __mask variable.

I have not read the whole thread but I need to comment about __mask. At the moment as I understand, it is deliberately an internal entity. I am not quite sure why and why lanemask is not enough for your use case.

I would not discourage you to use it if you need, but you at least need to understand that it is defined in a different way for different targets. See

ispc/src/parse.yy

Line 3008 in 45d66e9

static void lAddMaskToSymbolTable(SourcePos pos) {
and
#define UIntMaskType unsigned int64

@MkazemAkhgary
Copy link
Author

MkazemAkhgary commented Apr 2, 2024

I would not discourage you to use it if you need, but you at least need to understand that it is defined in a different way for different targets.

@nurmukhametov Thank you for pointing that out. I think I found a bug!

void SetByte(int8& dest, int8 src)
{
#if TARGET_ELEMENT_WIDTH == 1
    typedef uint8 TMask;
#elif TARGET_ELEMENT_WIDTH == 2
    typedef uint16 TMask;
#elif TARGET_ELEMENT_WIDTH == 4
    typedef uint32 TMask;
#elif TARGET_ELEMENT_WIDTH == 8
    typedef uint64 TMask;
#else
#error "unknown mask"
#endif
    varying TMask mask = __mask;
    unmasked
    {
        dest = mask ? src : dest;
    }
}

I tested this for all cpu targets on godbolt and the output assembly seems right, except for sse4.1-i8x16 and sse4.2-i8x16. the generated output does nothing, or perhaps I’ve made an error.


I am not quite sure why and why lanemask is not enough for your use case.

I used lanemask but the compiler is not able to optimize it very well.

I may try switching to int32 later, that quadruples the memory usage but SIMD instructions are more friendly with int32, though I must profile to see which way performs better.


Tldr of this thread

I'm working with small types to perform some expensive computations. unfortunately, sometimes ispc gives up and uses gather/scatter operations when working with types like uint8 or bool. I’ve made this issue in hope of optimizing the performance with these types.

@nurmukhametov
Copy link
Collaborator

@nurmukhametov Thank you for pointing that out. I think I found a bug!

void SetByte(int8& dest, int8 src)
{
#if TARGET_ELEMENT_WIDTH == 1
    typedef uint8 TMask;
#elif TARGET_ELEMENT_WIDTH == 2
    typedef uint16 TMask;
#elif TARGET_ELEMENT_WIDTH == 4
    typedef uint32 TMask;
#elif TARGET_ELEMENT_WIDTH == 8
    typedef uint64 TMask;
#else
#error "unknown mask"
#endif
    varying TMask mask = __mask;
    unmasked
    {
        dest = mask ? src : dest;
    }
}

This doesn't look right because TARGET_ELEMENT_WIDTH is not ISPC_MASK_BITS. I think idiomatically you need to write this code instead:

void SetByte(int8& dest, int8 src) {
    varying bool mask = __mask;
    unmasked {
        dest = mask ? src : dest;
    }
}

@MkazemAkhgary
Copy link
Author

This doesn't look right because TARGET_ELEMENT_WIDTH is not ISPC_MASK_BITS. I think idiomatically you need to write this code instead:

void SetByte(int8& dest, int8 src) {
    varying bool mask = __mask;
    unmasked {
        dest = mask ? src : dest;
    }
}

I see, yeah this looks correct and much cleaner. how ever the issue with sse4.1-i8x16 and sse4.2-i8x16 remains. the function just returns.

SetByte___REFvytvyt:                    # @SetByte___REFvytvyt
        ret

same issue with this code:

void SetByte(int8& dest, int8 src)
{
    uniform uint32 mask = lanemask();
    unmasked
    {
        dest = select(((1 << programIndex) & mask) != 0, src, dest);
    }
}

@nurmukhametov
Copy link
Collaborator

I see, yeah this looks correct and much cleaner. how ever the issue with sse4.1-i8x16 and sse4.2-i8x16 remains. the function just returns.

I have created issue #2824 with that.

@pbrubaker pbrubaker added Performance All issues related to performance/code generation Bugs labels Apr 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Bugs Performance All issues related to performance/code generation
Projects
Status: No status
Development

No branches or pull requests

3 participants