In [54]:
import math

def bit_reverse(x, bits):
    return int('{:0{width}b}'.format(x, width=bits)[::-1], 2)

def print_dif_nr_ntt(n):
    logn = int(math.log2(n))
    print("\n=== DIF-NR + GS NTT (Fixed) ===")
    for stage in range(logn):
        m = 2 ** (logn - stage)
        half = m // 2
        tw_stride = n // m
        print(f"\n[Stage {stage}] m={m}, twiddle stride={tw_stride}")
        for i in range(0, n, m):
            for j in range(half):
                idx1 = i + j
                idx2 = i + j + half
                tw = j * tw_stride
                print(f"  BU({idx1}, {idx2}, w^{tw})")


def print_dit_rn_intt(n):
    logn = int(math.log2(n))
    print("\n=== DIT-RN + GS INTT (Fixed) ===")

    # Bit-reversed indices
    def bit_reverse(x, bits):
        return int('{:0{w}b}'.format(x, w=bits)[::-1], 2)

    bit_rev_indices = [bit_reverse(i, logn) for i in range(n)]

    for stage in range(logn):
        m = 2 ** (stage + 1)  # block size
        half = m // 2
        tw_stride = n // m
        print(f"\n[Stage {stage}] m={m}, twiddle stride={tw_stride}")
        for i in range(0, n, m):
            for j in range(half):
                idx1 = i + j
                idx2 = i + j + half
                real1 = bit_rev_indices[idx1]
                real2 = bit_rev_indices[idx2]
                tw = j * tw_stride
                print(f"  BU({real1}, {real2}, w^{tw})")


In [128]:
print_dif_nr_ntt(128)


=== DIF-NR + GS NTT (Fixed) ===

[Stage 0] m=128, twiddle stride=1
  BU(0, 64, w^0)
  BU(1, 65, w^1)
  BU(2, 66, w^2)
  BU(3, 67, w^3)
  BU(4, 68, w^4)
  BU(5, 69, w^5)
  BU(6, 70, w^6)
  BU(7, 71, w^7)
  BU(8, 72, w^8)
  BU(9, 73, w^9)
  BU(10, 74, w^10)
  BU(11, 75, w^11)
  BU(12, 76, w^12)
  BU(13, 77, w^13)
  BU(14, 78, w^14)
  BU(15, 79, w^15)
  BU(16, 80, w^16)
  BU(17, 81, w^17)
  BU(18, 82, w^18)
  BU(19, 83, w^19)
  BU(20, 84, w^20)
  BU(21, 85, w^21)
  BU(22, 86, w^22)
  BU(23, 87, w^23)
  BU(24, 88, w^24)
  BU(25, 89, w^25)
  BU(26, 90, w^26)
  BU(27, 91, w^27)
  BU(28, 92, w^28)
  BU(29, 93, w^29)
  BU(30, 94, w^30)
  BU(31, 95, w^31)
  BU(32, 96, w^32)
  BU(33, 97, w^33)
  BU(34, 98, w^34)
  BU(35, 99, w^35)
  BU(36, 100, w^36)
  BU(37, 101, w^37)
  BU(38, 102, w^38)
  BU(39, 103, w^39)
  BU(40, 104, w^40)
  BU(41, 105, w^41)
  BU(42, 106, w^42)
  BU(43, 107, w^43)
  BU(44, 108, w^44)
  BU(45, 109, w^45)
  BU(46, 110, w^46)
  BU(47, 111, w^47)
  BU(48, 112, w^48)
  BU(49,

In [125]:
print_dif_nr_ntt(4)


=== DIF-NR + GS NTT (Fixed) ===

[Stage 0] m=4, twiddle stride=1
  BU(0, 2, w^0)
  BU(1, 3, w^1)

[Stage 1] m=2, twiddle stride=2
  BU(0, 1, w^0)
  BU(2, 3, w^0)


In [66]:
print_dit_rn_intt(16)


=== DIT-RN + GS INTT (Fixed) ===

[Stage 0] m=2, twiddle stride=8
  BU(0, 8, w^0)
  BU(4, 12, w^0)
  BU(2, 10, w^0)
  BU(6, 14, w^0)
  BU(1, 9, w^0)
  BU(5, 13, w^0)
  BU(3, 11, w^0)
  BU(7, 15, w^0)

[Stage 1] m=4, twiddle stride=4
  BU(0, 4, w^0)
  BU(8, 12, w^4)
  BU(2, 6, w^0)
  BU(10, 14, w^4)
  BU(1, 5, w^0)
  BU(9, 13, w^4)
  BU(3, 7, w^0)
  BU(11, 15, w^4)

[Stage 2] m=8, twiddle stride=2
  BU(0, 2, w^0)
  BU(8, 10, w^2)
  BU(4, 6, w^4)
  BU(12, 14, w^6)
  BU(1, 3, w^0)
  BU(9, 11, w^2)
  BU(5, 7, w^4)
  BU(13, 15, w^6)

[Stage 3] m=16, twiddle stride=1
  BU(0, 1, w^0)
  BU(8, 9, w^1)
  BU(4, 5, w^2)
  BU(12, 13, w^3)
  BU(2, 3, w^4)
  BU(10, 11, w^5)
  BU(6, 7, w^6)
  BU(14, 15, w^7)


In [106]:
import math

def print_split_stage_dif_nr_phase1(N, S):
    L = int(math.log2(N))
    print(f"\n=== Split-Stage DIF-NR + GS Trace ===")
    print(f"N = {N}, log2(N) = {L}, S = {S}")

    # Phase 1: Strided segment NTT (2^S-point, repeated stride times)
    stride = 2 ** (L - S)
    print(f"\n--- Phase 1: Strided NTT ({2**S}-point × {stride}) ---")
    for offset in range(stride):
        print(f"\n[Offset {offset}]")
        for stage in range(1, S + 1):
            m = 2 ** (S - stage + 1)
            r = (2 ** S) // m
            twiddle_stride = N // m
            print(f" Stage {stage} (m={m}, r={r})")
            for b in range(r):
                for j in range(m // 2):
                    e_idx = stride * (b * m + j) + offset
                    o_idx = stride * (b * m + j + m // 2) + offset
                    w_idx = (stride * j + offset) * r
                    print(f"  BU({e_idx}, {o_idx}, w^{w_idx})")

def print_split_stage_dif_nr_phase2(N, S):
    L = int(math.log2(N))
    block_size = 2 ** (L - S)
    block_num = N // block_size
    print(f"\n=== Phase 2: Sequential Segment DIF-NR + GS Trace ===")
    print(f"N = {N}, log2(N) = {L}, S = {S}, block_size = {block_size}, block_num = {block_num}")

    for block in range(block_num):
        print(f"\n[Block {block}]")
        for stage in range(1, L - S + 1):
            m = 2 ** ((L - S) - stage + 1)
            r = (2 ** (L - S)) // m
            print(f" Stage {stage} (m={m}, r={r})")
            for b in range(r):
                for j in range(m // 2):
                    e_idx = block * block_size + (b * m + j)
                    o_idx = block * block_size + (b * m + j + m // 2)
                    w_idx = j * block_num * r
                    print(f"  BU({e_idx}, {o_idx}, w^{w_idx})")

In [129]:
print_split_stage_dif_nr_phase1(N=128, S=4)



=== Split-Stage DIF-NR + GS Trace ===
N = 128, log2(N) = 7, S = 4

--- Phase 1: Strided NTT (16-point × 8) ---

[Offset 0]
 Stage 1 (m=16, r=1)
  BU(0, 64, w^0)
  BU(8, 72, w^8)
  BU(16, 80, w^16)
  BU(24, 88, w^24)
  BU(32, 96, w^32)
  BU(40, 104, w^40)
  BU(48, 112, w^48)
  BU(56, 120, w^56)
 Stage 2 (m=8, r=2)
  BU(0, 32, w^0)
  BU(8, 40, w^16)
  BU(16, 48, w^32)
  BU(24, 56, w^48)
  BU(64, 96, w^0)
  BU(72, 104, w^16)
  BU(80, 112, w^32)
  BU(88, 120, w^48)
 Stage 3 (m=4, r=4)
  BU(0, 16, w^0)
  BU(8, 24, w^32)
  BU(32, 48, w^0)
  BU(40, 56, w^32)
  BU(64, 80, w^0)
  BU(72, 88, w^32)
  BU(96, 112, w^0)
  BU(104, 120, w^32)
 Stage 4 (m=2, r=8)
  BU(0, 8, w^0)
  BU(16, 24, w^0)
  BU(32, 40, w^0)
  BU(48, 56, w^0)
  BU(64, 72, w^0)
  BU(80, 88, w^0)
  BU(96, 104, w^0)
  BU(112, 120, w^0)

[Offset 1]
 Stage 1 (m=16, r=1)
  BU(1, 65, w^1)
  BU(9, 73, w^9)
  BU(17, 81, w^17)
  BU(25, 89, w^25)
  BU(33, 97, w^33)
  BU(41, 105, w^41)
  BU(49, 113, w^49)
  BU(57, 121, w^57)
 Stage 2 (m=8, 

In [130]:
print_split_stage_dif_nr_phase2(N=128, S=4)


=== Phase 2: Sequential Segment DIF-NR + GS Trace ===
N = 128, log2(N) = 7, S = 4, block_size = 8, block_num = 16

[Block 0]
 Stage 1 (m=8, r=1)
  BU(0, 4, w^0)
  BU(1, 5, w^16)
  BU(2, 6, w^32)
  BU(3, 7, w^48)
 Stage 2 (m=4, r=2)
  BU(0, 2, w^0)
  BU(1, 3, w^32)
  BU(4, 6, w^0)
  BU(5, 7, w^32)
 Stage 3 (m=2, r=4)
  BU(0, 1, w^0)
  BU(2, 3, w^0)
  BU(4, 5, w^0)
  BU(6, 7, w^0)

[Block 1]
 Stage 1 (m=8, r=1)
  BU(8, 12, w^0)
  BU(9, 13, w^16)
  BU(10, 14, w^32)
  BU(11, 15, w^48)
 Stage 2 (m=4, r=2)
  BU(8, 10, w^0)
  BU(9, 11, w^32)
  BU(12, 14, w^0)
  BU(13, 15, w^32)
 Stage 3 (m=2, r=4)
  BU(8, 9, w^0)
  BU(10, 11, w^0)
  BU(12, 13, w^0)
  BU(14, 15, w^0)

[Block 2]
 Stage 1 (m=8, r=1)
  BU(16, 20, w^0)
  BU(17, 21, w^16)
  BU(18, 22, w^32)
  BU(19, 23, w^48)
 Stage 2 (m=4, r=2)
  BU(16, 18, w^0)
  BU(17, 19, w^32)
  BU(20, 22, w^0)
  BU(21, 23, w^32)
 Stage 3 (m=2, r=4)
  BU(16, 17, w^0)
  BU(18, 19, w^0)
  BU(20, 21, w^0)
  BU(22, 23, w^0)

[Block 3]
 Stage 1 (m=8, r=1)
  BU(24,

In [111]:
def collect_bu_set_from_dif_nr(n):
    logn = int(math.log2(n))
    bu_set = set()
    for stage in range(logn):
        m = 2 ** (logn - stage)
        half = m // 2
        tw_stride = n // m
        for i in range(0, n, m):
            for j in range(half):
                idx1 = i + j
                idx2 = i + j + half
                tw = j * tw_stride
                bu_set.add((min(idx1, idx2), max(idx1, idx2), tw))
    return bu_set

def collect_bu_set_from_split_stage(n, S):
    L = int(math.log2(n))
    bu_set = set()

    # Phase 1
    stride = 2 ** (L - S)
    for offset in range(stride):
        for stage in range(1, S + 1):
            m = 2 ** (S - stage + 1)
            r = (2 ** S) // m
            for b in range(r):
                for j in range(m // 2):
                    e_idx = stride * (b * m + j) + offset
                    o_idx = stride * (b * m + j + m // 2) + offset
                    w_idx = (stride * j + offset) * r
                    bu_set.add((min(e_idx, o_idx), max(e_idx, o_idx), w_idx))

    # Phase 2
    block_size = 2 ** (L - S)
    block_num = n // block_size
    for block in range(block_num):
        for stage in range(1, L - S + 1):
            m = 2 ** ((L - S) - stage + 1)
            r = (2 ** (L - S)) // m
            for b in range(r):
                for j in range(m // 2):
                    e_idx = block * block_size + (b * m + j)
                    o_idx = block * block_size + (b * m + j + m // 2)
                    w_idx = j * block_num * r
                    bu_set.add((min(e_idx, o_idx), max(e_idx, o_idx), w_idx))

    return bu_set


def compare_dif_and_split(n, S):
    full = collect_bu_set_from_dif_nr(n)
    split = collect_bu_set_from_split_stage(n, S)

    only_in_full = full - split
    only_in_split = split - full

    if only_in_full or only_in_split:
        print("❌ Difference found!")
        if only_in_full:
            print("Only in Full DIF-NR:")
            for item in sorted(only_in_full): print(" ", item)
        if only_in_split:
            print("Only in Split NTT:")
            for item in sorted(only_in_split): print(" ", item)
    else:
        print("✅ BU sets are identical — split version is functionally correct.")


In [114]:
compare_dif_and_split(n=32, S=2)


✅ BU sets are identical — split version is functionally correct.


In [115]:
from collections import defaultdict

def collect_stagewise_bu_from_dif(n):
    logn = int(math.log2(n))
    stage_bu = defaultdict(set)
    for stage in range(logn):
        m = 2 ** (logn - stage)
        half = m // 2
        tw_stride = n // m
        for i in range(0, n, m):
            for j in range(half):
                idx1 = i + j
                idx2 = i + j + half
                tw = j * tw_stride
                bu = (min(idx1, idx2), max(idx1, idx2), tw)
                stage_bu[stage].add(bu)
    return stage_bu

from collections import defaultdict
import math

def collect_stagewise_bu_from_split(n, S):
    L = int(math.log2(n))
    stage_bu = defaultdict(set)

    # Phase 1
    stride = 2 ** (L - S)
    for offset in range(stride):
        for stage in range(1, S + 1):
            m = 2 ** (S - stage + 1)
            r = (2 ** S) // m
            global_stage = stage - 1  # phase 1 starts at stage 0
            for b in range(r):
                for j in range(m // 2):
                    e_idx = stride * (b * m + j) + offset
                    o_idx = stride * (b * m + j + m // 2) + offset
                    w_idx = (stride * j + offset) * r
                    bu = (min(e_idx, o_idx), max(e_idx, o_idx), w_idx)
                    stage_bu[global_stage].add(bu)

    # Phase 2
    block_size = 2 ** (L - S)
    block_num = n // block_size
    for block in range(block_num):
        for stage in range(1, L - S + 1):
            m = 2 ** ((L - S) - stage + 1)
            r = (2 ** (L - S)) // m
            global_stage = S + stage - 1  # phase 2 starts at stage S
            for b in range(r):
                for j in range(m // 2):
                    e_idx = block * block_size + (b * m + j)
                    o_idx = block * block_size + (b * m + j + m // 2)
                    w_idx = j * block_num * r
                    bu = (min(e_idx, o_idx), max(e_idx, o_idx), w_idx)
                    stage_bu[global_stage].add(bu)

    return stage_bu



def compare_per_stage(n, S):
    L = int(math.log2(n))
    full = collect_stagewise_bu_from_dif(n)
    split = collect_stagewise_bu_from_split(n, S)  # you’d make this similar to above

    for stage in range(L):
        full_set = full[stage]
        split_set = split[stage]
        if full_set != split_set:
            print(f"❌ Stage {stage} mismatch:")
            print("  In Full Only:", full_set - split_set)
            print("  In Split Only:", split_set - full_set)
        else:
            print(f"✅ Stage {stage} is correct.")


In [122]:
compare_per_stage(n=512, S=5)


✅ Stage 0 is correct.
✅ Stage 1 is correct.
✅ Stage 2 is correct.
✅ Stage 3 is correct.
✅ Stage 4 is correct.
✅ Stage 5 is correct.
✅ Stage 6 is correct.
✅ Stage 7 is correct.
✅ Stage 8 is correct.
