From 6b7539ee15e4fea84331aca31aad732e36e7543a Mon Sep 17 00:00:00 2001 From: Martin Valigursky Date: Tue, 31 Mar 2026 12:30:40 +0100 Subject: [PATCH] perf: reduce register pressure in compute GSplat TileCount projection Fuse the 2D covariance matrix chain to avoid materializing Vrk, J, and W intermediate mat3x3f matrices, reducing peak register usage. Return viewDepth from computeSplatCov to eliminate a redundant mat4*vec4 multiply. Remove unused radius/radiusFactor fields from SplatCov2D struct. Made-with: Cursor --- .../chunks/gsplat/compute-gsplat-common.js | 55 +++++++++---------- .../gsplat/compute-gsplat-local-tile-count.js | 3 +- 2 files changed, 27 insertions(+), 31 deletions(-) diff --git a/src/scene/shader-lib/wgsl/chunks/gsplat/compute-gsplat-common.js b/src/scene/shader-lib/wgsl/chunks/gsplat/compute-gsplat-common.js index 88121dda35f..b077ef1a72b 100644 --- a/src/scene/shader-lib/wgsl/chunks/gsplat/compute-gsplat-common.js +++ b/src/scene/shader-lib/wgsl/chunks/gsplat/compute-gsplat-common.js @@ -23,8 +23,7 @@ struct SplatCov2D { a: f32, b: f32, c: f32, - radius: vec2f, - radiusFactor: f32, + viewDepth: f32, valid: bool, } @@ -68,37 +67,36 @@ fn computeSplatCov( s.z * vec3f(rot[2]) )); - let covA = vec3f(dot(M[0], M[0]), dot(M[0], M[1]), dot(M[0], M[2])); - let covB = vec3f(dot(M[1], M[1]), dot(M[1], M[2]), dot(M[2], M[2])); - - let Vrk = mat3x3f( - vec3f(covA.x, covA.y, covA.z), - vec3f(covA.y, covB.x, covB.y), - vec3f(covA.z, covB.y, covB.z) - ); - let ortho = isOrtho == 1u; let v = select(viewCenter.xyz, vec3f(0.0, 0.0, 1.0), ortho); let vz = select(min(v.z, -nearClip), v.z, ortho); let J1 = focal / vz; let J2 = -J1 / vz * v.xy; - let J = mat3x3f( - vec3f(J1, 0.0, J2.x), - vec3f(0.0, J1, J2.y), - vec3f(0.0, 0.0, 0.0) - ); - - let W = transpose(mat3x3f( - viewMatrix[0].xyz, - viewMatrix[1].xyz, - viewMatrix[2].xyz - )); - let TT = W * J; - let cov = transpose(TT) * Vrk * TT; - let a = cov[0][0] + 0.3; - let b = cov[0][1]; - let c = cov[1][1] + 0.3; + // Compute TT columns directly without materializing full J and W matrices. + // Original code: + // let J = mat3x3f(vec3f(J1, 0.0, J2.x), vec3f(0.0, J1, J2.y), vec3f(0.0, 0.0, 0.0)); + // let W = transpose(mat3x3f(viewMatrix[0].xyz, viewMatrix[1].xyz, viewMatrix[2].xyz)); + // let TT = W * J; + let w0 = vec3f(viewMatrix[0].x, viewMatrix[1].x, viewMatrix[2].x); + let w1 = vec3f(viewMatrix[0].y, viewMatrix[1].y, viewMatrix[2].y); + let w2 = vec3f(viewMatrix[0].z, viewMatrix[1].z, viewMatrix[2].z); + let tt0 = J1 * w0 + J2.x * w2; + let tt1 = J1 * w1 + J2.y * w2; + + // Fused covariance: cov = TT^T * Vrk * TT = TT^T * (M^T * M) * TT = (M * TT)^T * (M * TT). + // Compute B = M * TT then cov = B^T * B, avoiding the intermediate Vrk (mat3x3f) matrix. + // Original code: + // let covA = vec3f(dot(M[0], M[0]), dot(M[0], M[1]), dot(M[0], M[2])); + // let covB = vec3f(dot(M[1], M[1]), dot(M[1], M[2]), dot(M[2], M[2])); + // let Vrk = mat3x3f(vec3f(covA.x, covA.y, covA.z), ...); + // let cov = transpose(TT) * Vrk * TT; + let b0 = M * tt0; + let b1 = M * tt1; + + let a = dot(b0, b0) + 0.3; + let b = dot(b0, b1); + let c = dot(b1, b1) + 0.3; let det = a * c - b * b; if (det <= 0.0) { @@ -156,8 +154,7 @@ fn computeSplatCov( result.a = scaledCov.x; result.b = scaledCov.y; result.c = scaledCov.z; - result.radius = vec2f(radiusX, radiusY); - result.radiusFactor = radiusFactor; + result.viewDepth = -viewCenter.z; result.valid = true; return result; } diff --git a/src/scene/shader-lib/wgsl/chunks/gsplat/compute-gsplat-local-tile-count.js b/src/scene/shader-lib/wgsl/chunks/gsplat/compute-gsplat-local-tile-count.js index dbb723d60c8..75ac7200ccf 100644 --- a/src/scene/shader-lib/wgsl/chunks/gsplat/compute-gsplat-local-tile-count.js +++ b/src/scene/shader-lib/wgsl/chunks/gsplat/compute-gsplat-local-tile-count.js @@ -95,8 +95,7 @@ fn main(@builtin(global_invocation_id) gid: vec3u, @builtin(num_workgroups) numW projCache[base + 6u] = pack2x16float(vec2f(rgb.z, opacity)); #endif - let viewDepth = -(uniforms.viewMatrix * vec4f(center, 1.0)).z; - projCache[base + 7u] = bitcast(viewDepth); + projCache[base + 7u] = bitcast(proj.viewDepth); let eval = computeSplatTileEval(proj.screen, coeffX, coeffY, coeffXY, half(opacity), uniforms.viewportWidth, uniforms.viewportHeight,