[https://developer.nvidia.com/blog/thinking-parallel-part-iii-tree-construction-gpu/]

In [None]:
import numpy

# 具体解释看链接里的文章
# 效果是后10位每两位中间放两个0
def expandBits(v):
    v = (v * 0x00010001) & 0xFF0000FF
    v = (v * 0x00000101) & 0x0F00F00F
    v = (v * 0x00000011) & 0xC30C30C3
    v = (v * 0x00000005) & 0x49249249
    return v

# 空间坐标的莫顿码
def morton3D(x, y, z):
    x = numpy.array(x, dtype = numpy.uint32)
    y = numpy.array(y, dtype = numpy.uint32)
    z = numpy.array(z, dtype = numpy.uint32)
    x:numpy.ndarray = numpy.minimum(numpy.maximum(x * numpy.array(1024.0, dtype=numpy.float32), numpy.array(0.0, dtype=numpy.float32)), numpy.array(1023.0, dtype=numpy.float32))
    y:numpy.ndarray = numpy.minimum(numpy.maximum(y * numpy.array(1024.0, dtype=numpy.float32), numpy.array(0.0, dtype=numpy.float32)), numpy.array(1023.0, dtype=numpy.float32))
    z:numpy.ndarray = numpy.minimum(numpy.maximum(z * numpy.array(1024.0, dtype=numpy.float32), numpy.array(0.0, dtype=numpy.float32)), numpy.array(1023.0, dtype=numpy.float32))
    xx = expandBits(x.astype(numpy.uint32))
    yy = expandBits(y.astype(numpy.uint32))
    zz = expandBits(z.astype(numpy.uint32))
    return xx * 4 + yy * 2 + zz

In [None]:
from concurrent.futures import ThreadPoolExecutor
int_bit = 32

# 叶子
class leafNode:
    def __init__(self) -> None:
        self.objectID = None
        self.parent = None

    def __str__(self, level=0) -> str:
        indent = "  " * level  # 缩进显示层级
        parent_id = id(self.parent) if self.parent else "None"
        return f"{indent}leafNode(objectID={self.objectID}, parentID={parent_id})"

# 非叶子
class internalNode:
    def __init__(self) -> None:
        self.childA = None
        self.childB = None
        self.parent = None

    def __str__(self, level=0) -> str:
        indent = "  " * level  # 缩进显示层级
        parent_id = id(self.parent) if self.parent else "None"
        childA_str = self.childA.__str__(level + 1) if self.childA else f"{indent}  None"
        childB_str = self.childB.__str__(level + 1) if self.childB else f"{indent}  None"
        return (f"{indent}internalNode(parentID={parent_id})\n"
                f"{childA_str}\n"
                f"{childB_str}")

# 前缀0的个数
# aten里面好像有封装但是torch没提供接口
# numpy也没有相应接口
# 只能先转成int
def clz(num:int):
    return int_bit - int(num).bit_length()

# 二分确定分割点
def findSplit(
    first,
    last,
    sigma,
):
    commonPrefix = sigma(first, last)
    if commonPrefix == int_bit:
        return (first + last) >> 1
    
    split = first
    step = last - first
    
    while True:
        step = (step + 1) >> 1
        newSplit = split + step
        if newSplit < last:
            splitPrefix = sigma(first, split)
            if splitPrefix > commonPrefix:
                split = newSplit
        if step <= 1:
            break
    return split

# 倍增和二分确定大致范围
def determineRange(
    idx,
    sigma,
):
    d = numpy.sign(sigma(idx, idx + 1) - sigma(idx, idx - 1))
    sigma_min = sigma(idx,idx-d)
    l_max = 2
    while sigma(idx, idx + l_max*d) > sigma_min:
        l_max = l_max << 1
    l = 0
    t = l_max >> 1
    while t >= 1:
        if sigma(idx, idx+(l+t)*d) > sigma_min:
            l = l + t
        t = t >> 1
    j = idx + l*d
    if j > idx:
        return (idx, j)
    return (j, idx)

# BVH树
def generateHierarchy(
    sortedMortonCodes:list[int],        # 排序后的莫顿码
    sortedObjectIDs:list[int],          # 莫顿码对应的面片的ID
    numObjects,                         # 面片数量
):
    # 获取相同前缀长度
    def sigma(i, j):
        if j < 0 or j > numObjects - 1:
            return -1
        return clz(sortedMortonCodes[i] ^ sortedMortonCodes[j])
    leafNodes = [leafNode() for i in range(numObjects)]
    internalNodes = [internalNode() for i in range(numObjects - 1)]
    for idx in range(numObjects):
        leafNodes[idx].objectID = sortedObjectIDs[idx]

    # 从这里开始可以并行
    # 但由于PIL，就算改成多线程也没什么意义
    # 多进程又会遇到内存不互通引起的各种问题
    # 暂时使用简单的多线程实现
    def calUnit(idx):
        rg = determineRange(idx, sigma)
        first = rg[0]
        last = rg[1]
        
        split = findSplit(first, last, sigma)
        if split == first:
            childA = leafNodes[split]
        else:
            childA = internalNodes[split]

        if split + 1 == last:
            childB = leafNodes[split + 1]
        else:
            childB = internalNodes[split + 1]
        
        internalNodes[idx].childA = childA
        internalNodes[idx].childB = childB
        childA.parent = internalNodes[idx]
        childB.parent = internalNodes[idx]
    with ThreadPoolExecutor() as pool:
        pool.map(calUnit, range(numObjects-1))
    
    return internalNodes[0]

In [15]:
x = numpy.array([0,0,0])
y = numpy.array([0,1,0])
z = numpy.array([0,0,1])

sortedMortonCodes = morton3D(x,y,z)
sortedObjectIDs = sortedMortonCodes.argsort()
sortedMortonCodes = sortedMortonCodes[sortedObjectIDs]
print(generateHierarchy(sortedMortonCodes, sortedObjectIDs, len(sortedMortonCodes)))

internalNode(parentID=None)
  internalNode(parentID=136359914125840)
    leafNode(objectID=0, parentID=136359724650448)
    leafNode(objectID=2, parentID=136359724650448)
  leafNode(objectID=1, parentID=136359914125840)


# numpy实现

In [136]:
def dedetermineRangePal(
    numObjects,
    sigmaPal,
):
    idx = numpy.linspace(0, numObjects - 2, numObjects - 1, dtype=numpy.int32)
    sigma1 = sigmaPal(idx, idx-1)
    sigma2 = sigmaPal(idx, idx+1)
    d = numpy.sign(sigma2 - sigma1)
    sigma_min = sigmaPal(idx, idx-d)
    l_max = numpy.ones_like(idx) * 2
    mask = numpy.ones_like(idx)
    while numpy.any(mask):
        mask = sigmaPal(idx, idx + l_max*d) > sigma_min
        l_max[mask] = l_max[mask] << 1
    l = numpy.zeros_like(idx)
    t = l_max >> 1
    mask_t = numpy.ones_like(idx)
    while numpy.any(mask_t):
        mask_l = mask_t * (sigmaPal(idx, idx+(l+t)*d) > sigma_min)
        l[mask_l] = l[mask_l] + t[mask_l]
        t[mask_t] = t[mask_t] >> 1
        mask_t = t >= 1
    j = idx + l * d
    fin_mask = j > idx
    return idx*fin_mask+j*(~fin_mask), idx*(~fin_mask)+j*fin_mask

def findSplitPal(
    first,
    last,
    sigmaPal,
):
    commonPrefix = sigmaPal(first, last)
    baseMask = commonPrefix != int_bit
    split = first.copy()
    step = last - first
    while True:
        step[baseMask] = (step[baseMask] + 1) >> 1
        newSplit = split + step
        mask = baseMask*(newSplit < last)
        splitMask = mask * (sigmaPal(first, split) > commonPrefix)
        split[splitMask] = newSplit[splitMask]
        if numpy.all(step[baseMask] <= 1):
            break
    split[~baseMask] = (first[~baseMask] + last[~baseMask]) >> 1
    return split

def generateHierarchyPal(
    sortedMortonCodes:list[int],        # 排序后的莫顿码
    sortedObjectIDs:list[int],          # 莫顿码对应的面片的ID
    numObjects,                         # 面片数量
):
    def clzPal(numList):
        int_bit = 32
        result = numpy.zeros_like(numList)
        ones = numpy.ones_like(numList)
        mask = numpy.ones_like(numList)
        for i in range(int_bit):
            mask = mask*(ones - (numList>>(int_bit-i-1)))
            if numpy.all(mask == 0):
                break
            result = result + mask
        return result
    def sigmaPal(idx1, idx2):
        mask = (idx2 < 0) ^ (idx2 > numObjects - 1)
        idx1_copy = idx1*(~mask)
        idx2_copy = idx2*(~mask)
        ori = clzPal(sortedMortonCodes[idx1_copy]^sortedMortonCodes[idx2_copy])
        return (ori*(~mask) - mask).astype(numpy.int32)
    leafNodes = [leafNode() for i in range(numObjects)]
    internalNodes = [internalNode() for i in range(numObjects - 1)]
    for idx in range(numObjects):
        leafNodes[idx].objectID = sortedObjectIDs[idx]
    first, last = dedetermineRangePal(numObjects, sigmaPal)
    split = findSplitPal(first, last, sigmaPal)
    for i in range(numObjects - 1):
        if split[i] == first[i]:
            childA = leafNodes[split[i]]
        else:
            childA = internalNodes[split[i]]

        if split[i] + 1 == last[i]:
            childB = leafNodes[split[i] + 1]
        else:
            childB = internalNodes[split[i] + 1]
        
        internalNodes[i].childA = childA
        internalNodes[i].childB = childB
        childA.parent = internalNodes[i]
        childB.parent = internalNodes[i]
    return internalNodes[0]


In [137]:
x = numpy.array([0,0,0])
y = numpy.array([0,1,0])
z = numpy.array([0,0,1])

sortedMortonCodes = morton3D(x,y,z)
sortedObjectIDs = sortedMortonCodes.argsort()
sortedMortonCodes = sortedMortonCodes[sortedObjectIDs]
print(generateHierarchyPal(sortedMortonCodes, sortedObjectIDs, len(sortedMortonCodes)))

internalNode(parentID=None)
  internalNode(parentID=128932514664400)
    leafNode(objectID=0, parentID=128932514667536)
    leafNode(objectID=2, parentID=128932514667536)
  leafNode(objectID=1, parentID=128932514664400)
