A python implementation of BVH for ray casting.

In [1]:
import numpy as np
from typing import Tuple

import igl

# Ray

In [2]:
class Ray:
    """Ray implementation.
    """

    def __init__(self, origin:np.ndarray = np.zeros(3), dir:np.ndarray = np.array([0,0,1]),
                 tMax:float = float("inf"), t:float = 0.):
        self.origin = origin
        self.dir = dir / np.linalg.norm(dir)
        
        self.tMax= tMax
        self.t   = t
    
    def __repr__(self) -> str:
        return f"Origin: {self.origin}, Direction: {self.dir}, tMax: {self.tMax}, t: {self.t}"
    
    def __call__(self, t) -> np.ndarray:
        """Cast ray with given time.
        """
        return self.origin + t*self.dir

In [3]:
ray = Ray(dir = np.array([0.,0., 2]))
ray(10.0)

array([ 0.,  0., 10.])

# Bound

In [4]:
class Bound3D:
    """3D bounding box implementation.
    """

    def __init__(self, p1: np.ndarray, p2: np.ndarray):
        self.pMin = np.minimum(p1, p2)
        self.pMax = np.maximum(p1, p2)
    
    def __repr__(self) -> str:
        return f"pMin: {self.pMin}, pMax: {self.pMax}"
    
    def diag(self) -> np.ndarray:
        """Diagnoal of the bounding box.
        """
        return self.pMax - self.pMin
    
    def inside(self, p: np.ndarray) -> bool:
        """Whether a point p is inside the box.
        """
        x, y, z = p
        return self.pMin[0] <= x <= self.pMax[0] \
               and self.pMin[1] <= y <= self.pMax[1] \
               and self.pMin[2] <= z <= self.pMax[2]
    
    def maximumExtent(self) -> int:
        """Find the axis with maximum extent.
        """
        d = self.diag()
        return np.argmax(d)
    
    def centroid(self) -> np.ndarray:
        """Compute centroid of the bounding box.
        """
        return 0.5*(self.pMin+self.pMax)
    
    def offset(self, p: np.ndarray) -> np.ndarray:
        """Relative offset of p from the box.
        """

        o = p - self.pMin

        for i in range(3):
            if self.pMax[i] > self.pMin[i]:
                o[i] /= (self.pMax[i] - self.pMin[i])

        return o
    
    def surfaceArea(self) -> float:
        """Surface area of the box.
        """
        d = self.diag()
        return 2*(d[0]*d[1] + d[1]*d[2] + d[2]*d[0])
    
    def intersect(self, ray: Ray) -> bool:
        """Intersection with a given ray.

        Args:
            ray: a ray
        
        Return:
            hit: whether the ray hits the bounding box
        """
        t0 = 0.
        t1 = ray.tMax

        for i in range(3):
            invRayDir = 1 / (ray.dir[i]+1e-7)

            tNear= (self.pMin[i] - ray.origin[i]) * invRayDir
            tFar = (self.pMax[i] - ray.origin[i]) * invRayDir

            if tNear > tFar:
                tNear, tFar = tFar, tNear

            t0 = max(t0, tNear)
            t1 = min(t1, tFar)
            
            if t0 > t1:
                return False

        return True

class EmptyBound3D(Bound3D):
    """Empty 3D bounding box implementation.
    """

    def __init__(self):
        super().__init__(np.zeros(3), np.zeros(3))

        self.pMin = np.array([np.Infinity, np.Infinity,np.Infinity])
        self.pMax =-np.array([np.Infinity, np.Infinity,np.Infinity])
    
    def __repr__(self) -> str:
        return "Empty bounding box"
    
    def diag(self) -> np.ndarray:
        """Diagnoal of the bounding box.
        """
        raise NotImplementedError
    
    def inside(self, p: np.ndarray) -> bool:
        """Whether a point p is inside the box.
        """
        return False
    
    def maximumExtent(self) -> int:
        """Find the axis with maximum extent.
        """
        raise NotImplementedError
    
    def centroid(self) -> np.ndarray:
        """Compute centroid of the bounding box.
        """
        raise NotImplementedError
    
    def offset(self, p: np.ndarray) -> np.ndarray:
        """Relative offset of p from the box.
        """
        raise NotImplementedError
    
    def surfaceArea(self) -> float:
        """Surface area of the box.
        """
        return 0.
    
    def intersect(self, ray: Ray) -> bool:
        """Intersection with a given ray.

        Args:
            ray: a ray
        
        Return:
            hit: whether the ray hits the bounding box
        """
        return False

def unionBound(b1: Bound3D, b2: Bound3D) -> Bound3D:
    """Union of two bounding box.

    Args:
        b1: the first bounding box
        b1: the second bounding box
    
    Return:
        box: the union of b1 and b2
    """

    ## check if b1 or b2 is empty
    if isinstance(b1, EmptyBound3D):
        return b2
    elif isinstance(b2, EmptyBound3D):
        return b1
    
    p1 = np.minimum(b1.pMin, b2.pMin)
    p2 = np.maximum(b1.pMax, b2.pMax)

    box = Bound3D(p1, p2)

    return box

# def emptyBound() -> Bound3D:
#     """Create an empty bounding box.

#     Return:
#         box: an empty bounding box
#     """

#     box = Bound3D(np.zeros(3), np.zeros(3))
#     box.pMin = np.array([np.Infinity, np.Infinity,np.Infinity])
#     box.pMax =-np.array([np.Infinity, np.Infinity,np.Infinity])

#     return box

In [5]:
p1 = np.array([3,4,5])
p2 = np.array([6,3,1])

bbox = Bound3D(p1, p2)
bbox.diag()

bbox.inside(np.array([0,0,0]))

False

In [6]:
bbox

pMin: [3 3 1], pMax: [6 4 5]

In [7]:
unionBound(EmptyBound3D(), bbox)

pMin: [3 3 1], pMax: [6 4 5]

In [8]:
p1 = np.array([0,0,0])
p2 = np.array([1,1,1])

bbox = Bound3D(p1, p2)

ray = Ray(origin=np.array([0.5,0.5,0.5]), dir=np.array([1,0,1]))
bbox.intersect(ray)

True

# Triangle

In [9]:
class Triangle:
    """Triangle implementation.
    """

    def __init__(self, v: np.ndarray, faceIdx: int):
        self.v       = v
        self.faceIdx = faceIdx
    
    def __repr__(self) -> str:
        v0, v1, v2 = self.v
        # return f"v0: {v0}, v1: {v1}, v2: {v2}, faceIdx: {self.faceIdx}"
        return f"faceIdx: {self.faceIdx}"
    
    def bound(self) -> Bound3D:
        """Get bounding box of the triangle in WCS.
        """

        p1 = np.min(self.v, axis=0)
        p2 = np.max(self.v, axis=0)
        
        return Bound3D(p1, p2)
    
    def centroid(self) -> np.ndarray:
        """Get centroid of the triangle in WCS.
        """

        return np.mean(self.v, axis=0)
    
    def intersect(self, ray: Ray) -> Tuple[bool, float]:
        """Intersection with a given ray.

        Args:
            ray: a ray
        
        Return:
            hit: whether the ray hits the bounding box
            tHit: time of hit
        """
        
        ## find triangle plane
        v0, v1, v2 = self.v
        e1 = v1 - v0
        e2 = v2 - v0

        ## normal of the plane
        N = np.cross(e1, e2)
        N/= np.linalg.norm(N)

        ## check if the ray direction is parallel to the normal
        NdotDir = np.dot(N, ray.dir)
        if np.isclose(NdotDir, 0.):
            return False, 0.
        
        ## find the intersection P
        # d = -np.dot(N, v0)
        # t = -(np.dot(N, ray.origin) + d) / NdotDir
        t = np.dot(N, v0-ray.origin) / NdotDir

        ## t must be positive
        if t < 0:
            return False, 0.

        ## outside-inside test
        P = ray(t)
        ## check e0
        e0 = v1 - v0
        vp0= P - v0
        C  = np.cross(e0, vp0)
        if np.dot(N, C) < 0:
            return False, 0.

        ## check e1
        e1 = v2 - v1
        vp1= P - v1
        C  = np.cross(e1, vp1)
        if np.dot(N, C) < 0:
            return False, 0.

        ## check e2
        e2 = v0 - v2
        vp2= P - v2
        C  = np.cross(e2, vp2)
        if np.dot(N, C) < 0:
            return False, 0.

        return True, t

In [10]:
v0 = np.array([1,0,0], dtype=float)
v1 = np.array([0,1,0], dtype=float)
v2 = np.array([0,0,1], dtype=float)

v = np.array([v0, v1, v2])

tri = Triangle(v, 0)
tri.bound()
tri.centroid()

array([0.33333333, 0.33333333, 0.33333333])

In [11]:
ray = Ray(dir=np.array([0,1,-1]))
tri.intersect(ray)

(False, 0.0)

In [12]:
V, F = igl.read_triangle_mesh("standard_sphere_2.obj")
ray = Ray(dir=np.array([0.,0.,1.]))

# for i in range(len(F)):
#     v = V[F[i]]

#     tri = Triangle(v, i)
#     print(tri.intersect(ray))

# BVH

In [13]:
# class BVHNode:
#     """BVH node implementation.
#     """

#     def __init__(self, bound:Bound3D, splitAxis:int = 0, left=None, right=None):
#         self.bound = bound
#         self.splitAxis = splitAxis

#         self.left = left
#         self.right= right


# def initLeaf(bound: Bound3D) -> BVHNode:
#     """Initialize a leaf node.
#     """
#     node = BVHNode(bound)

#     return node

# def initInterior(splitAxis: int, left: BVHNode, right: BVHNode) -> BVHNode:
#     """Initialize an interior node.
#     """
#     bound= unionBound(left.bound, right.bound)
#     node = BVHNode(bound, splitAxis, left, right)

#     return node

In [14]:
class BVHNode:
    """Base node implementation.
    """

    def __init__(self, bound:Bound3D, splitAxis:int):
        self.bound = bound
        self.splitAxis = splitAxis

class leafNode(BVHNode):
    """Leaf node implementation.
    """

    def __init__(self, first: int, nTriangles: int, bound: Bound3D):
        super().__init__(bound, splitAxis=-1)

        self.firstTriangleOffset = first
        self.nTriangles = nTriangles

class interiorNode(BVHNode):
    """Interior node implementation.
    """

    def __init__(self, splitAxis: int, left: BVHNode, right: BVHNode):
        bound = unionBound(left.bound, right.bound)
        super().__init__(bound, splitAxis)

        self.left = left
        self.right= right

In [15]:
class BucketInfo:
    """Bucket info container.
    """

    def __init__(self, nBuckets: int):
        self.nBuckets = nBuckets

        self.count = np.zeros(self.nBuckets, dtype=int)
        self.bounds= [EmptyBound3D() for _ in range(self.nBuckets)]
    
    def push(self, idx: int, bound: Bound3D) -> None:
        """Push a bounding box to the given bucket.
        """

        self.count[idx] += 1
        self.bounds[idx] = unionBound(self.bounds[idx], bound)
    
    def SAH(self) -> np.ndarray:
        """Find cost for splitting buckets with SAH.

        Return:
            cost[nBuckets-1]: cost of splitting buckets
        """

        cost = np.zeros(self.nBuckets-1)
        totalSurfaceArea = self.boundsArea() + 1e-7

        for i in range(self.nBuckets-1):
            b0, b1 = EmptyBound3D(), EmptyBound3D()
            count0, count1 = 0, 0

            for j in range(i+1):
                b0 = unionBound(b0, self.bounds[j])
                count0 += self.count[j]
            
            for j in range(i+1, self.nBuckets):
                b1 = unionBound(b1, self.bounds[j])
                count1 += self.count[j]
            
            cost[i] = 0.125 + (count0*b0.surfaceArea() + count1*b1.surfaceArea()) / totalSurfaceArea

        return cost
    
    def boundsArea(self) -> float:
        """Calculate total surface area of all buckets.

        Return:
            A: surface area of the union of all boxes
        """

        box = EmptyBound3D()
        for bound in self.bounds:
            box = unionBound(box, bound)

        return box.surfaceArea()

In [16]:
from typing import Callable

def partition(lst: list, first: int, last: int, fun: Callable) -> int:
    """Python version of std::partition.
    """

    if first == last:
        return first

    for i in range(first, last):
        if fun(lst[i]):
            lst[first], lst[i] = lst[i], lst[first]
            first += 1
    # print(lst)
    return first

In [17]:
class BVH:
    """BVH acceleration.
    """

    def __init__(self, triangles:list[Triangle], nBuckets:int=12):
        self.triangles = triangles[:]
        self.nBuckets  = nBuckets

        ## build the BVH tree recursively
        orderedTriangles = []
        self.root = self._recursiveBuild(0, len(self.triangles), orderedTriangles)
        self.triangles = orderedTriangles
    
    def _recursiveBuild(self, start: int, end: int, orderedTriangles: list[Triangle]) -> BVHNode:
        """Build BVH recursively.
        """
        ## TODO: build the BVH tree recursively
        # print(self.triangles[0])
        ## bound of all the triangles
        bounds = EmptyBound3D()
        for i in range(start, end):
            bounds = unionBound(bounds, self.triangles[i].bound())
        
        nTriangles = end - start
        ## create a leaf node if only one triangle is considered
        if nTriangles == 1:
            ## TODO: update this later
            first = len(orderedTriangles)
            orderedTriangles.append(self.triangles[start])
            # print(f"triangle[{start}]: {self.triangles[start]} is added to orderedTriangles.")
            return leafNode(first, nTriangles, bounds)
        
        ## otherwise build the tree
        else:
            ## compute bound of primitive centroids, choose split dimension
            centroidBounds = EmptyBound3D()

            for i in range(start, end):
                centroid = self.triangles[i].centroid()
                centroidBounds = unionBound(centroidBounds, Bound3D(centroid, centroid))
            
            dim = centroidBounds.maximumExtent()

            if centroidBounds.pMax[dim] == centroidBounds.pMin[dim]:
                mid = (start + end) // 2
            else:
                ## partition primitives into two sets and build children
                if nTriangles == 2:
                    mid = (start + end) // 2
                    self.triangles[start: end] = sorted(self.triangles[start: end], key=lambda triangle: triangle.centroid()[dim])
                else:
                    buckets = BucketInfo(self.nBuckets)

                    ## distribute bounding boxes to buckets
                    for i in range(start, end):
                        b = self.nBuckets * centroidBounds.offset(self.triangles[i].centroid())[dim]
                        b = min(int(b), self.nBuckets-1)

                        buckets.push(b, self.triangles[i].bound())

                    ## find bucket to split at that minimizes SAH metric
                    cost = buckets.SAH()
                    minCostSplitBucket = np.argmin(cost)
                    minCost = cost[minCostSplitBucket]

                    ## either create leaf or split primitives at selected SAH bucket
                    # leafCost = nTriangles
                    # if minCost < leafCost:
                    ## partition the triangles to two sets
                    def fun(triangle: Triangle):
                        b = self.nBuckets * centroidBounds.offset(triangle.centroid())[dim]
                        b = min(int(b), self.nBuckets-1)
                        return b <= minCostSplitBucket
                        
                    mid = partition(self.triangles, start, end, fun)
                    # else:
                    #     first = len(orderedTriangles)
                    #     for i in range(start, end):
                    #         orderedTriangles.append(self.triangles[i])
                        
                    #     return leafNode(first, nTriangles, bounds)
            
            # print(start, mid, end)
            node = interiorNode(dim, 
                                self._recursiveBuild(start, mid, orderedTriangles),
                                self._recursiveBuild(mid,   end, orderedTriangles))
        
        return node

    def intersect(self, ray: Ray):
        ## TODO: ray intersection
        return self._intersect(ray, self.root)

    def _intersect(self, ray: Ray, node: BVHNode) -> Tuple[bool, float, int]:
        """Find intersection with a given ray and node.

        Args:
            ray: ray instance
            node: current node
        
        Return:
            hit: whether the ray hits a triangle
            t: time of flight
            idx: triangle index
        """
        hit = node.bound.intersect(ray)
        if hit:
            if isinstance(node, leafNode):
                idx = node.firstTriangleOffset
                return *self.triangles[idx].intersect(ray), idx
            else:
                hit1, t1, idx1 = self._intersect(ray, node.left)
                hit2, t2, idx2 = self._intersect(ray, node.right)

                if hit1 and hit2:
                    if t1 < t2:
                        return True, t1, idx1
                    else:
                        return True, t2, idx2
                elif hit1:
                    return True, t1, idx1
                elif hit2:
                    return True, t2, idx2
        
        return False, 0., -1

In [18]:
# V = np.array([
#     [0, 0, 0],
#     [1, 0, 0],
#     [2, 0, 0],
#     [0, 1, 0],
#     [1, 1, 0],
#     [2, 1, 0]
# ], dtype=float)

# F = np.array([
#     [0, 1, 4],
#     [0, 4, 3],
#     [1, 2, 5],
#     [1, 5, 4]
# ])

# triangles = [Triangle(V[F[i]], i) for i in range(len(F))]
# bvh = BVH(triangles)

In [19]:
V, F = igl.read_triangle_mesh("standard_sphere_4.obj")
triangles = [Triangle(V[F[i]], i) for i in range(len(F))]

bvh = BVH(triangles)

In [20]:
rays = []
for _ in range(1000):
    origin = np.zeros(3)
    direction = np.random.random(3)
    rays.append(Ray(origin, direction))

In [21]:
t1 = []
F1 = []
for ray in rays:
    for i, triangle in enumerate(triangles):
        hit, t = triangle.intersect(ray)
        if hit:
            t1.append(t)
            F1.append(i)
            break

In [22]:
t2 = []
F2 = []
for ray in rays:
    _, t, idx = bvh.intersect(ray)
    t2.append(t)
    F2.append(bvh.triangles[idx].faceIdx)

In [23]:
np.allclose(t1, t2)

True

In [25]:
np.allclose(F1, F2)

True