In [None]:
import drjit as dr
import mitsuba as mi
import numpy as np
import matplotlib.pyplot as pp

mi.set_variant('cuda_ad_rgb')

In [None]:
view_resolution = 100
normal_resolution = 100

# Canonical coordinates
N_X, N_Y, Y, X = dr.meshgrid(
    dr.linspace(mi.Float, 0, 1.0, normal_resolution),
    dr.linspace(mi.Float, 0, 1.0, normal_resolution),
    dr.linspace(mi.Float, 0, 1.0, view_resolution),
    dr.linspace(mi.Float, 0, 1.0, view_resolution)
)

# Cylinder coordinates
# https://web.ma.utexas.edu/users/m408m/Display15-10-8.shtml
def canonical_to_dir(p: mi.Point2f) -> mi.Vector3f:
    cosTheta = 2 * p.x - 1
    phi = 2 * dr.pi * p.y
    sinTheta = dr.sqrt(dr.maximum(0, 1 - cosTheta**2))
    return mi.Vector3f(sinTheta * dr.cos(phi), sinTheta * dr.sin(phi), cosTheta)

def dir_to_canonical(wi: mi.Vector3f) -> mi.Point2f:
    cosTheta = wi.z
    phi = dr.atan2(wi.y, wi.x)
    phi = dr.select(phi < 0, phi + 2 * dr.pi, phi)
    return mi.Point2f((cosTheta + 1) * 0.5, phi / (2 * dr.pi))

# Compute the view and normal directions
wi = canonical_to_dir(mi.Point2f(X, Y))

n = canonical_to_dir(mi.Point2f(N_X, N_Y))

In [None]:
def rotate_align_scalar(v1, v2):
    v1 = v1 / np.linalg.norm(v1)
    v2 = v2 / np.linalg.norm(v2)
    axis = np.cross(v1, v2)
    
    cosA = np.dot(v1, v2)
    k = 1.0 / (1.0 + cosA)
    
    return mi.ScalarTransform4f(
        [[(axis[0] * axis[0] * k) + cosA, (axis[1] * axis[0] * k) - axis[2], (axis[2] * axis[0] * k) + axis[1], 0],
        [(axis[0] * axis[1] * k) + axis[2], (axis[1] * axis[1] * k) + cosA, (axis[2] * axis[1] * k) - axis[0], 0], 
        [(axis[0] * axis[2] * k) - axis[1], (axis[1] * axis[2] * k) + axis[0], (axis[2] * axis[2] * k) + cosA, 0],
        [0, 0, 0, 1]] 
    )

def scene_from_normal(normal):
    given_normal = np.array([0, 0, 1])
    desired_normal = normal

    m = rotate_align_scalar(given_normal, desired_normal)

    return mi.load_dict({
        'type': 'scene',
        'light': {'type': 'envmap', 'filename': './scenes/studio_small_09_1k.exr'},
        'rectangle' : {
            'type': 'obj',
            'filename': './scenes/meshes/sphere.obj',
            'to_world': m,
            'bsdf': {
                'type': 'roughconductor',
                'material': 'Al',
                'distribution': 'ggx',
                'alpha': 0.1,
                'sample_visible' : True,
            },
        }
    })

def compute_image(scene: mi.Scene, directions: mi.Vector3f, spps: int):
    rng = mi.PCG32(size=dr.width(directions))

    result = mi.Color3f(0, 0, 0)

    bsdf_context = mi.BSDFContext()
    ray = mi.Ray3f(o=directions, d=-directions) # Ray towards 0.0
    si = scene.ray_intersect(ray)
    i = mi.UInt32(0)

    loop = mi.Loop(name="", state=lambda: (rng, i, result))

    while loop(i < spps):
        # TODO: Add better MIS (e.g. bidirectional approach)
        # Sample the BSDF
        bsdf_sample, bsdf_val = si.bsdf().sample(bsdf_context, si, 
                                                 rng.next_float32(), 
                                                 mi.Point2f(rng.next_float32(), rng.next_float32()), 
                                                 si.is_valid())
        
        # Create new ray
        new_direction = si.sh_frame.to_world(bsdf_sample.wo)
        new_ray = mi.Ray3f(o=si.p, d=new_direction)
        # Intersect the new ray
        si_new = scene.ray_intersect(new_ray)
        # Monte Carlo estimator
        result +=  bsdf_val * si_new.emitter(scene, si.is_valid()).eval(si_new)
        
        i = i + 1

    result = result / spps
    
    count = dr.width(result)
    image = result.numpy()[count - view_resolution * view_resolution: ]
    image = mi.TensorXf(image.reshape(view_resolution, view_resolution, 3))
    
    return image

In [None]:
def rotate_align(v1: mi.Vector3f, v2: mi.Vector3f):
    check = (1 + dr.sign(dr.dot(v1, v2) + 1 - 0.0001)) / 2
    axis = dr.cross(v1, v2)
    
    cosA = dr.dot(v1, v2)
    k = mi.Float(1.0 / (1.0 + cosA * check))
    
    return mi.Matrix3f(
        [[(axis[0] * axis[0] * k) + cosA, (axis[1] * axis[0] * k) - axis[2], (axis[2] * axis[0] * k) + axis[1]],
        [(axis[0] * axis[1] * k) + axis[2], (axis[1] * axis[1] * k) + cosA, (axis[2] * axis[1] * k) - axis[0]], 
        [(axis[0] * axis[2] * k) - axis[1], (axis[1] * axis[2] * k) + axis[0], (axis[2] * axis[2] * k) + cosA]] 
    )

def matmul(rotmat : mi.Matrix3f, vec: mi.Vector3f):
    new_vector = mi.Vector3f(
        rotmat[0][0] * vec[0] + rotmat[1][0] * vec[1] + rotmat[2][0] * vec[2],
        rotmat[0][1] * vec[0] + rotmat[1][1] * vec[1] + rotmat[2][1] * vec[2],
        rotmat[0][2] * vec[0] + rotmat[1][2] * vec[1] + rotmat[2][2] * vec[2]
    )
    return new_vector

def compute_grid(normals: mi.Vector3f, directions: mi.Vector3f, spps: int):
    given_normal = mi.Vector3f(0, 0, 1)
    rotation_mat = rotate_align(normals, given_normal)
    inverse_mat = rotate_align(given_normal, normals)
    local_view_directions = matmul(rotation_mat, directions)
    
    print(normals)
    
    si = mi.SurfaceInteraction3f()
    si.n = normals
    si.p = mi.Point3f(0, 0, 0)
    si.wi = local_view_directions
    
    envmap_scene = mi.load_dict({
        'type': 'scene',
        'light': {'type': 'envmap', 'filename': './scenes/studio_small_09_1k.exr'}
    })

    material = mi.load_dict({
        'type': 'roughconductor',
        'material': 'Al',
        'distribution': 'ggx',
        'alpha': 0.1,
        'sample_visible' : True,
    })

    i = mi.Float(0)
    rng = mi.PCG32(size=dr.width(directions))

    result = mi.Color3f(0, 0, 0)
    bsdf_context = mi.BSDFContext()

    grid_loop = mi.Loop(name="", state=lambda: (rng, i, result))
    
    while grid_loop(i < spps):
        # TODO: Add better MIS (e.g. bidirectional approach)
        # Sample the BSDF
        bsdf_sample, bsdf_val = material.sample(bsdf_context, si, 
                                                    rng.next_float32(), 
                                                    mi.Point2f(rng.next_float32(), rng.next_float32()), 
                                                    True)
        
        # Create new ray
        new_direction = matmul(inverse_mat, bsdf_sample.wo)
        new_ray = mi.Ray3f(o=si.p, d=new_direction)
        # Intersect the new ray
        si_new = envmap_scene.ray_intersect(new_ray)
        # Monte Carlo estimator
        result +=  bsdf_val * si_new.emitter(envmap_scene, True).eval(si_new)
        #result += directions
        
        i = i + 1
    
    result = result / spps
    
    return result

def display_grid_seg(grid, n_x, n_y):
    image = grid[n_y - 1][n_x - 1]
    return image

def display_full_grid(grid):
    figure = pp.figure(figsize=(20, 5))
    # First subplot
    for i in range(normal_resolution):
        for j in range(normal_resolution):
            pp.subplot(normal_resolution, normal_resolution, i * normal_resolution + j)
            image = grid[i][j]
            pp.imshow(image)
            pp.axis("off")

In [None]:
normal = np.array([0, 0, 1])
test_scene = scene_from_normal(normal)
precomputed_grid = compute_grid(n, wi, 10240)
#grid_numpy = precomputed_grid.numpy().reshape(normal_resolution, normal_resolution, view_resolution, view_resolution, 3)
#precomputed_outgoing = display_grid_seg(grid_numpy, normal_resolution, normal_resolution)
#pp.imshow(precomputed_outgoing)

In [None]:
#display_full_grid(precomputed_grid)

In [None]:
# Do a rendering of the scene
sensor = mi.load_dict({
    "type": "perspective",
    "film": {
        "type": "hdrfilm",
        "width": 512,
        "height": 512,
        "rfilter": {"type": "box"}
    },
    "sampler": {
        "type": "independent",
        "sample_count": 128
    },
    "to_world" : mi.ScalarTransform4f.look_at(origin=[0, 1.5, 0], target=[0, 0, 0], up=[0, 0, 1]),
    "fov": 90,
    "near_clip": 0.1,
    "far_clip": 1000
})

path_int = mi.load_dict({
    "type": "path"
})

test_scene = scene_from_normal(np.array([0, 0, 1]))
img_path = path_int.render(test_scene, sensor=sensor, spp=1280)
pp.imshow(img_path.numpy().reshape(512, 512, 3))

In [8]:
class NormalIntegrator(mi.SamplingIntegrator):
    def __init__(self, arg0: mi.Properties, grid):
        mi.SamplingIntegrator.__init__(self, arg0)
        self.precomputed_grid = grid.numpy()
        print(np.shape(self.precomputed_grid))
    
    def get_grid(self, n_y, n_x, w_y, w_x):
        final_index = mi.Int((n_y % normal_resolution) * normal_resolution * view_resolution * view_resolution
                            + ((n_x % normal_resolution - n_x // normal_resolution) % normal_resolution) * view_resolution * view_resolution 
                            + (w_y % view_resolution) * view_resolution 
                            + ((w_x % view_resolution - w_x // view_resolution) % view_resolution))
        return mi.Color3f(self.precomputed_grid[final_index]) 
    
    def quadrilinear_interp(self, n_y, n_x, w_y, w_x, ny, nx, py, px):
        p0000 = self.get_grid(n_y, n_x, w_y, w_x)
        p0001 = self.get_grid(n_y, n_x, w_y, w_x + 1)
        p0010 = self.get_grid(n_y, n_x, w_y + 1, w_x)
        p0011 = self.get_grid(n_y, n_x, w_y + 1, w_x + 1)
        p0100 = self.get_grid(n_y, n_x + 1, w_y, w_x)
        p0101 = self.get_grid(n_y, n_x + 1, w_y, w_x + 1)
        p0110 = self.get_grid(n_y, n_x + 1, w_y + 1, w_x)
        p0111 = self.get_grid(n_y, n_x + 1, w_y + 1, w_x + 1)
        p1000 = self.get_grid(n_y + 1, n_x, w_y, w_x)
        p1001 = self.get_grid(n_y + 1, n_x, w_y, w_x + 1)
        p1010 = self.get_grid(n_y + 1, n_x, w_y + 1, w_x)
        p1011 = self.get_grid(n_y + 1, n_x, w_y + 1, w_x + 1)
        p1100 = self.get_grid(n_y + 1, n_x + 1, w_y, w_x)
        p1101 = self.get_grid(n_y + 1, n_x + 1, w_y, w_x + 1)
        p1110 = self.get_grid(n_y + 1, n_x + 1, w_y + 1, w_x)
        p1111 = self.get_grid(n_y + 1, n_x + 1, w_y + 1, w_x + 1)
        
        q000 = dr.lerp(p0000, p0001, px * view_resolution - w_x)
        q001 = dr.lerp(p0010, p0011, px * view_resolution - w_x)
        q010 = dr.lerp(p0100, p0101, px * view_resolution - w_x)
        q011 = dr.lerp(p0110, p0111, px * view_resolution - w_x)
        q100 = dr.lerp(p1000, p1001, px * view_resolution - w_x)
        q101 = dr.lerp(p1010, p1011, px * view_resolution - w_x)
        q110 = dr.lerp(p1100, p1101, px * view_resolution - w_x)
        q111 = dr.lerp(p1110, p1111, px * view_resolution - w_x)
        
        r00 = dr.lerp(q000, q001, py * view_resolution - w_y)
        r01 = dr.lerp(q010, q011, py * view_resolution - w_y)
        r10 = dr.lerp(q100, q101, py * view_resolution - w_y)
        r11 = dr.lerp(q110, q111, py * view_resolution - w_y)
        
        s0 = dr.lerp(r00, r01, nx * normal_resolution - n_x)
        s1 = dr.lerp(r10, r11, nx * normal_resolution - n_x)
        
        t0 = dr.lerp(s0, s1, ny * normal_resolution - n_y)
        
        return t0
    
    def sample(self, scene: mi.Scene, sampler: mi.Sampler, ray: mi.RayDifferential3f, medium=None, active=True):
        si = scene.ray_intersect(ray)
        color = mi.Color3f(0.0)
        wi_world = si.to_world(si.wi)
        n = dir_to_canonical(si.sh_frame.n)
        p = dir_to_canonical(wi_world)
        n_d = 1.0 / normal_resolution
        w_d = 1.0 / view_resolution
        n_y = mi.Int(dr.floor(n.y / n_d))
        n_x = mi.Int(dr.floor(n.x / n_d))
        w_y = mi.Int(dr.floor(p.y / w_d))
        w_x = mi.Int(dr.floor(p.x / w_d))
        grid_sample = self.quadrilinear_interp(n_y, n_x, w_y, w_x, n.y, n.x, p.y, p.x)
        color[si.is_valid()] = grid_sample
        color[~si.is_valid()] = si.emitter(scene, True).eval(si)
        return (color, si.is_valid(), [])

normal_int = NormalIntegrator(mi.Properties(), precomputed_grid)

In [None]:
img_cached = normal_int.render(test_scene, sensor=sensor, spp=1)
pp.imshow(img_cached.numpy().reshape(512, 512, 3))

In [None]:
diff = np.mean((img_cached - img_path).numpy(), axis=2)
np_diff = np.zeros_like(img_cached.numpy())
np_diff[:, :, 0] = np.maximum(np.zeros_like(diff), diff)
np_diff[:, :, 1] = np.maximum(np.zeros_like(diff), -diff)
# Large horizontal figure with 4 subplots
figure = pp.figure(figsize=(20, 5))
# First subplot
pp.subplot(1, 4, 1)
pp.imshow(img_path.numpy().reshape(512, 512, 3))
pp.title("Path Tracing")
# Second subplot
pp.subplot(1, 4, 2)
pp.imshow(img_cached.numpy().reshape(512, 512, 3))
pp.title("Normal Integrator")
# Third subplot
pp.subplot(1, 4, 3)
pp.imshow(diff)
pp.title("Difference")
# Fourth subplot
pp.subplot(1, 4, 4)
pp.imshow(np_diff)
pp.title("Bias")