In [None]:
%%writefile simd_divergence.c
#include <immintrin.h>
#include <stdio.h>

int main() {
    // Create a vector of integers
    // 256-bit vector holding 8 × 32-bit integers
    __m256i values = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);


    // Create mask: first 4 elements = true, last 4 = false
    // Mask: [TRUE, TRUE, TRUE, TRUE, FALSE, FALSE, FALSE, FALSE]
    // This simulates: "if (index < 4)"

    __m256i mask = _mm256_setr_epi32(-1, -1, -1, -1, 0, 0, 0, 0);

    // Divergence !!!
    __m256i path_a = _mm256_mullo_epi32(values, _mm256_set1_epi32(2)); // multiply whole vector by 2
    __m256i path_b = _mm256_mullo_epi32(values, _mm256_set1_epi32(3)); // multiply whole vector by 3


    // but path_a and path_b is applied to the whole do how to branch them ???
    // use the mask -> mask
    // if mask is true, take path_a, else path_b

    __m256i result = _mm256_blendv_epi8(path_b, path_a, mask);

    int out[8];
    _mm256_storeu_si256((__m256i*)out, result);

    for (int i = 0; i < 8; i++) { printf("%d | %d\n", i, out[i]); }

    return 0;
}

Writing simd_divergence.c


In [None]:
!gcc -O3 -Wall -Wextra -mavx2  -o simd_divergence simd_divergence.c
!./simd_divergence

0 | 0
1 | 2
2 | 4
3 | 6
4 | 12
5 | 15
6 | 18
7 | 21


In [None]:
# __global__ void warpDivergeExample() {
#    int tid = threadIdx.x;
#
#   // This if-else causes warp divergence
#   if (tid < 16) {
#       // First 16 threads (0-15) execute Path A
#       tid = tid * 2;
#   } else {
#       // Last 16 threads (16-31) execute Path B
#       tid = tid * 3;
#   }
#}

In [None]:
# Visual representation:
#
# values:  [0,  1,  2,  3,  4,  5,  6,  7]
#              ↓   ↓   ↓   ↓   ↓   ↓   ↓   ↓
#          ┌───────────────────────────────────┐
#          │  Compute path_a (* 2)             │ ← Processes ALL 8
#          │  [0, 2, 4, 6, 8, 10, 12, 14]      │
#          └───────────────────────────────────┘
#              ↓   ↓   ↓   ↓   ↓   ↓   ↓   ↓
#          ┌───────────────────────────────────┐
#          │  Compute path_b (* 3)             │ ← Processes ALL 8
#          │  [0, 3, 6, 9, 12, 15, 18, 21]     │
#          └───────────────────────────────────┘
#              ↓   ↓   ↓   ↓   ↓   ↓   ↓   ↓
#          ┌───────────────────────────────────┐
#          │  Blend based on mask              │
#          │  mask: [-1,-1,-1,-1, 0, 0, 0, 0]  │
#          └───────────────────────────────────┘
#              ↓   ↓   ↓   ↓   ↓   ↓   ↓   ↓
# result:  [0, 2, 4, 6, 12, 15, 18, 21]
#          └──from A──┘ └───from B────┘