# The Forward Rendering Pipeline

The forward rendering pipeline should be combined with that of PhySG rather than directly employing Unity's URP/HDRP. The only part that should be extracted from Unity's pipeline is the specular colour rendering in `GlintForwardLitPass`, as geometry and albedo are nicely handled by PhySG.

Nevertheless, details of PhySG's pipeline can be found [here](https://kai-46.github.io/PhySG-website/) the details of Unity's URP can be found [here](https://nedmakesgames.medium.com/writing-unity-urp-shaders-with-code-part-1-the-graphics-pipeline-and-you-798cbc941cea).

> **Watch Out!!!**
>
> - HLSL uses column major. When detailing with matrices that have more than 1 dimension, remember to check that the calculations are correct. Potentially problematic functions include:
>   - `GetGradientEllipse(duvdx, duvdy)`
>   - `SampleGlintGridSimplex(uv, gridSeed, slope, footprintArea, targetNDF, gridWeight)`


In [1]:
import torch
import numpy as np
import struct
import array
import imageio
import warnings

In [2]:
# constants
EPSILON = 1e-6
DEG2RAD = 0.01745329251
RAD2DEG = 57.2957795131
ZERO = torch.tensor(0.0, requires_grad=True)
ONE = torch.tensor(1.0, requires_grad=True)

## Utility Functions


In [3]:
def is_valid(tensor):
    valid = True
    if torch.any(torch.isnan(tensor)):
        print(f"The input tensor has NaN values.")
        valid = False
    if torch.any(torch.isinf(tensor)):
        print(f"The input tensor has Inf values.")
        valid = False

    return valid

In [4]:
def remove_zeros(tensor):
    res = torch.where(tensor == 0.0, tensor.mean(), tensor)
    assert torch.all(res != 0.0)
    return res

In [5]:
def toIntApprox(tensor):
    return torch.where(tensor >= 0.0, torch.floor(tensor), torch.ceil(tensor))

In [6]:
def normalise(tensor):
    return (tensor - tensor.min()) / torch.clamp(tensor.max() - tensor.min(), EPSILON)

In [7]:
# @param u: float3/4
# @param mu: float
# @param sigma: float
#
# @return float3
def sampleNormalDistribution(u, mu, sigma):
    # 2.0 * u - 1.0 must be within (-1.0, 1.0)
    return (sigma * 1.414213).unsqueeze(-1) * torch.erfinv(
        2.0 * u - 1.0
    ) + mu.unsqueeze(-1)

In [8]:
# @param v: uint3
#
# @return float3
def pcg3dFloat(v):
    v = v * np.uint32(1664525) + np.uint32(1013904223)

    v[..., 0] += v[..., 1] * v[..., 2]
    v[..., 1] += v[..., 2] * v[..., 0]
    v[..., 2] += v[..., 0] * v[..., 1]

    v ^= v >> np.uint32(16)

    v[..., 0] += v[..., 1] * v[..., 2]
    v[..., 1] += v[..., 2] * v[..., 0]
    v[..., 2] += v[..., 0] * v[..., 1]

    return torch.tensor(v * (1.0 / 4294967296.0))


def mt19937_3dFloat(dim, seed):
    g = torch.Generator()  # TODO: remember to add device='cuda' if using GPU
    g.manual_seed(seed)

    return torch.clamp(torch.rand(dim, generator=g, requires_grad=True), EPSILON)

In [9]:
# @param p3: float3
#
# @return float
def HashWithoutSine13(p3):
    p3 = torch.frac(p3 * 0.1031)
    p3 += torch.sum(p3 * (p3[:, [1, 2, 0]] + 33.33), dim=1).unsqueeze(-1)

    return torch.frac((p3[..., 0] + p3[..., 1]) * p3[..., 2])

In [10]:
# @param duvdx: float2
# @param duvdy: float2
# @param ellipseMajor: float2
# @param ellipseMinor: float2
def GetGradientEllipse(duvdx, duvdy):
    # Construct Jacobian matrix
    # Note that HLSL is column major: https://stackoverflow.com/questions/22756121/confuse-with-row-major-and-column-major-matrix-multiplication-in-hlsl#:~:text=HLSL%20uses%20Column%2DMajor%20and%20XNAMath%20uses%20ROW%2DMajor.
    J = torch.transpose(torch.stack([duvdx, duvdy], dim=2), 1, 2)
    # Check if determinant is zero and replace the matrices
    J = torch.where(
        torch.det(J).unsqueeze(1).unsqueeze(2) == 0,
        torch.randn((2, 2), requires_grad=True),
        J,
    )
    J = torch.linalg.inv(J)
    J = torch.matmul(J, torch.transpose(J, 1, 2))

    a = J[..., 0, 0]
    c = J[..., 1, 0]
    d = J[..., 1, 1]

    T = a + d
    D = torch.linalg.det(J)
    # They are meant to be > 0.0
    L1 = remove_zeros(torch.abs(T / 2.0 - torch.pow(T * T / 3.99999 - D, 0.5)))
    L2 = remove_zeros(torch.abs(T / 2.0 + torch.pow(T * T / 3.99999 - D, 0.5)))

    A0 = torch.stack((L1 - d, c), dim=-1)
    A1 = torch.stack((L2 - d, c), dim=-1)
    r0 = 1.0 / torch.sqrt(L1)
    r1 = 1.0 / torch.sqrt(L2)

    ellipseMajor = torch.nn.functional.normalize(A0, dim=-1) * r0.unsqueeze(-1)
    ellipseMinor = torch.nn.functional.normalize(A1, dim=-1) * r1.unsqueeze(-1)

    return ellipseMajor, ellipseMinor

In [11]:
# @param v: float3
#
# @return float2
def VectorToSlope(v):
    return torch.tensor([-v[0] / v[2], -v[1] / v[2]])


# @param s: float2
#
# @return float3
def SlopeToVector(s):
    z = 1 / torch.sqrt(s[0] * s[0] + s[1] * s[1] + 1)
    x = s[0] * z
    y = s[1] * z

    return torch.tensor([x, y, z])

In [12]:
# @param uv: float2
# @param rotation: float
# @param mid: float2
#
# @return float2
def RotateUV(uv, rotation, mid):
    return torch.stack(
        (
            torch.cos(rotation) * (uv[..., 0] - mid[0])
            + torch.sin(rotation) * (uv[..., 1] - mid[1])
            + mid[0],
            torch.cos(rotation) * (uv[..., 1] - mid[1])
            - torch.sin(rotation) * (uv[..., 0] - mid[0])
            + mid[1],
        ),
        dim=-1,
    )

In [13]:
# @param values: float4
# @param valuesLerp: float2
#
# @return float
def BilinearLerp(values, valuesLerp):
    resultX = torch.lerp(values[..., 0], values[..., 2], valuesLerp[..., 0])
    resultY = torch.lerp(values[..., 1], values[..., 3], valuesLerp[..., 0])

    return torch.lerp(resultX, resultY, valuesLerp[..., 1])

In [14]:
# @param s: float
# @param a1: float
# @param a2: float
# @param b1: float
# @param b2: float
#
# @return float
def Remap(s, a1, a2, b1, b2):
    return b1 + (s - a1) * (b2 - b1) / (a2 - a1)


# @param s: float
# @param b1: float
# @param b2: float
#
# @return float
def Remap01To(s, b1, b2):
    return b1 + s * (b2 - b1)


# @param s: float or float4
# @param a1: float or float4
# @param a2: float or float4
#
# @return float or float4
def RemapTo01(s, a1, a2):
    return (s - a1) / (a2 - a1)

In [15]:
# @param p: float3
# @param v1: float3
# @param v2: float3
# @param v3: float3
# @param v4: float3
#
# @return float4
def GetBarycentricWeightsTetrahedron(p, v1, v2, v3, v4):
    c11, c21, c31, c41 = v1 - v4, v2 - v4, v3 - v4, v4 - p

    m1 = c31[..., 1:] / remove_zeros(c31[..., 0]).unsqueeze(-1)
    c12 = c11[..., 1:] - c11[..., 0].unsqueeze(-1) * m1
    c22 = c21[..., 1:] - c21[..., 0].unsqueeze(-1) * m1
    c32 = c41[..., 1:] - c41[..., 0].unsqueeze(-1) * m1

    m2 = c22[..., 1] / remove_zeros(c22[..., 0])
    uvwk_0 = (c32[..., 0] * m2 - c32[..., 1]) / remove_zeros(
        c12[..., 1] - c12[..., 0] * m2
    )
    uvwk_1 = -(c32[..., 0] + c12[..., 0] * uvwk_0) / remove_zeros(c22[..., 0])
    uvwk_2 = -(
        c41[..., 0] + c21[..., 0] * uvwk_1 + c11[..., 0] * uvwk_0
    ) / remove_zeros(c31[..., 0])
    uvwk_3 = 1.0 - uvwk_2 - uvwk_1 - uvwk_0

    # The range of the return values depends on the input vertices and the position of the point p relative to the
    # tetrahedron. In general, the barycentric coordinates should satisfy the condition 0 <= uvwk.x, uvwk.y, uvwk.z, uvwk.w <= 1
    # and uvwk.x + uvwk.y + uvwk.z + uvwk.w = 1. These conditions ensure that the point p lies within the tetrahedron.
    return torch.nn.functional.softmax(
        torch.stack((uvwk_0, uvwk_1, uvwk_2, uvwk_3), dim=-1), dim=-1
    )

Check the [HLSL documentation](https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-scalar) for the number of bits of each data type.


In [16]:
# @param input: uint32
#
# @return float32
def f16tof32(input):
    packed_input = [struct.pack("IIII", x[0], x[1], x[2], x[3]) for x in input]
    return np.float32([struct.unpack("ffff", y) for y in packed_input])


# @param input: float32
#
# return uint32 (stored in the low-half of the uint.)
def f32tof16(input):
    return np.uint32(
        struct.unpack("I", struct.pack("f", input))[0]
    )  # make it explicit as a uint32

In [17]:
# @param input: float4 (here we have tensors)
# @param a: float4
# @param b: float4
def UnpackFloatParallel4(input):
    input_in_bytes = [
        struct.pack("ffff", x[0], x[1], x[2], x[3]) for x in input.tolist()
    ]
    uintInput = np.uint32([array.array("I", x) for x in input_in_bytes])
    a = torch.tensor(f16tof32(uintInput >> 16))
    b = torch.tensor(f16tof32(uintInput))

    return a, b

## Noise Map Initialisation

The noise map is generated from Unity. Since the noise map has randomness involved and stays constant during rendering, it is unnecessary to reimplement this part in Python.


In [18]:
noiseMapExr = imageio.v3.imread("./noise_maps/noise_map_256.exr")
_Glint2023NoiseMap = torch.from_numpy(noiseMapExr).requires_grad_()  # Texture2D<float4>
_Glint2023NoiseMap

tensor([[[6.9222e-05, 1.7004e-09, 3.1239e-07, 8.9268e-19],
         [3.1239e-07, 8.9268e-19, 6.7638e-08, 1.3198e-11],
         [6.7638e-08, 1.3198e-11, 9.3990e-06, 4.0279e-07],
         ...,
         [7.7720e-04, 9.0422e-07, 2.1206e-07, 5.1943e-03],
         [2.1206e-07, 5.1943e-03, 1.8476e-04, 1.9808e-07],
         [1.8476e-04, 1.9808e-07, 6.9222e-05, 1.7004e-09]],

        [[4.0038e-03, 6.9222e-05, 3.0856e-03, 3.1239e-07],
         [3.0856e-03, 3.1239e-07, 5.7582e-06, 6.7638e-08],
         [5.7582e-06, 6.7638e-08, 2.8748e-04, 9.3990e-06],
         ...,
         [2.9713e-03, 7.7720e-04, 1.3319e-04, 2.1206e-07],
         [1.3319e-04, 2.1206e-07, 3.0629e-03, 1.8476e-04],
         [3.0629e-03, 1.8476e-04, 4.0038e-03, 6.9222e-05]],

        [[1.2987e-18, 4.0038e-03, 3.0278e-06, 3.0856e-03],
         [3.0278e-06, 3.0856e-03, 5.1768e-17, 5.7582e-06],
         [5.1768e-17, 5.7582e-06, 6.1151e-05, 2.8748e-04],
         ...,
         [4.8815e-05, 2.9713e-03, 3.4283e-04, 1.3319e-04],
         [

In [19]:
_Glint2023NoiseMapSize = _Glint2023NoiseMap.shape[0]
target_dim = 32

In [20]:
# Generate random indices
random_indices = torch.randint(0, _Glint2023NoiseMapSize, (target_dim, target_dim))
# Use the random indices to select elements from the tensor
selected_values = _Glint2023NoiseMap[random_indices, random_indices]

In [21]:
# update values
_Glint2023NoiseMap = selected_values
_Glint2023NoiseMapSize = _Glint2023NoiseMap.shape[0]

## Glint BRDF


### Sampling Method


In [22]:
# vars
_ScreenSpaceScale = 1.5  # float
_LogMicrofacetDensity = 20.0  # float
_MicrofacetRoughness = 0.025  # float
_DensityRandomization = 1.5  # float

In [23]:
batch_size = 8
num_vals = batch_size * _Glint2023NoiseMapSize * _Glint2023NoiseMapSize

In [24]:
_ScreenSpaceScale = torch.rand((num_vals,), requires_grad=True)
_LogMicrofacetDensity = torch.rand((num_vals,), requires_grad=True)
_MicrofacetRoughness = torch.rand((num_vals,), requires_grad=True)
_DensityRandomization = torch.rand((num_vals,), requires_grad=True)

In [25]:
# @param slope: float2
# @param slopeRandOffset: float2
# @param out outUniform: float4
# @param out outGaussian: float4
# @param out slopeLerp: float2
def CustomRand4Texture(slope, slopeRandOffset):
    slope2 = torch.abs(slope) / _MicrofacetRoughness.unsqueeze(-1)
    slope2 = slope2 + (slopeRandOffset * _Glint2023NoiseMapSize)
    slopeLerp = torch.frac(slope2)
    # slopeCoord = (toIntApprox(torch.floor(slope2)) % _Glint2023NoiseMapSize)
    #
    # packedRead = _Glint2023NoiseMap[slopeCoord[...,0].long(), slopeCoord[...,1].long()]
    # outUniform, outGaussian = UnpackFloatParallel4(packedRead)
    # TODO: Uniform random numbers generated within the range of what would've been generated by UnpackFloatParallel4()
    outUniform = (
        torch.rand((slope.shape[0], 4), requires_grad=True) * (2**31 - 1) - 2**31
    )
    outGaussian = (
        torch.rand((slope.shape[0], 4), requires_grad=True) * (2**31 - 1) - 2**31
    )

    return outUniform, outGaussian, slopeLerp

In [26]:
# @param randB: float4
# @param randG: float4
# @param slopeLerp: float2
# @param footprintOneHitProba: float
# @param binomialSmoothWidth: float
# @param footprintMean: float
# @param footprintSTD: float
# @param microfacetCount: float
#
# @return float
def GenerateAngularBinomialValueForSurfaceCell(
    randB,
    randG,
    slopeLerp,
    footprintOneHitProba,
    binomialSmoothWidth,
    footprintMean,
    footprintSTD,
    microfacetCount,
):
    gating = torch.ones((len(binomialSmoothWidth), 4), requires_grad=True)
    gating = torch.where(
        binomialSmoothWidth.unsqueeze(-1) * gating > EPSILON,
        torch.clamp(
            RemapTo01(
                randB,
                (footprintOneHitProba + binomialSmoothWidth).unsqueeze(-1),
                (footprintOneHitProba - binomialSmoothWidth).unsqueeze(-1),
            ),
            0.0,
            1.0,
        ),
        torch.where(randB < footprintOneHitProba.unsqueeze(-1), ONE, ZERO),
    )

    # Compute gauss
    gauss = randG * footprintSTD.unsqueeze(-1) + footprintMean.unsqueeze(-1)
    gauss = torch.clamp(torch.floor(gauss), ZERO, microfacetCount.unsqueeze(-1))

    # Compute results
    results = gating * (1.0 + gauss)

    # Perform BilinearLerp
    return BilinearLerp(results, slopeLerp)

In [27]:
# @param uv: float2
# @param gridSeed: uint
# @param slope: float2
# @param footprintArea: float
# @param targetNDF: float
# @param gridWeight: float
#
# @return float
def SampleGlintGridSimplex(uv, gridSeed, slope, footprintArea, targetNDF, gridWeight):
    # Get surface space glint simplex grid cell
    gridToSkewedGrid = torch.tensor(
        [[1.0, -0.57735027], [0.0, 1.15470054]], requires_grad=True
    )
    skewedCoord = torch.matmul(gridToSkewedGrid, uv.t()).t()
    # baseId = torch.floor(skewedCoord).to(torch.int)
    baseId = toIntApprox(torch.floor(skewedCoord))
    temp = torch.cat(
        (
            torch.frac(skewedCoord),
            torch.zeros(len(skewedCoord), 1, requires_grad=True),
        ),
        dim=-1,
    )
    temp[..., 2] = 1.0 - temp[..., 0] - temp[..., 1]
    s = torch.where(-temp[..., 2] >= 0.0, ONE, ZERO)
    s2 = 2.0 * s - 1.0
    # glint0 = baseId + torch.stack((s,s), dim=-1).to(torch.int)
    # glint1 = baseId + torch.stack((s, 1.0 - s), dim=-1).to(torch.int)
    # glint2 = baseId + torch.stack((1.0 - s, s), dim=-1).to(torch.int)
    glint0 = baseId + toIntApprox(torch.stack((s, s), dim=-1))
    glint1 = baseId + toIntApprox(torch.stack((s, 1.0 - s), dim=-1))
    glint2 = baseId + toIntApprox(torch.stack((1.0 - s, s), dim=-1))
    barycentrics = torch.stack(
        (-temp[..., 2] * s2, s - temp[..., 1] * s2, s - temp[..., 0] * s2), dim=-1
    )

    # Generate per surface cell random numbers
    # It's produce overflow runtime warning but uint is indeed 32 bits in HLSL and the code is basically copied pasted
    # So I don't think I could do much about it
    # pcg3dFloat already ensured that the param will be converted to np.uint32
    # rand0 = pcg3dFloat(np.hstack((np.uint32(glint0.detach().numpy() + 2147483648), gridSeed.reshape(-1, 1))))
    # rand1 = pcg3dFloat(np.hstack((np.uint32(glint1.detach().numpy() + 2147483648), gridSeed.reshape(-1, 1))))
    # rand2 = pcg3dFloat(np.hstack((np.uint32(glint2.detach().numpy() + 2147483648), gridSeed.reshape(-1, 1))))
    rand0 = mt19937_3dFloat(
        (glint0.shape[0], 3), int((torch.sum(glint0) + gridSeed.mean()).item())
    )
    rand1 = mt19937_3dFloat(
        (glint1.shape[0], 3), int((torch.sum(glint1) + gridSeed.mean()).item())
    )
    rand2 = mt19937_3dFloat(
        (glint2.shape[0], 3), int((torch.sum(glint2) + gridSeed.mean()).item())
    )

    # Get per surface cell per slope cell random numbers
    rand0SlopesB, rand0SlopesG, slopeLerp0 = CustomRand4Texture(slope, rand0[..., 1:])
    rand1SlopesB, rand1SlopesG, slopeLerp1 = CustomRand4Texture(slope, rand1[..., 1:])
    rand2SlopesB, rand2SlopesG, slopeLerp2 = CustomRand4Texture(slope, rand2[..., 1:])

    # Compute microfacet count with randomization
    logDensityRand = torch.clamp(
        sampleNormalDistribution(
            torch.stack((rand0[..., 0], rand1[..., 0], rand2[..., 0]), dim=-1),
            _LogMicrofacetDensity,
            _DensityRandomization,
        ),
        0.0,
        50.0,
    )

    microfacetCount = torch.clamp(
        footprintArea.unsqueeze(-1) * torch.exp(logDensityRand), min=EPSILON
    )

    # Compute binomial properties
    hitProba = torch.clamp(
        _MicrofacetRoughness * targetNDF, 0.0, 1.0
    )  # probability of hitting desired half vector in NDF distribution
    microfacetCountBlended = microfacetCount * gridWeight.unsqueeze(-1)
    microfacetCountBlended = torch.clamp(microfacetCountBlended, min=1.0)

    footprintOneHitProba = 1.0 - torch.pow(
        (1.0 - hitProba).unsqueeze(-1), microfacetCountBlended
    )  # probability of hitting at least one microfacet in footprint
    footprintMean = (microfacetCountBlended - 1.0) * hitProba.unsqueeze(
        -1
    )  # Expected value of number of hits in the footprint given already one hit
    footprintSTD = torch.sqrt(
        torch.clamp(
            (microfacetCountBlended - 1.0)
            * hitProba.unsqueeze(-1)
            * (1.0 - hitProba.unsqueeze(-1)),
            EPSILON,
        )
    )  # Standard deviation of number of hits in the footprint given already one hit, clampped for numerical stability

    binomialSmoothWidth = (
        0.1
        * torch.clamp(footprintOneHitProba * 10, 0.0, 1.0)
        * torch.clamp((1.0 - footprintOneHitProba) * 10, 0.0, 1.0)
    )

    # Generate numbers of reflecting microfacets
    result0 = GenerateAngularBinomialValueForSurfaceCell(
        rand0SlopesB,
        rand0SlopesG,
        slopeLerp0,
        footprintOneHitProba[..., 0],
        binomialSmoothWidth[..., 0],
        footprintMean[..., 0],
        footprintSTD[..., 0],
        microfacetCountBlended[..., 0],
    )
    result1 = GenerateAngularBinomialValueForSurfaceCell(
        rand1SlopesB,
        rand1SlopesG,
        slopeLerp1,
        footprintOneHitProba[..., 1],
        binomialSmoothWidth[..., 1],
        footprintMean[..., 1],
        footprintSTD[..., 1],
        microfacetCountBlended[..., 1],
    )
    result2 = GenerateAngularBinomialValueForSurfaceCell(
        rand2SlopesB,
        rand2SlopesG,
        slopeLerp2,
        footprintOneHitProba[..., 2],
        binomialSmoothWidth[..., 2],
        footprintMean[..., 2],
        footprintSTD[..., 2],
        microfacetCountBlended[..., 2],
    )

    # Interpolate result for glint grid cell
    results = torch.stack((result0, result1, result2), dim=-1) / microfacetCount

    return torch.sum(results * barycentrics, dim=-1)

In [28]:
# @param centerSpecialCase: bool
# @param thetaBinLerp: inout float
# @param ratioLerp: float
# @param lodLerp: float
# @param out p0: out float3
# @param out p1: out float3
# @param out p2: out float3
# @param out p3: out float3
def GetAnisoCorrectingGridTetrahedron(
    centerSpecialCase, thetaBinLerp, ratioLerp, lodLerp
):
    p0 = torch.ones((len(centerSpecialCase), 3), requires_grad=True)
    p1 = torch.ones((len(centerSpecialCase), 3), requires_grad=True)
    p2 = torch.ones((len(centerSpecialCase), 3), requires_grad=True)
    p3 = torch.ones((len(centerSpecialCase), 3), requires_grad=True)

    # vars for centerSpecialCase
    upper_pyramid_mask = torch.where(lodLerp > 1.0 - ratioLerp, ONE, ZERO)
    lower_pyramid_mask = torch.where(lodLerp < 1.0 - ratioLerp, ONE, ZERO)
    left_up_tetrahedron_mask = torch.where(
        RemapTo01(lodLerp, 1.0 - ratioLerp, 1.0) > thetaBinLerp, ONE, ZERO
    )

    # vars for normal case
    normal_case_mask = 1.0 - centerSpecialCase  # flipping the bits using 1.0 - bit
    prismA_mask = torch.where(
        (thetaBinLerp < 0.5) & (thetaBinLerp * 2.0 < ratioLerp), ONE, ZERO
    )
    prismB_mask = torch.where(1.0 - ((thetaBinLerp - 0.5) * 2.0) > ratioLerp, ONE, ZERO)
    prismC_mask = (1.0 - prismA_mask) * (1.0 - prismB_mask)
    left_up_tetrahedron_mask_prismA = torch.where(
        RemapTo01(lodLerp, 1.0 - ratioLerp, 1.0)
        > RemapTo01(thetaBinLerp * 2.0, 0.0, ratioLerp),
        ONE,
        ZERO,
    )
    left_up_tetrahedron_mask_prismB = torch.where(
        RemapTo01(lodLerp, 0.0, 1.0 - ratioLerp)
        > RemapTo01(
            thetaBinLerp,
            0.5 - (1.0 - ratioLerp) * 0.5,
            0.5 + (1.0 - ratioLerp) * 0.5,
        ),
        ONE,
        ZERO,
    )
    left_up_tetrahedron_mask_prismC = torch.where(
        RemapTo01(lodLerp, 1.0 - ratioLerp, 1.0)
        > RemapTo01((thetaBinLerp - 0.5) * 2.0, 1.0 - ratioLerp, 1.0),
        ONE,
        ZERO,
    )

    # if (centerSpecialCase): # SPECIAL CASE (no anisotropy, center of blending pattern, different triangulation)
    a = torch.tensor([0.0, 1.0, 0.0], requires_grad=True)
    b = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
    c = torch.tensor([1.0, 1.0, 0.0], requires_grad=True)
    d = torch.tensor([0.0, 1.0, 1.0], requires_grad=True)
    e = torch.tensor([0.0, 0.0, 1.0], requires_grad=True)
    f = torch.tensor([1.0, 1.0, 1.0], requires_grad=True)

    p0 = torch.where(
        (centerSpecialCase * upper_pyramid_mask * left_up_tetrahedron_mask).unsqueeze(
            -1
        )
        * p0
        == p0,
        a,
        p0,
    )
    p0 = torch.where(
        (
            centerSpecialCase * upper_pyramid_mask * (1.0 - left_up_tetrahedron_mask)
        ).unsqueeze(-1)
        * p0
        == p0,
        f,
        p0,
    )
    # If a condition replaces values with b [0.0,0.0,0.0], it must be put at the end because any mask times
    # b would result in 0.0, so the condition would always be true
    # Need to ensures that at this stage, if condition * p0 == 0.0, it must be because condition == 0.0
    p0 = torch.where(
        (centerSpecialCase * (1.0 - upper_pyramid_mask)).unsqueeze(-1) * p0 == p0,
        b,
        p0,
    )

    # p1
    p1 = torch.where(
        (centerSpecialCase * upper_pyramid_mask).unsqueeze(-1) * p1 == p1, e, p1
    )
    p1 = torch.where(
        (centerSpecialCase * (1.0 - upper_pyramid_mask)).unsqueeze(-1) * p1 == p1,
        a,
        p1,
    )

    # p2
    p2 = torch.where(
        (centerSpecialCase * upper_pyramid_mask * left_up_tetrahedron_mask).unsqueeze(
            -1
        )
        * p2
        == p2,
        d,
        p2,
    )
    p2 = torch.where(
        (
            centerSpecialCase * upper_pyramid_mask * (1.0 - left_up_tetrahedron_mask)
        ).unsqueeze(-1)
        * p2
        == p2,
        c,
        p2,
    )
    p2 = torch.where(
        (centerSpecialCase * (1.0 - upper_pyramid_mask)).unsqueeze(-1) * p2 == p2,
        c,
        p2,
    )

    # p3
    p3 = torch.where(
        (centerSpecialCase * upper_pyramid_mask * left_up_tetrahedron_mask).unsqueeze(
            -1
        )
        * p3
        == p3,
        f,
        p3,
    )
    p3 = torch.where(
        (
            centerSpecialCase * upper_pyramid_mask * (1.0 - left_up_tetrahedron_mask)
        ).unsqueeze(-1)
        * p3
        == p3,
        a,
        p3,
    )
    p3 = torch.where(
        (centerSpecialCase * (1.0 - upper_pyramid_mask)).unsqueeze(-1) * p3 == p3,
        e,
        p3,
    )

    # else: # NORMAL CASE
    c = torch.tensor([0.5, 1.0, 0.0], requires_grad=True)
    d = torch.tensor([1.0, 0.0, 0.0], requires_grad=True)
    e = torch.tensor([1.0, 1.0, 0.0], requires_grad=True)
    f = torch.tensor([0.0, 1.0, 1.0], requires_grad=True)
    g = torch.tensor([0.0, 0.0, 1.0], requires_grad=True)
    h = torch.tensor([0.5, 1.0, 1.0], requires_grad=True)
    i = torch.tensor([1.0, 0.0, 1.0], requires_grad=True)
    j = torch.tensor([1.0, 1.0, 1.0], requires_grad=True)

    # p0 values
    p0 = torch.where(
        (
            normal_case_mask
            * prismA_mask
            * upper_pyramid_mask
            * left_up_tetrahedron_mask_prismA
        ).unsqueeze(-1)
        * p0
        == p0,
        a,
        p0,
    )
    # A tweak to calculate and & operation for floating point 1.0 and 0.0 values
    p0_replace_with_c = (
        normal_case_mask
        * prismA_mask
        * upper_pyramid_mask
        * (1.0 - left_up_tetrahedron_mask_prismA)
        + normal_case_mask
        * prismC_mask
        * upper_pyramid_mask
        * left_up_tetrahedron_mask_prismC
        + normal_case_mask * prismB_mask * (1.0 - lower_pyramid_mask)
    )
    p0 = torch.where(
        p0_replace_with_c.unsqueeze(-1) * p0
        != b,  # if condition * p0 == 0.0, it must be because condition == 0.0
        c,
        p0,
    )
    p0_replace_with_d = normal_case_mask * prismB_mask * lower_pyramid_mask * (
        1.0 - left_up_tetrahedron_mask_prismB
    ) + normal_case_mask * prismC_mask * (1.0 - upper_pyramid_mask)
    p0 = torch.where(p0_replace_with_d.unsqueeze(-1) * p0 != b, d, p0)
    p0 = torch.where(
        (
            normal_case_mask
            * prismC_mask
            * upper_pyramid_mask
            * (1.0 - left_up_tetrahedron_mask_prismC)
        ).unsqueeze(-1)
        * p0
        == p0,
        e,
        p0,
    )

    p0_replace_with_b = (
        normal_case_mask * prismA_mask * (1.0 - upper_pyramid_mask)
        + normal_case_mask
        * prismB_mask
        * lower_pyramid_mask
        * left_up_tetrahedron_mask_prismB
    )
    p0 = torch.where(p0_replace_with_b.unsqueeze(-1) * p0 != b, b, p0)

    # p1
    p1 = torch.where(
        (
            normal_case_mask
            * prismA_mask
            * upper_pyramid_mask
            * left_up_tetrahedron_mask_prismA
        ).unsqueeze(-1)
        * p1
        == p1,
        f,
        p1,
    )
    p1_replace_with_a = normal_case_mask * prismA_mask * upper_pyramid_mask * (
        1.0 - left_up_tetrahedron_mask_prismA
    ) + normal_case_mask * prismA_mask * (1.0 - upper_pyramid_mask)
    p1 = torch.where(p1_replace_with_a.unsqueeze(-1) * p1 != b, a, p1)
    p1_replace_with_g = (
        normal_case_mask
        * prismB_mask
        * lower_pyramid_mask
        * left_up_tetrahedron_mask_prismB
        + normal_case_mask * prismB_mask * (1.0 - lower_pyramid_mask)
    )
    p1 = torch.where(p1_replace_with_g.unsqueeze(-1) * p1 != b, g, p1)
    p1 = torch.where(
        (
            normal_case_mask
            * prismC_mask
            * upper_pyramid_mask
            * left_up_tetrahedron_mask_prismC
        ).unsqueeze(-1)
        * p1
        == p1,
        j,
        p1,
    )
    p1 = torch.where(
        (
            normal_case_mask
            * prismC_mask
            * upper_pyramid_mask
            * (1.0 - left_up_tetrahedron_mask_prismC)
        ).unsqueeze(-1)
        * p1
        == p1,
        i,
        p1,
    )
    p1 = torch.where(
        (normal_case_mask * prismC_mask * (1.0 - upper_pyramid_mask)).unsqueeze(-1) * p1
        == p1,
        e,
        p1,
    )

    p1 = torch.where(
        (
            normal_case_mask
            * prismB_mask
            * lower_pyramid_mask
            * (1.0 - left_up_tetrahedron_mask_prismB)
        ).unsqueeze(-1)
        * p1
        == p1,
        b,
        p1,
    )

    # p2
    p2_replace_with_h = (
        normal_case_mask * prismA_mask * upper_pyramid_mask
        + normal_case_mask * prismB_mask * (1.0 - lower_pyramid_mask)
        + normal_case_mask
        * prismC_mask
        * upper_pyramid_mask
        * left_up_tetrahedron_mask_prismC
    )
    p2 = torch.where(p2_replace_with_h.unsqueeze(-1) * p2 != b, h, p2)
    p2_replace_with_c = (
        normal_case_mask * prismA_mask * (1.0 - upper_pyramid_mask)
        + normal_case_mask
        * prismB_mask
        * lower_pyramid_mask
        * (1.0 - left_up_tetrahedron_mask_prismB)
        + normal_case_mask
        * prismC_mask
        * upper_pyramid_mask
        * (1.0 - left_up_tetrahedron_mask_prismC)
        + normal_case_mask * prismC_mask * (1.0 - upper_pyramid_mask)
    )
    p2 = torch.where(p2_replace_with_c.unsqueeze(-1) * p2 != b, c, p2)
    torch.where(
        (
            normal_case_mask
            * prismB_mask
            * lower_pyramid_mask
            * left_up_tetrahedron_mask_prismB
        ).unsqueeze(-1)
        * p2
        == p2,
        i,
        p2,
    )

    # p3
    p3 = torch.where((normal_case_mask * prismA_mask).unsqueeze(-1) * p3 == p3, g, p3)
    p3 = torch.where(
        (
            normal_case_mask
            * prismB_mask
            * lower_pyramid_mask
            * left_up_tetrahedron_mask_prismB
        ).unsqueeze(-1)
        * p3
        == p3,
        c,
        p3,
    )
    p3_replace_with_i = (
        normal_case_mask
        * prismB_mask
        * lower_pyramid_mask
        * (1.0 - left_up_tetrahedron_mask_prismB)
        + normal_case_mask * prismB_mask * (1.0 - lower_pyramid_mask)
        + normal_case_mask
        * prismC_mask
        * upper_pyramid_mask
        * left_up_tetrahedron_mask_prismC
        + normal_case_mask * prismC_mask * (1.0 - upper_pyramid_mask)
    )
    p3 = torch.where(p3_replace_with_i.unsqueeze(-1) * p3 != b, i, p3)
    p3 = torch.where(
        (
            normal_case_mask
            * prismC_mask
            * upper_pyramid_mask
            * (1.0 - left_up_tetrahedron_mask_prismC)
        ).unsqueeze(-1)
        * p3
        == p3,
        j,
        p3,
    )

    return p0, p1, p2, p3

In [29]:
# @param localHalfVector: float3
# @param targetNDF: float
# @param maxNDF: float
# @param uv: float2
# @param duvdx: float2
# @param duvdy: float2
#
# @return float4
def SampleGlints2023NDF(localHalfVector, targetNDF, maxNDF, uv, duvdx, duvdy):
    ellipseMajor, ellipseMinor = GetGradientEllipse(duvdx, duvdy)
    ellipseRatio = torch.norm(ellipseMajor, dim=-1) / torch.norm(ellipseMinor, dim=-1)

    # SHARED GLINT NDF VALUES
    halfScreenSpaceScaler = _ScreenSpaceScale * 0.5
    slope = localHalfVector[..., :2]  # Orthogrtaphic slope projected grid
    rescaledTargetNDF = targetNDF / maxNDF

    # MANUAL LOD COMPENSATION
    lod = torch.log2(torch.norm(ellipseMinor, dim=-1) * halfScreenSpaceScaler)
    # TODO: lod0 = lod.to(torch.int).to(torch.float) # lod >= 0.0 ? (int)(lod) : (int)(lod - 1.0)
    lod0 = toIntApprox(lod)
    lod1 = lod0 + 1
    divLod0 = torch.pow(torch.tensor(2.0), lod0)
    divLod1 = torch.pow(torch.tensor(2.0), lod1)
    lodLerp = torch.frac(lod)
    footprintAreaLOD0 = torch.pow(torch.exp2(lod0), 2.0)
    footprintAreaLOD1 = torch.pow(torch.exp2(lod1), 2.0)

    # MANUAL ANISOTROPY RATIO COMPENSATION
    # TODO: ratio0 = torch.max(torch.pow(torch.tensor(2.0), torch.log2(ellipseRatio).to(torch.int)), torch.tensor(1.0))
    ratio0 = torch.clamp(
        torch.pow(
            torch.tensor(2.0, requires_grad=True),
            torch.log2(ellipseRatio).to(torch.int),
        ),
        1.0,
    )
    ratio1 = ratio0 * 2.0
    ratioLerp = torch.clamp(Remap(ellipseRatio, ratio0, ratio1, 0.0, 1.0), 0.0, 1.0)

    # MANUAL ANISOTROPY ROTATION COMPENSATION
    v1 = torch.tensor([0.0, 1.0], requires_grad=True)
    v2 = torch.nn.functional.normalize(ellipseMajor, dim=-1)
    theta = (
        torch.atan2(
            v1[0] * v2[..., 1] - v1[1] * v2[..., 0],
            v1[0] * v2[..., 0] + v1[1] * v2[..., 1],
        )
        * RAD2DEG
    )
    thetaGrid = 90.0 / torch.clamp(ratio0, 2.0)
    # TODO: thetaBin = (theta / thetaGrid).to(torch.int) * thetaGrid
    thetaBin = toIntApprox(theta / thetaGrid) * thetaGrid
    thetaBin += thetaGrid / 2.0
    thetaBin0 = torch.where(theta < thetaBin, thetaBin - thetaGrid / 2.0, thetaBin)
    thetaBinH = thetaBin0 + thetaGrid / 4.0
    thetaBin1 = thetaBin0 + thetaGrid / 2.0
    thetaBinLerp = Remap(theta, thetaBin0, thetaBin1, 0.0, 1.0)
    thetaBin0 = torch.where(thetaBin0 <= 0.0, 180.0 + thetaBin0, thetaBin0)

    # TETRAHEDRONIZATION OF ROTATION + RATIO + LOD GRID
    # TODO: centerSpecialCase = (ratio0 == 1.0)
    centerSpecialCase = torch.where(ratio0 == 1.0, ratio0, ZERO)
    divLods = torch.stack((divLod0, divLod1), dim=-1)
    footprintAreas = torch.stack((footprintAreaLOD0, footprintAreaLOD1), dim=-1)
    ratios = torch.stack((ratio0, ratio1), dim=-1)
    thetaBins = torch.stack(
        (thetaBin0, thetaBinH, thetaBin1, torch.zeros(thetaBin0.shape)),
        dim=-1,
    )  # added 0.0 for center singularity case
    tetraA, tetraB, tetraC, tetraD = GetAnisoCorrectingGridTetrahedron(
        centerSpecialCase, thetaBinLerp, ratioLerp, lodLerp
    )
    # TODO: thetaBinLerp[centerSpecialCase] = Remap01To(thetaBinLerp[centerSpecialCase], 0.0, ratioLerp[centerSpecialCase])
    # Account for center singularity in barycentric computation
    thetaBinLerp = torch.where(
        centerSpecialCase == 1.0,
        Remap01To(thetaBinLerp, 0.0, ratioLerp),
        thetaBinLerp,
    )
    tetraBarycentricWeights = GetBarycentricWeightsTetrahedron(
        torch.stack((thetaBinLerp, ratioLerp, lodLerp), dim=-1),
        tetraA,
        tetraB,
        tetraC,
        tetraD,
    )  # Compute barycentric coordinates within chosen tetrahedron

    # PREPARE NEEDED ROTATIONS
    tetraA[..., 0] *= 2
    tetraB[..., 0] *= 2
    tetraC[..., 0] *= 2
    tetraD[..., 0] *= 2

    # if (centerSpecialCase): # Account for center singularity (if center vertex => no rotation)
    # TODO: tetraA[centerSpecialCase][..., 0] = torch.where(tetraA[centerSpecialCase][..., 1] == 0.0, 3.0, tetraA[centerSpecialCase][..., 0])
    # tetraB[centerSpecialCase][..., 0] = torch.where(tetraB[centerSpecialCase][..., 1] == 0.0, 3.0, tetraB[centerSpecialCase][..., 0])
    # tetraC[centerSpecialCase][..., 0] = torch.where(tetraC[centerSpecialCase][..., 1] == 0.0, 3.0, tetraC[centerSpecialCase][..., 0])
    # tetraD[centerSpecialCase][..., 0] = torch.where(tetraD[centerSpecialCase][..., 1] == 0.0, 3.0, tetraD[centerSpecialCase][..., 0])
    three = torch.tensor(3.0, requires_grad=True)
    tetraA[..., 0] = torch.where(
        centerSpecialCase == 1.0,
        torch.where(tetraA[..., 1] == 0.0, three, tetraA[..., 0]),
        tetraA[..., 0],
    )
    tetraB[..., 0] = torch.where(
        centerSpecialCase == 1.0,
        torch.where(tetraB[..., 1] == 0.0, three, tetraB[..., 0]),
        tetraB[..., 0],
    )
    tetraC[..., 0] = torch.where(
        centerSpecialCase == 1.0,
        torch.where(tetraC[..., 1] == 0.0, three, tetraC[..., 0]),
        tetraC[..., 0],
    )
    tetraD[..., 0] = torch.where(
        centerSpecialCase == 1.0,
        torch.where(tetraD[..., 1] == 0.0, three, tetraD[..., 0]),
        tetraD[..., 0],
    )

    # selections based on tetra values
    # TODO: thetaBins_tetraA = thetaBins[torch.arange(len(thetaBins)), tetraA[...,0].to(torch.int)]
    # thetaBins_tetraB = thetaBins[torch.arange(len(thetaBins)), tetraB[...,0].to(torch.int)]
    # thetaBins_tetraC = thetaBins[torch.arange(len(thetaBins)), tetraC[...,0].to(torch.int)]
    # thetaBins_tetraD = thetaBins[torch.arange(len(thetaBins)), tetraD[...,0].to(torch.int)]
    zeros = torch.zeros(tetraA.shape[0], 1, requires_grad=True)
    thetaBins_tetraA = torch.gather(
        thetaBins, dim=-1, index=torch.cat((tetraA, zeros), dim=-1).to(torch.int64)
    )[..., 0]
    thetaBins_tetraB = torch.gather(
        thetaBins, dim=-1, index=torch.cat((tetraB, zeros), dim=-1).to(torch.int64)
    )[..., 0]
    thetaBins_tetraC = torch.gather(
        thetaBins, dim=-1, index=torch.cat((tetraC, zeros), dim=-1).to(torch.int64)
    )[..., 0]
    thetaBins_tetraD = torch.gather(
        thetaBins, dim=-1, index=torch.cat((tetraD, zeros), dim=-1).to(torch.int64)
    )[..., 0]
    # TODO: divLods_tetraA = divLods[torch.arange(len(divLods)), tetraA[...,2].to(torch.int)]
    # divLods_tetraB = divLods[torch.arange(len(divLods)), tetraB[...,2].to(torch.int)]
    # divLods_tetraC = divLods[torch.arange(len(divLods)), tetraC[...,2].to(torch.int)]
    # divLods_tetraD = divLods[torch.arange(len(divLods)), tetraD[...,2].to(torch.int)]
    divLods_tetraA = torch.gather(divLods, dim=-1, index=tetraA[:, 1:].to(torch.int64))[
        ..., 1
    ]
    divLods_tetraB = torch.gather(divLods, dim=-1, index=tetraB[:, 1:].to(torch.int64))[
        ..., 1
    ]
    divLods_tetraC = torch.gather(divLods, dim=-1, index=tetraC[:, 1:].to(torch.int64))[
        ..., 1
    ]
    divLods_tetraD = torch.gather(divLods, dim=-1, index=tetraD[:, 1:].to(torch.int64))[
        ..., 1
    ]
    # TODO: ratios_tetraA = ratios[torch.arange(len(ratios)), tetraA[...,1].to(torch.int)]
    # ratios_tetraB = ratios[torch.arange(len(ratios)), tetraB[...,1].to(torch.int)]
    # ratios_tetraC = ratios[torch.arange(len(ratios)), tetraC[...,1].to(torch.int)]
    # ratios_tetraD = ratios[torch.arange(len(ratios)), tetraD[...,1].to(torch.int)]
    ratios_tetraA = torch.gather(ratios, dim=-1, index=tetraA[:, 1:].to(torch.int64))[
        ..., 0
    ]
    ratios_tetraB = torch.gather(ratios, dim=-1, index=tetraB[:, 1:].to(torch.int64))[
        ..., 0
    ]
    ratios_tetraC = torch.gather(ratios, dim=-1, index=tetraC[:, 1:].to(torch.int64))[
        ..., 0
    ]
    ratios_tetraD = torch.gather(ratios, dim=-1, index=tetraD[:, 1:].to(torch.int64))[
        ..., 0
    ]
    # TODO: footprintAreas_tetraA = footprintAreas[torch.arange(len(footprintAreas)), tetraA[...,2].to(torch.int)]
    # footprintAreas_tetraB = footprintAreas[torch.arange(len(footprintAreas)), tetraB[...,2].to(torch.int)]
    # footprintAreas_tetraC = footprintAreas[torch.arange(len(footprintAreas)), tetraC[...,2].to(torch.int)]
    # footprintAreas_tetraD = footprintAreas[torch.arange(len(footprintAreas)), tetraD[...,2].to(torch.int)]
    footprintAreas_tetraA = torch.gather(
        footprintAreas, dim=-1, index=tetraA[:, 1:].to(torch.int64)
    )[..., 1]
    footprintAreas_tetraB = torch.gather(
        footprintAreas, dim=-1, index=tetraB[:, 1:].to(torch.int64)
    )[..., 1]
    footprintAreas_tetraC = torch.gather(
        footprintAreas, dim=-1, index=tetraC[:, 1:].to(torch.int64)
    )[..., 1]
    footprintAreas_tetraD = torch.gather(
        footprintAreas, dim=-1, index=tetraD[:, 1:].to(torch.int64)
    )[..., 1]

    uvRotA = RotateUV(
        uv,
        thetaBins_tetraA * DEG2RAD,
        torch.full((2,), 0.0, requires_grad=True),
    )
    uvRotB = RotateUV(
        uv,
        thetaBins_tetraB * DEG2RAD,
        torch.full((2,), 0.0, requires_grad=True),
    )
    uvRotC = RotateUV(
        uv,
        thetaBins_tetraC * DEG2RAD,
        torch.full((2,), 0.0, requires_grad=True),
    )
    uvRotD = RotateUV(
        uv,
        thetaBins_tetraD * DEG2RAD,
        torch.full((2,), 0.0, requires_grad=True),
    )

    # SAMPLE GLINT GRIDS
    # a float is returned
    # gridSeedA = np.uint32(HashWithoutSine13(torch.stack((torch.log2(divLods_tetraA), thetaBins_tetraA % 360, ratios_tetraA), dim=-1))* 4294967296.0)
    # gridSeedB = np.uint32(HashWithoutSine13(torch.stack((torch.log2(divLods_tetraB), thetaBins_tetraB % 360, ratios_tetraB), dim=-1))* 4294967296.0)
    # gridSeedC = np.uint32(HashWithoutSine13(torch.stack((torch.log2(divLods_tetraC), thetaBins_tetraC % 360, ratios_tetraC), dim=-1))* 4294967296.0)
    # gridSeedD = np.uint32(HashWithoutSine13(torch.stack((torch.log2(divLods_tetraD), thetaBins_tetraD % 360, ratios_tetraD), dim=-1))* 4294967296.0)
    gridSeedA = (
        HashWithoutSine13(
            torch.stack(
                (torch.log2(divLods_tetraA), thetaBins_tetraA % 360, ratios_tetraA),
                dim=-1,
            )
        )
        * 4294967296.0
    )
    gridSeedB = (
        HashWithoutSine13(
            torch.stack(
                (torch.log2(divLods_tetraB), thetaBins_tetraB % 360, ratios_tetraB),
                dim=-1,
            )
        )
        * 4294967296.0
    )
    gridSeedC = (
        HashWithoutSine13(
            torch.stack(
                (torch.log2(divLods_tetraC), thetaBins_tetraC % 360, ratios_tetraC),
                dim=-1,
            )
        )
        * 4294967296.0
    )
    gridSeedD = (
        HashWithoutSine13(
            torch.stack(
                (torch.log2(divLods_tetraD), thetaBins_tetraD % 360, ratios_tetraD),
                dim=-1,
            )
        )
        * 4294967296.0
    )

    ones = torch.ones(ratios_tetraA.shape, requires_grad=True)
    sampleA = SampleGlintGridSimplex(
        uvRotA
        / divLods_tetraA.unsqueeze(-1)
        / torch.stack((ones, ratios_tetraA), dim=-1),
        gridSeedA,
        slope,
        ratios_tetraA * footprintAreas_tetraA,
        rescaledTargetNDF,
        tetraBarycentricWeights[..., 0],
    )
    sampleB = SampleGlintGridSimplex(
        uvRotB
        / divLods_tetraB.unsqueeze(-1)
        / torch.stack((ones, ratios_tetraB), dim=-1),
        gridSeedB,
        slope,
        ratios_tetraB * footprintAreas_tetraB,
        rescaledTargetNDF,
        tetraBarycentricWeights[..., 1],
    )
    sampleC = SampleGlintGridSimplex(
        uvRotC
        / divLods_tetraC.unsqueeze(-1)
        / torch.stack((ones, ratios_tetraC), dim=-1),
        gridSeedC,
        slope,
        ratios_tetraC * footprintAreas_tetraC,
        rescaledTargetNDF,
        tetraBarycentricWeights[..., 2],
    )
    sampleD = SampleGlintGridSimplex(
        uvRotD
        / divLods_tetraD.unsqueeze(-1)
        / torch.stack((ones, ratios_tetraD), dim=-1),
        gridSeedD,
        slope,
        ratios_tetraD * footprintAreas_tetraD,
        rescaledTargetNDF,
        tetraBarycentricWeights[..., 3],
    )

    res = (
        (sampleA + sampleB + sampleC + sampleD) * (1.0 / _MicrofacetRoughness) * maxNDF
    )
    
    res = normalise(res)

    print(
        f"min: {torch.min(res)}, max: {torch.max(res)}, median: {torch.median(res)}, mean: {torch.mean(res)}"
    )

    assert is_valid(res)
    return res

### Try Calling the Function


In [30]:
localHalfVector = torch.rand((num_vals, 3), requires_grad=True)
targetNDF = torch.rand((num_vals,), requires_grad=True)
maxNDF = torch.rand((num_vals,), requires_grad=True)
uv = torch.rand((num_vals, 2), requires_grad=True)
duvdx = torch.rand((num_vals, 2), requires_grad=True)
duvdy = torch.rand((num_vals, 2), requires_grad=True)

In [31]:
torch.autograd.set_detect_anomaly(True)
res = SampleGlints2023NDF(localHalfVector, targetNDF, maxNDF, uv, duvdx, duvdy)
res

min: 0.0, max: 1.0, median: 0.009127961471676826, mean: 0.011039678007364273


tensor([0.0092, 0.0091, 0.0097,  ..., 0.0091, 0.0092, 0.0091],
       grad_fn=<DivBackward0>)

# Test


In [32]:
def validate_tensor(tensor):
    if torch.any(torch.isnan(tensor)):
        print("Nan found.")
    if torch.any(torch.isinf(tensor)):
        print("Inf found.")
    if tensor.grad_fn is None:
        print("No gradient.")
    if (
        not torch.any(torch.isnan(tensor))
        and not torch.any(torch.isinf(tensor))
        and tensor.grad_fn is not None
    ):
        print("All good.")

In [33]:
validate_tensor(res)

All good.


In [34]:
# torch.autograd.gradcheck(SampleGlints2023NDF, (localHalfVector, targetNDF, maxNDF, uv, duvdx, duvdy))