diff --git a/examples/src/examples/gaussian-splatting/lod-streaming.controls.mjs b/examples/src/examples/gaussian-splatting/lod-streaming.controls.mjs index f5ed4212f8a..d488a36b31c 100644 --- a/examples/src/examples/gaussian-splatting/lod-streaming.controls.mjs +++ b/examples/src/examples/gaussian-splatting/lod-streaming.controls.mjs @@ -171,7 +171,7 @@ export const controls = ({ observer, ReactPCUI, React, jsx, fragment }) => { binding: new BindingTwoWay(), link: { observer, path: 'splatBudget' }, min: 0, - max: 20, + max: 40, precision: 1, step: 0.1 }) diff --git a/src/scene/gsplat-unified/gsplat-frustum-culler.js b/src/scene/gsplat-unified/gsplat-frustum-culler.js index 1579d85fdf9..1b9fd288abe 100644 --- a/src/scene/gsplat-unified/gsplat-frustum-culler.js +++ b/src/scene/gsplat-unified/gsplat-frustum-culler.js @@ -1,26 +1,24 @@ import { Frustum } from '../../core/shape/frustum.js'; import { Mat4 } from '../../core/math/mat4.js'; -import { Vec2 } from '../../core/math/vec2.js'; -import { PIXELFORMAT_R32U, PIXELFORMAT_RGBA32F } from '../../platform/graphics/constants.js'; -import { RenderTarget } from '../../platform/graphics/render-target.js'; -import { Texture } from '../../platform/graphics/texture.js'; -import { TextureUtils } from '../../platform/graphics/texture-utils.js'; -import { GSplatNodeCullRenderPass } from './gsplat-node-cull-render-pass.js'; +import { BUFFERUSAGE_COPY_DST } from '../../platform/graphics/constants.js'; +import { StorageBuffer } from '../../platform/graphics/storage-buffer.js'; /** * @import { GraphicsDevice } from '../../platform/graphics/graphics-device.js' * @import { GSplatInfo } from "./gsplat-info.js" */ -const tmpSize = new Vec2(); const _viewProjMat = new Mat4(); const _frustum = new Frustum(); -const _frustumPlanes = new Float32Array(24); + +// 8 u32/f32 elements per BoundsEntry (matches WGSL struct layout): +// [centerX, centerY, centerZ, radius, transformIndex, pad, pad, pad] +const BOUNDS_ENTRY_FLOATS = 8; /** - * GPU frustum culling for GSplat octree nodes. Manages bounding-sphere and transform - * textures, runs a render-pass that tests each sphere against camera frustum planes, - * and produces a bit-packed visibility texture consumed by interval compaction. + * Frustum culling data for GSplat octree nodes. Manages bounding-sphere and + * transform storage buffers and computes frustum planes from camera matrices. + * The actual culling test is performed inline by the interval compaction compute shader. * * @ignore */ @@ -29,61 +27,53 @@ class GSplatFrustumCuller { device; /** - * RGBA32F texture storing local-space bounding spheres for all selected nodes - * across all GSplatInfos. Each texel is (center.x, center.y, center.z, radius). - * Created lazily on first use and resized as needed. + * Storage buffer holding interleaved BoundsEntry structs (center.xyz, radius, + * transformIndex, pad x3). 32 bytes per entry. * - * @type {Texture|null} + * @type {StorageBuffer|null} */ - boundsSphereTexture = null; + boundsBuffer = null; /** - * R32U texture mapping each bounds entry to its GSplatInfo index (for transform lookup). - * Same dimensions as boundsSphereTexture. Created lazily on first use and resized as needed. + * Total number of bounds entries across all GSplatInfos. * - * @type {Texture|null} + * @type {number} */ - boundsTransformIndexTexture = null; + totalBoundsEntries = 0; - /** - * R32U texture storing per-node visibility as packed bitmasks. - * Each texel packs 32 visibility bits, so width is boundsSphereTexture.width / 32. - * Written by the culling render pass. - * - * @type {Texture|null} - */ - nodeVisibilityTexture = null; + /** @type {number} */ + _allocatedBoundsEntries = 0; - /** - * Render target wrapping nodeVisibilityTexture for the culling pass. - * - * @type {RenderTarget|null} - */ - cullingRenderTarget = null; + /** @type {Float32Array|null} */ + _boundsFloatView = null; - /** - * GPU frustum culling render pass. Created lazily on first use. - * - * @type {GSplatNodeCullRenderPass|null} - */ - cullingPass = null; + /** @type {Uint32Array|null} */ + _boundsUintView = null; + + /** @type {Float32Array|null} */ + _tmpSpheres = null; /** - * Total number of bounds entries across all GSplatInfos. + * Storage buffer holding world matrices as vec4f triplets (3 vec4f per matrix, + * rows of a 4x3 affine matrix). 48 bytes per matrix. * - * @type {number} + * @type {StorageBuffer|null} */ - totalBoundsEntries = 0; + transformsBuffer = null; + + /** @type {number} */ + _allocatedTransformCount = 0; + + /** @type {Float32Array|null} */ + _transformsData = null; /** - * RGBA32F texture storing world matrices (3 texels per GSplatInfo, rows of a 4x3 - * affine matrix) for transforming local bounding spheres to world space during - * GPU frustum culling. - * Created lazily on first use and resized as needed. + * Packed frustum planes (6 planes x 4 floats: nx, ny, nz, distance). + * Updated by {@link computeFrustumPlanes} and consumed by the interval cull shader. * - * @type {Texture|null} + * @type {Float32Array} */ - transformsTexture = null; + frustumPlanes = new Float32Array(24); /** * @param {GraphicsDevice} device - The graphics device. @@ -93,22 +83,17 @@ class GSplatFrustumCuller { } destroy() { - this.boundsSphereTexture?.destroy(); - this.boundsTransformIndexTexture?.destroy(); - this.nodeVisibilityTexture?.destroy(); - this.cullingRenderTarget?.destroy(); - this.cullingPass?.destroy(); - this.transformsTexture?.destroy(); + this.boundsBuffer?.destroy(); + this.transformsBuffer?.destroy(); } /** - * Updates the bounds sphere texture with local-space bounding spheres from pre-built - * bounds groups. Each group contributes one set of sphere entries and maps to one - * transform index. + * Updates the bounds buffer with local-space bounding spheres and transform + * indices from pre-built bounds groups. * * @param {Array<{splat: GSplatInfo, boundsBaseIndex: number, numBoundsEntries: number}>} boundsGroups - Pre-built bounds groups. */ - updateBoundsTexture(boundsGroups) { + updateBoundsData(boundsGroups) { let totalEntries = 0; for (let i = 0; i < boundsGroups.length; i++) { totalEntries += boundsGroups[i].numBoundsEntries; @@ -118,65 +103,62 @@ class GSplatFrustumCuller { if (totalEntries === 0) return; - // Width is multiple of 32 so that 32 consecutive spheres always land on the same - // texture row, allowing the bit-packed culling shader to avoid per-iteration modulo/division. - const { x: width, y: height } = TextureUtils.calcTextureSize(totalEntries, tmpSize, 32); + if (totalEntries > this._allocatedBoundsEntries) { + this.boundsBuffer?.destroy(); + this._allocatedBoundsEntries = totalEntries; + this.boundsBuffer = new StorageBuffer(this.device, totalEntries * BOUNDS_ENTRY_FLOATS * 4, BUFFERUSAGE_COPY_DST); - // Create/resize bounds sphere texture (RGBA32F: center.xyz, radius) - if (!this.boundsSphereTexture) { - this.boundsSphereTexture = Texture.createDataTexture2D(this.device, 'boundsSphereTexture', width, height, PIXELFORMAT_RGBA32F); - } else { - this.boundsSphereTexture.resize(width, height); + const ab = new ArrayBuffer(totalEntries * BOUNDS_ENTRY_FLOATS * 4); + this._boundsFloatView = new Float32Array(ab); + this._boundsUintView = new Uint32Array(ab); + this._tmpSpheres = new Float32Array(totalEntries * 4); } - // Create/resize transform index texture (R32U: group index per bounds entry) - if (!this.boundsTransformIndexTexture) { - this.boundsTransformIndexTexture = Texture.createDataTexture2D(this.device, 'boundsTransformIndexTexture', width, height, PIXELFORMAT_R32U); - } else { - this.boundsTransformIndexTexture.resize(width, height); - } - - const sphereData = this.boundsSphereTexture.lock(); - const indexData = /** @type {Uint32Array} */ (this.boundsTransformIndexTexture.lock()); + const floatView = this._boundsFloatView; + const uintView = this._boundsUintView; + const tmpSpheres = this._tmpSpheres; for (let i = 0; i < boundsGroups.length; i++) { const group = boundsGroups[i]; const base = group.boundsBaseIndex; const count = group.numBoundsEntries; - group.splat.writeBoundsSpheres(sphereData, base * 4); + group.splat.writeBoundsSpheres(tmpSpheres, base * 4); for (let j = 0; j < count; j++) { - indexData[base + j] = i; + const src = (base + j) * 4; + const dst = (base + j) * BOUNDS_ENTRY_FLOATS; + floatView[dst + 0] = tmpSpheres[src + 0]; + floatView[dst + 1] = tmpSpheres[src + 1]; + floatView[dst + 2] = tmpSpheres[src + 2]; + floatView[dst + 3] = tmpSpheres[src + 3]; + uintView[dst + 4] = i; + // [dst+5..dst+7] are zero-initialized by ArrayBuffer } } - this.boundsSphereTexture.unlock(); - this.boundsTransformIndexTexture.unlock(); + this.boundsBuffer.write(0, floatView); } /** - * Updates the transforms texture with one world matrix per bounds group. - * Each matrix uses 3 texels (RGBA32F per row) in the texture. + * Updates the transforms buffer with one world matrix per bounds group. + * Each matrix is stored as 3 vec4f (rows of a 4x3 affine matrix). * * @param {Array<{splat: GSplatInfo, boundsBaseIndex: number, numBoundsEntries: number}>} boundsGroups - Pre-built bounds groups. */ - updateTransformsTexture(boundsGroups) { + updateTransformsData(boundsGroups) { const numMatrices = boundsGroups.length; if (numMatrices === 0) return; - // 3 texels per matrix (rows of a 4x3 affine matrix). Width is a multiple of 3 so all 3 - // texels of a matrix always land on the same texture row. - const totalTexels = numMatrices * 3; - const { x: width, y: height } = TextureUtils.calcTextureSize(totalTexels, tmpSize, 3); - - if (!this.transformsTexture) { - this.transformsTexture = Texture.createDataTexture2D(this.device, 'transformsTexture', width, height, PIXELFORMAT_RGBA32F); - } else { - this.transformsTexture.resize(width, height); + if (numMatrices > this._allocatedTransformCount) { + this.transformsBuffer?.destroy(); + this._allocatedTransformCount = numMatrices; + // 3 vec4f per matrix = 12 floats = 48 bytes + this.transformsBuffer = new StorageBuffer(this.device, numMatrices * 12 * 4, BUFFERUSAGE_COPY_DST); + this._transformsData = new Float32Array(numMatrices * 12); } - const data = this.transformsTexture.lock(); + const data = this._transformsData; // Write world matrices as 3 rows of a 4x3 matrix (row-major, 12 floats per matrix). // Mat4.data is column-major: [col0(4), col1(4), col2(4), col3(4)]. @@ -184,7 +166,6 @@ class GSplatFrustumCuller { // row0 = data[0], data[4], data[8], data[12] // row1 = data[1], data[5], data[9], data[13] // row2 = data[2], data[6], data[10], data[14] - // The shader reconstructs the mat4 by transposing + appending (0,0,0,1). let offset = 0; for (let i = 0; i < boundsGroups.length; i++) { const m = boundsGroups[i].splat.node.getWorldTransform().data; @@ -196,72 +177,27 @@ class GSplatFrustumCuller { data[offset++] = m[2]; data[offset++] = m[6]; data[offset++] = m[10]; data[offset++] = m[14]; } - this.transformsTexture.unlock(); + this.transformsBuffer.write(0, data); } /** - * Runs the GPU frustum culling pass to generate the node visibility texture. - * Computes the view-projection matrix, extracts frustum planes, and tests each - * bounding sphere against them. + * Computes frustum planes from camera matrices and stores them in + * {@link frustumPlanes} for use by the interval cull compute shader. * * @param {Mat4} projectionMatrix - The camera projection matrix. * @param {Mat4} viewMatrix - The camera view matrix. */ - updateNodeVisibility(projectionMatrix, viewMatrix) { - if (this.totalBoundsEntries === 0 || !this.boundsSphereTexture || !this.boundsTransformIndexTexture || !this.transformsTexture) { - return; - } - - // Compute view-projection matrix and extract frustum planes + computeFrustumPlanes(projectionMatrix, viewMatrix) { _viewProjMat.mul2(projectionMatrix, viewMatrix); _frustum.setFromMat4(_viewProjMat); + const planes = this.frustumPlanes; for (let p = 0; p < 6; p++) { const plane = _frustum.planes[p]; - _frustumPlanes[p * 4 + 0] = plane.normal.x; - _frustumPlanes[p * 4 + 1] = plane.normal.y; - _frustumPlanes[p * 4 + 2] = plane.normal.z; - _frustumPlanes[p * 4 + 3] = plane.distance; - } - - // Visibility texture is 32x smaller: each texel stores 32 sphere results as bits. - // Since boundsTextureWidth is a multiple of 32, the visibility texture is exactly - // (boundsWidth/32) x boundsHeight, keeping a 1:1 row correspondence and allowing - // the shader to derive visWidth = boundsTextureWidth / 32 without extra uniforms. - const width = this.boundsSphereTexture.width / 32; - const height = this.boundsSphereTexture.height; - - // Create/resize visibility texture (R32U: bit-packed, 32 spheres per texel) - if (!this.nodeVisibilityTexture) { - this.nodeVisibilityTexture = Texture.createDataTexture2D(this.device, 'nodeVisibilityTexture', width, height, PIXELFORMAT_R32U); - - this.cullingRenderTarget = new RenderTarget({ - name: 'NodeCullingRT', - colorBuffer: this.nodeVisibilityTexture, - depth: false - }); - } else if (this.nodeVisibilityTexture.width !== width || this.nodeVisibilityTexture.height !== height) { - this.nodeVisibilityTexture.resize(width, height); - /** @type {RenderTarget} */ (this.cullingRenderTarget).resize(width, height); + planes[p * 4 + 0] = plane.normal.x; + planes[p * 4 + 1] = plane.normal.y; + planes[p * 4 + 2] = plane.normal.z; + planes[p * 4 + 3] = plane.distance; } - - // Lazily create the culling render pass - if (!this.cullingPass) { - this.cullingPass = new GSplatNodeCullRenderPass(this.device); - this.cullingPass.init(this.cullingRenderTarget); - this.cullingPass.colorOps.clear = true; - this.cullingPass.colorOps.clearValue.set(0, 0, 0, 0); - } - - // Set up uniforms and execute - this.cullingPass.setup( - this.boundsSphereTexture, - this.boundsTransformIndexTexture, - this.transformsTexture, - this.totalBoundsEntries, - _frustumPlanes - ); - - this.cullingPass.render(); } } diff --git a/src/scene/gsplat-unified/gsplat-interval-compaction.js b/src/scene/gsplat-unified/gsplat-interval-compaction.js index b579245a209..2e0a138d6a8 100644 --- a/src/scene/gsplat-unified/gsplat-interval-compaction.js +++ b/src/scene/gsplat-unified/gsplat-interval-compaction.js @@ -2,15 +2,15 @@ import { Debug } from '../../core/debug.js'; import { Compute } from '../../platform/graphics/compute.js'; import { Shader } from '../../platform/graphics/shader.js'; import { StorageBuffer } from '../../platform/graphics/storage-buffer.js'; -import { BindGroupFormat, BindStorageBufferFormat, BindTextureFormat, BindUniformBufferFormat } from '../../platform/graphics/bind-group-format.js'; +import { BindGroupFormat, BindStorageBufferFormat, BindUniformBufferFormat } from '../../platform/graphics/bind-group-format.js'; import { UniformBufferFormat, UniformFormat } from '../../platform/graphics/uniform-buffer-format.js'; import { BUFFERUSAGE_COPY_DST, BUFFERUSAGE_COPY_SRC, - SAMPLETYPE_UINT, SHADERLANGUAGE_WGSL, SHADERSTAGE_COMPUTE, - UNIFORMTYPE_UINT + UNIFORMTYPE_UINT, + UNIFORMTYPE_VEC4 } from '../../platform/graphics/constants.js'; import { computeGsplatIntervalCullSource } from '../shader-lib/wgsl/chunks/gsplat/compute-gsplat-interval-cull.js'; import { computeGsplatIntervalScatterSource } from '../shader-lib/wgsl/chunks/gsplat/compute-gsplat-interval-scatter.js'; @@ -21,8 +21,8 @@ import { GSplatResourceBase } from '../gsplat/gsplat-resource-base.js'; /** * @import { GraphicsDevice } from '../../platform/graphics/graphics-device.js' + * @import { GSplatFrustumCuller } from './gsplat-frustum-culler.js' * @import { GSplatWorldState } from './gsplat-world-state.js' - * @import { Texture } from '../../platform/graphics/texture.js' */ const WORKGROUP_SIZE = 256; @@ -178,11 +178,6 @@ class GSplatIntervalCompaction { _createUniformBufferFormats() { const device = this.device; - this._cullUniformBufferFormat = new UniformBufferFormat(device, [ - new UniformFormat('numIntervals', UNIFORMTYPE_UINT), - new UniformFormat('visWidth', UNIFORMTYPE_UINT) - ]); - this._scatterUniformBufferFormat = new UniformBufferFormat(device, [ new UniformFormat('numIntervals', UNIFORMTYPE_UINT), new UniformFormat('pad0', UNIFORMTYPE_UINT), @@ -221,10 +216,22 @@ class GSplatIntervalCompaction { new BindStorageBufferFormat('countBuffer', SHADERSTAGE_COMPUTE, false) ]; if (cullingEnabled) { - entries.push(new BindTextureFormat('nodeVisibilityTexture', SHADERSTAGE_COMPUTE, undefined, SAMPLETYPE_UINT, false)); + entries.push(new BindStorageBufferFormat('boundsBuffer', SHADERSTAGE_COMPUTE, true)); + entries.push(new BindStorageBufferFormat('transformsBuffer', SHADERSTAGE_COMPUTE, true)); } this._cullBindGroupFormat = new BindGroupFormat(device, entries); + if (cullingEnabled) { + this._cullUniformBufferFormat = new UniformBufferFormat(device, [ + new UniformFormat('frustumPlanes', UNIFORMTYPE_VEC4, 6), + new UniformFormat('numIntervals', UNIFORMTYPE_UINT) + ]); + } else { + this._cullUniformBufferFormat = new UniformBufferFormat(device, [ + new UniformFormat('numIntervals', UNIFORMTYPE_UINT) + ]); + } + /** @type {Map} */ const cdefines = new Map([ ['{WORKGROUP_SIZE}', WORKGROUP_SIZE.toString()] @@ -292,7 +299,8 @@ class GSplatIntervalCompaction { const cdefines = new Map([ ['{INSTANCE_SIZE}', GSplatResourceBase.instanceSize], ['{KEYGEN_THREADS_PER_WORKGROUP}', 256], - ['{SORT_ELEMENTS_PER_WORKGROUP}', SORT_ELEMENTS_PER_WORKGROUP] + ['{SORT_ELEMENTS_PER_WORKGROUP}', SORT_ELEMENTS_PER_WORKGROUP], + ['{MAX_WORKGROUPS_PER_DIM}', device.limits.maxComputeWorkgroupsPerDimension || 65535] ]); const shader = new Shader(device, { @@ -386,12 +394,12 @@ class GSplatIntervalCompaction { /** * Runs the full interval compaction pipeline: cull+count, prefix sum, scatter. * - * @param {Texture|null} nodeVisibilityTexture - Bit-packed visibility texture (when culling). + * @param {GSplatFrustumCuller|null} frustumCuller - Frustum culler providing bounds/transforms storage buffers and frustum planes (when culling). * @param {number} numIntervals - Total number of intervals. * @param {number} totalActiveSplats - Total active splats across all intervals. * @param {boolean} cullingEnabled - Whether frustum culling is active. */ - dispatchCompact(nodeVisibilityTexture, numIntervals, totalActiveSplats, cullingEnabled) { + dispatchCompact(frustumCuller, numIntervals, totalActiveSplats, cullingEnabled) { if (numIntervals === 0) return; this._ensureCapacity(numIntervals, totalActiveSplats); @@ -403,11 +411,13 @@ class GSplatIntervalCompaction { cullCompute.setParameter('intervals', this.intervalsBuffer); cullCompute.setParameter('countBuffer', this.countBuffer); if (cullingEnabled) { - cullCompute.setParameter('nodeVisibilityTexture', nodeVisibilityTexture); + Debug.assert(frustumCuller, 'frustumCuller must be provided when cullingEnabled is true'); + cullCompute.setParameter('boundsBuffer', frustumCuller.boundsBuffer); + cullCompute.setParameter('transformsBuffer', frustumCuller.transformsBuffer); + cullCompute.setParameter('frustumPlanes[0]', frustumCuller.frustumPlanes); } cullCompute.setParameter('numIntervals', numIntervals); - cullCompute.setParameter('visWidth', cullingEnabled ? nodeVisibilityTexture.width : 0); const cullWorkgroups = Math.ceil(numIntervals / WORKGROUP_SIZE); cullCompute.setupDispatch(cullWorkgroups); diff --git a/src/scene/gsplat-unified/gsplat-manager.js b/src/scene/gsplat-unified/gsplat-manager.js index c5ec66a224e..a3e7a956a1d 100644 --- a/src/scene/gsplat-unified/gsplat-manager.js +++ b/src/scene/gsplat-unified/gsplat-manager.js @@ -824,11 +824,11 @@ class GSplatManager { this.workBuffer.resize(textureSize); } - // Bounds and transforms textures are needed for frustum culling. + // Bounds and transforms storage buffers are needed for frustum culling. // These index splats sequentially, so always use the full splats array. if (this.scene.gsplat.culling) { - this.workBuffer.frustumCuller.updateBoundsTexture(worldState.boundsGroups); - this.workBuffer.frustumCuller.updateTransformsTexture(worldState.boundsGroups); + this.workBuffer.frustumCuller.updateBoundsData(worldState.boundsGroups); + this.workBuffer.frustumCuller.updateTransformsData(worldState.boundsGroups); } // Render splats to work buffer: full rebuild renders all, partial renders only changed @@ -1593,8 +1593,8 @@ class GSplatManager { // Always run interval compaction (culling or not) const numIntervals = worldState.totalIntervals; const totalActiveSplats = worldState.totalActiveSplats; - const nodeVisibilityTexture = cullingEnabled ? this.workBuffer.frustumCuller.nodeVisibilityTexture : null; - this.intervalCompaction.dispatchCompact(nodeVisibilityTexture, numIntervals, totalActiveSplats, cullingEnabled); + const frustumCuller = cullingEnabled ? this.workBuffer.frustumCuller : null; + this.intervalCompaction.dispatchCompact(frustumCuller, numIntervals, totalActiveSplats, cullingEnabled); // Allocate indirect draw/dispatch slots and write args from visible count this.allocateAndWriteIntervalIndirectArgs(numIntervals); @@ -1653,8 +1653,8 @@ class GSplatManager { const numIntervals = worldState.totalIntervals; const totalActiveSplats = worldState.totalActiveSplats; - const nodeVisibilityTexture = cullingEnabled ? this.workBuffer.frustumCuller.nodeVisibilityTexture : null; - this.intervalCompaction.dispatchCompact(nodeVisibilityTexture, numIntervals, totalActiveSplats, cullingEnabled); + const frustumCuller = cullingEnabled ? this.workBuffer.frustumCuller : null; + this.intervalCompaction.dispatchCompact(frustumCuller, numIntervals, totalActiveSplats, cullingEnabled); // Extract the visible count from the prefix sum into sortElementCountBuffer. // writeIndirectArgs is the only path that does this; the indirect draw/dispatch @@ -1747,17 +1747,18 @@ class GSplatManager { } /** - * Runs GPU frustum culling: updates the transforms texture and renders the - * node visibility pass, producing the bit-packed nodeVisibilityTexture. + * Prepares frustum culling data: updates the GPU transform buffers and computes + * frustum planes from the camera. The actual culling test runs inline in the + * interval compaction compute shader. * * @param {GSplatWorldState} worldState - The world state whose splats provide transforms. * @private */ _runFrustumCulling(worldState) { - this.workBuffer.frustumCuller.updateTransformsTexture(worldState.boundsGroups); + this.workBuffer.frustumCuller.updateTransformsData(worldState.boundsGroups); const cam = this.cameraNode.camera; - this.workBuffer.frustumCuller.updateNodeVisibility(cam.projectionMatrix, cam.viewMatrix); + this.workBuffer.frustumCuller.computeFrustumPlanes(cam.projectionMatrix, cam.viewMatrix); } /** diff --git a/src/scene/gsplat-unified/gsplat-node-cull-render-pass.js b/src/scene/gsplat-unified/gsplat-node-cull-render-pass.js deleted file mode 100644 index f81a1111eaf..00000000000 --- a/src/scene/gsplat-unified/gsplat-node-cull-render-pass.js +++ /dev/null @@ -1,88 +0,0 @@ -import { SEMANTIC_POSITION } from '../../platform/graphics/constants.js'; -import { RenderPassShaderQuad } from '../graphics/render-pass-shader-quad.js'; -import { ShaderUtils } from '../shader-lib/shader-utils.js'; -import glslGsplatNodeCullingPS from '../shader-lib/glsl/chunks/gsplat/frag/gsplatNodeCulling.js'; -import wgslGsplatNodeCullingPS from '../shader-lib/wgsl/chunks/gsplat/frag/gsplatNodeCulling.js'; - -/** - * @import { GraphicsDevice } from '../../platform/graphics/graphics-device.js' - * @import { Texture } from '../../platform/graphics/texture.js' - */ - -/** - * Render pass for GPU frustum culling of bounding spheres. Reads local-space spheres and - * transform indices, reconstructs world matrices, and tests against camera frustum planes. - * Outputs a bit-packed R32U visibility texture (each texel holds 32 sphere results as bits). - * - * @ignore - */ -class GSplatNodeCullRenderPass extends RenderPassShaderQuad { - /** @type {Texture} */ - _boundsSphereTexture; - - /** @type {Texture} */ - _boundsTransformIndexTexture; - - /** @type {Texture} */ - _transformsTexture; - - /** @type {number} */ - _totalBoundsEntries = 0; - - /** @type {Float32Array} */ - _frustumPlanes; - - /** - * @param {GraphicsDevice} device - The graphics device. - */ - constructor(device) { - super(device); - - this.shader = ShaderUtils.createShader(device, { - uniqueName: 'GSplatNodeCulling', - attributes: { aPosition: SEMANTIC_POSITION }, - vertexChunk: 'quadVS', - fragmentGLSL: glslGsplatNodeCullingPS, - fragmentWGSL: wgslGsplatNodeCullingPS, - fragmentOutputTypes: ['uint'] - }); - - // Resolve uniform scope IDs - this.boundsSphereTextureId = device.scope.resolve('boundsSphereTexture'); - this.boundsTransformIndexTextureId = device.scope.resolve('boundsTransformIndexTexture'); - this.transformsTextureId = device.scope.resolve('transformsTexture'); - this.boundsTextureWidthId = device.scope.resolve('boundsTextureWidth'); - this.transformsTextureWidthId = device.scope.resolve('transformsTextureWidth'); - this.totalBoundsEntriesId = device.scope.resolve('totalBoundsEntries'); - this.frustumPlanesId = device.scope.resolve('frustumPlanes[0]'); - } - - /** - * @param {Texture} boundsSphereTexture - The bounds sphere texture. - * @param {Texture} boundsTransformIndexTexture - The transform index texture. - * @param {Texture} transformsTexture - The transforms texture. - * @param {number} totalBoundsEntries - Total number of bounds entries. - * @param {Float32Array} frustumPlanes - 24 floats: 6 planes x (nx, ny, nz, distance). - */ - setup(boundsSphereTexture, boundsTransformIndexTexture, transformsTexture, totalBoundsEntries, frustumPlanes) { - this._boundsSphereTexture = boundsSphereTexture; - this._boundsTransformIndexTexture = boundsTransformIndexTexture; - this._transformsTexture = transformsTexture; - this._totalBoundsEntries = totalBoundsEntries; - this._frustumPlanes = frustumPlanes; - } - - execute() { - this.boundsSphereTextureId.setValue(this._boundsSphereTexture); - this.boundsTransformIndexTextureId.setValue(this._boundsTransformIndexTexture); - this.transformsTextureId.setValue(this._transformsTexture); - this.boundsTextureWidthId.setValue(this._boundsSphereTexture.width); - this.transformsTextureWidthId.setValue(this._transformsTexture.width); - this.totalBoundsEntriesId.setValue(this._totalBoundsEntries); - this.frustumPlanesId.setValue(this._frustumPlanes); - - super.execute(); - } -} - -export { GSplatNodeCullRenderPass }; diff --git a/src/scene/shader-lib/glsl/chunks/gsplat/frag/gsplatNodeCulling.js b/src/scene/shader-lib/glsl/chunks/gsplat/frag/gsplatNodeCulling.js deleted file mode 100644 index 474aebd1ccf..00000000000 --- a/src/scene/shader-lib/glsl/chunks/gsplat/frag/gsplatNodeCulling.js +++ /dev/null @@ -1,89 +0,0 @@ -// Fragment shader for GPU frustum culling of bounding spheres. -// Each fragment processes 32 consecutive spheres and outputs a packed bitmask -// (bit b = 1 means sphere baseIndex+b is visible). The visibility texture is -// 32x smaller than the bounds texture. -export default /* glsl */` -uniform sampler2D boundsSphereTexture; -uniform usampler2D boundsTransformIndexTexture; -uniform sampler2D transformsTexture; - -uniform int boundsTextureWidth; -uniform int transformsTextureWidth; -uniform int totalBoundsEntries; -uniform vec4 frustumPlanes[6]; - -void main(void) { - // Linear texel index in the (small) visibility texture - int visWidth = boundsTextureWidth / 32; - int texelIndex = int(gl_FragCoord.y) * visWidth + int(gl_FragCoord.x); - - // Base sphere index for this group of 32 - int baseIndex = texelIndex * 32; - - // Since boundsTextureWidth is a multiple of 32, all 32 spheres are on the same row. - // Compute row coordinates once. - int baseX = baseIndex % boundsTextureWidth; - int boundsY = baseIndex / boundsTextureWidth; - - uint visBits = 0u; - uint cachedTransformIdx = 0xFFFFFFFFu; - mat4 worldMatrix; - vec4 row0, row1, row2; - - for (int b = 0; b < 32; b++) { - int sphereIndex = baseIndex + b; - if (sphereIndex >= totalBoundsEntries) break; - - ivec2 boundsCoord = ivec2(baseX + b, boundsY); - - // Read local-space bounding sphere (center.xyz, radius) - vec4 sphere = texelFetch(boundsSphereTexture, boundsCoord, 0); - vec3 localCenter = sphere.xyz; - float radius = sphere.w; - - // Read GSplatInfo transform index - uint transformIdx = texelFetch(boundsTransformIndexTexture, boundsCoord, 0).r; - - // Reconstruct world matrix only when transform index changes. - // The texture stores 3 texels per matrix (rows of a 4x3 affine matrix). - // Transpose back to column-major mat4 and append the implicit (0,0,0,1) row. - if (transformIdx != cachedTransformIdx) { - cachedTransformIdx = transformIdx; - int baseTexel = int(transformIdx) * 3; - int tx = baseTexel % transformsTextureWidth; - int ty = baseTexel / transformsTextureWidth; - row0 = texelFetch(transformsTexture, ivec2(tx, ty), 0); - row1 = texelFetch(transformsTexture, ivec2(tx + 1, ty), 0); - row2 = texelFetch(transformsTexture, ivec2(tx + 2, ty), 0); - worldMatrix = mat4( - row0.x, row1.x, row2.x, 0, - row0.y, row1.y, row2.y, 0, - row0.z, row1.z, row2.z, 0, - row0.w, row1.w, row2.w, 1 - ); - } - - // Transform sphere center to world space - vec3 worldCenter = (worldMatrix * vec4(localCenter, 1.0)).xyz; - - // World-space radius (uniform scale: all column lengths are equal) - float worldRadius = radius * length(vec3(row0.x, row1.x, row2.x)); - - // Test against 6 frustum planes - bool visible = true; - for (int p = 0; p < 6; p++) { - float dist = dot(frustumPlanes[p].xyz, worldCenter) + frustumPlanes[p].w; - if (dist <= -worldRadius) { - visible = false; - break; - } - } - - if (visible) { - visBits |= (1u << uint(b)); - } - } - - gl_FragColor = visBits; -} -`; diff --git a/src/scene/shader-lib/wgsl/chunks/common/comp/dispatch-core.js b/src/scene/shader-lib/wgsl/chunks/common/comp/dispatch-core.js new file mode 100644 index 00000000000..38c2554ec7b --- /dev/null +++ b/src/scene/shader-lib/wgsl/chunks/common/comp/dispatch-core.js @@ -0,0 +1,12 @@ +// Compute 2D dispatch dimensions from a linear workgroup count, staying within the +// per-dimension limit. Mirrors Compute.calcDispatchSize on the CPU side. +export default /* wgsl */` +fn calcDispatch2D(count: u32, maxDim: u32) -> vec2u { + if (count <= maxDim) { + return vec2u(count, 1u); + } + let y = (count + maxDim - 1u) / maxDim; + let x = (count + y - 1u) / y; + return vec2u(x, y); +} +`; diff --git a/src/scene/shader-lib/wgsl/chunks/gsplat/compute-gsplat-interval-cull.js b/src/scene/shader-lib/wgsl/chunks/gsplat/compute-gsplat-interval-cull.js index a274a015c1c..aa1ee94f214 100644 --- a/src/scene/shader-lib/wgsl/chunks/gsplat/compute-gsplat-interval-cull.js +++ b/src/scene/shader-lib/wgsl/chunks/gsplat/compute-gsplat-interval-cull.js @@ -1,18 +1,14 @@ // Compute shader for interval-based culling and counting. // -// Replaces the per-pixel flag pass for the GPU sort path. Instead of testing -// every work-buffer pixel, this shader tests one bounding sphere per interval -// against the nodeVisibilityTexture and writes the interval's splat count -// (or 0 if culled) into a count buffer. +// Each thread processes one interval. When CULLING_ENABLED is defined, the +// shader performs an inline sphere-vs-frustum test per interval (reading +// bounding spheres and transforms from storage buffers) and writes the +// interval's splat count (or 0 if culled) into a count buffer. Otherwise +// all intervals are visible (count is copied directly). // // After an exclusive prefix sum over numIntervals + 1 elements: // - prefixSum[i] gives the output offset for interval i's splats // - prefixSum[numIntervals] equals the total visible splat count -// -// When CULLING_ENABLED is defined, reads the bit-packed nodeVisibilityTexture -// to determine per-interval visibility. Otherwise all intervals are visible -// (count is copied directly), making this a trivial O(numIntervals) pass that -// still produces the prefix-sum input needed by the scatter shader. export const computeGsplatIntervalCullSource = /* wgsl */` @@ -23,10 +19,28 @@ struct Interval { pad: u32 }; -struct CullUniforms { - numIntervals: u32, - visWidth: u32 -}; +#ifdef CULLING_ENABLED + struct BoundsEntry { + centerX: f32, + centerY: f32, + centerZ: f32, + radius: f32, + transformIndex: u32, + pad0: u32, + pad1: u32, + pad2: u32 + }; + + struct CullUniforms { + frustumPlanes: array, + numIntervals: u32 + }; +#else + struct CullUniforms { + numIntervals: u32 + }; +#endif + @group(0) @binding(0) var uniforms: CullUniforms; @group(0) @binding(1) var intervals: array; @@ -34,7 +48,8 @@ struct CullUniforms { @group(0) @binding(2) var countBuffer: array; #ifdef CULLING_ENABLED -@group(0) @binding(3) var nodeVisibilityTexture: texture_2d; +@group(0) @binding(3) var boundsBuffer: array; +@group(0) @binding(4) var transformsBuffer: array; #endif @compute @workgroup_size({WORKGROUP_SIZE}) @@ -44,13 +59,33 @@ fn main(@builtin(global_invocation_id) gid: vec3u) { let interval = intervals[idx]; #ifdef CULLING_ENABLED - let boundsIdx = interval.boundsIndex; - let texelIdx = boundsIdx >> 5u; - let bitIdx = boundsIdx & 31u; - let visW = uniforms.visWidth; - let visCoord = vec2i(i32(texelIdx % visW), i32(texelIdx / visW)); - let visBits = textureLoad(nodeVisibilityTexture, visCoord, 0).r; - let visible = (visBits & (1u << bitIdx)) != 0u; + let entry = boundsBuffer[interval.boundsIndex]; + let localCenter = vec3f(entry.centerX, entry.centerY, entry.centerZ); + let radius = entry.radius; + + let base = entry.transformIndex * 3u; + let row0 = transformsBuffer[base]; + let row1 = transformsBuffer[base + 1u]; + let row2 = transformsBuffer[base + 2u]; + + let worldCenter = vec3f( + dot(vec4f(localCenter, 1.0), row0), + dot(vec4f(localCenter, 1.0), row1), + dot(vec4f(localCenter, 1.0), row2) + ); + + let worldRadius = radius * length(vec3f(row0.x, row1.x, row2.x)); + + var visible = true; + for (var p = 0; p < 6; p++) { + let plane = uniforms.frustumPlanes[p]; + let dist = dot(plane.xyz, worldCenter) + plane.w; + if (dist <= -worldRadius) { + visible = false; + break; + } + } + countBuffer[idx] = select(0u, interval.splatCount, visible); #else countBuffer[idx] = interval.splatCount; diff --git a/src/scene/shader-lib/wgsl/chunks/gsplat/compute-gsplat-local-classify.js b/src/scene/shader-lib/wgsl/chunks/gsplat/compute-gsplat-local-classify.js index c9d2a2145cf..32c5f7e0feb 100644 --- a/src/scene/shader-lib/wgsl/chunks/gsplat/compute-gsplat-local-classify.js +++ b/src/scene/shader-lib/wgsl/chunks/gsplat/compute-gsplat-local-classify.js @@ -6,10 +6,12 @@ // Single workgroup (256 threads) — each thread processes ceil(numTiles/256) tiles. import indirectCoreCS from '../common/comp/indirect-core.js'; +import dispatchCoreCS from '../common/comp/dispatch-core.js'; export const computeGsplatLocalClassifySource = /* wgsl */` ${indirectCoreCS} +${dispatchCoreCS} const MAX_TILE_ENTRIES: u32 = 4096u; const CLASSIFY_WORKGROUP: u32 = 256u; @@ -77,24 +79,21 @@ fn main(@builtin(local_invocation_index) localIdx: u32) { let maxDim = uniforms.maxWorkgroupsPerDim; // Slot 0: small tile sort — 1 workgroup per tile - var sy = (smallCount + maxDim - 1u) / maxDim; - sy = max(sy, 1u); - indirectDispatchArgs[off + 0u] = (smallCount + sy - 1u) / sy; - indirectDispatchArgs[off + 1u] = sy; + let smallDim = calcDispatch2D(smallCount, maxDim); + indirectDispatchArgs[off + 0u] = smallDim.x; + indirectDispatchArgs[off + 1u] = smallDim.y; indirectDispatchArgs[off + 2u] = 1u; // Slot 1: bucket pre-sort — 1 workgroup per large tile - var ly = (largeCount + maxDim - 1u) / maxDim; - ly = max(ly, 1u); - indirectDispatchArgs[off + 3u] = (largeCount + ly - 1u) / ly; - indirectDispatchArgs[off + 4u] = ly; + let largeDim = calcDispatch2D(largeCount, maxDim); + indirectDispatchArgs[off + 3u] = largeDim.x; + indirectDispatchArgs[off + 4u] = largeDim.y; indirectDispatchArgs[off + 5u] = 1u; // Slot 2: rasterize — 1 workgroup per non-empty tile - var ry = (rasterizeCount + maxDim - 1u) / maxDim; - ry = max(ry, 1u); - indirectDispatchArgs[off + 6u] = (rasterizeCount + ry - 1u) / ry; - indirectDispatchArgs[off + 7u] = ry; + let rasterDim = calcDispatch2D(rasterizeCount, maxDim); + indirectDispatchArgs[off + 6u] = rasterDim.x; + indirectDispatchArgs[off + 7u] = rasterDim.y; indirectDispatchArgs[off + 8u] = 1u; // Indirect draw args for tile-based composite: 6 vertices per tile quad diff --git a/src/scene/shader-lib/wgsl/chunks/gsplat/compute-gsplat-local-copy.js b/src/scene/shader-lib/wgsl/chunks/gsplat/compute-gsplat-local-copy.js index 886010966c5..e43b342d83d 100644 --- a/src/scene/shader-lib/wgsl/chunks/gsplat/compute-gsplat-local-copy.js +++ b/src/scene/shader-lib/wgsl/chunks/gsplat/compute-gsplat-local-copy.js @@ -5,8 +5,13 @@ // - The chunkSortIndirect args written here are visible to the chunk sort (next pass). // Clamps the count to maxChunks (matching the chunkRanges buffer budget) and writes a 2D // dispatch to stay within maxComputeWorkgroupsPerDimension. + +import dispatchCoreCS from '../common/comp/dispatch-core.js'; + export const computeGsplatLocalCopySource = /* wgsl */` +${dispatchCoreCS} + @group(0) @binding(0) var totalChunks: array; @group(0) @binding(1) var chunkSortIndirect: array; @@ -19,11 +24,9 @@ struct Uniforms { @compute @workgroup_size(1) fn main() { let count = min(totalChunks[0], uniforms.maxChunks); - let maxDim = uniforms.maxWorkgroupsPerDim; - var y = (count + maxDim - 1u) / maxDim; - y = max(y, 1u); - chunkSortIndirect[0] = (count + y - 1u) / y; - chunkSortIndirect[1] = y; + let dim = calcDispatch2D(count, uniforms.maxWorkgroupsPerDim); + chunkSortIndirect[0] = dim.x; + chunkSortIndirect[1] = dim.y; chunkSortIndirect[2] = 1u; } `; diff --git a/src/scene/shader-lib/wgsl/chunks/gsplat/compute-gsplat-write-indirect-args.js b/src/scene/shader-lib/wgsl/chunks/gsplat/compute-gsplat-write-indirect-args.js index 4cfd829bb18..9fe9a12a0a7 100644 --- a/src/scene/shader-lib/wgsl/chunks/gsplat/compute-gsplat-write-indirect-args.js +++ b/src/scene/shader-lib/wgsl/chunks/gsplat/compute-gsplat-write-indirect-args.js @@ -7,10 +7,12 @@ // index N equals the total number of visible splats. import indirectCoreCS from '../common/comp/indirect-core.js'; +import dispatchCoreCS from '../common/comp/dispatch-core.js'; export const computeGsplatWriteIndirectArgsSource = /* wgsl */` ${indirectCoreCS} +${dispatchCoreCS} // Prefix sum buffer (flagBuffer after in-place exclusive scan) @group(0) @binding(0) var prefixSumBuffer: array; @@ -55,17 +57,20 @@ fn main(@builtin(global_invocation_id) gid: vec3u) { // Write numSplats for vertex shader numSplatsBuf[0] = count; - // Write indirect dispatch args: slot 0 = key gen, slot 1 = sort + // Write indirect dispatch args: slot 0 = key gen, slot 1 = sort. + // Use 2D layout to stay within maxComputeWorkgroupsPerDimension. let dispatchOffset = uniforms.dispatchSlotOffset; let keygenWorkgroupCount = (count + {KEYGEN_THREADS_PER_WORKGROUP}u - 1u) / {KEYGEN_THREADS_PER_WORKGROUP}u; - indirectDispatchArgs[dispatchOffset + 0u] = keygenWorkgroupCount; - indirectDispatchArgs[dispatchOffset + 1u] = 1u; + let keygenDim = calcDispatch2D(keygenWorkgroupCount, {MAX_WORKGROUPS_PER_DIM}u); + indirectDispatchArgs[dispatchOffset + 0u] = keygenDim.x; + indirectDispatchArgs[dispatchOffset + 1u] = keygenDim.y; indirectDispatchArgs[dispatchOffset + 2u] = 1u; let sortWorkgroupCount = (count + {SORT_ELEMENTS_PER_WORKGROUP}u - 1u) / {SORT_ELEMENTS_PER_WORKGROUP}u; - indirectDispatchArgs[dispatchOffset + 3u] = sortWorkgroupCount; - indirectDispatchArgs[dispatchOffset + 4u] = 1u; + let sortDim = calcDispatch2D(sortWorkgroupCount, {MAX_WORKGROUPS_PER_DIM}u); + indirectDispatchArgs[dispatchOffset + 3u] = sortDim.x; + indirectDispatchArgs[dispatchOffset + 4u] = sortDim.y; indirectDispatchArgs[dispatchOffset + 5u] = 1u; // Write sortElementCount for sort shaders (= visibleCount) diff --git a/src/scene/shader-lib/wgsl/chunks/gsplat/frag/gsplatNodeCulling.js b/src/scene/shader-lib/wgsl/chunks/gsplat/frag/gsplatNodeCulling.js deleted file mode 100644 index bbd4d9f91f7..00000000000 --- a/src/scene/shader-lib/wgsl/chunks/gsplat/frag/gsplatNodeCulling.js +++ /dev/null @@ -1,96 +0,0 @@ -// Fragment shader for GPU frustum culling of bounding spheres. -// Each fragment processes 32 consecutive spheres and outputs a packed bitmask -// (bit b = 1 means sphere baseIndex+b is visible). The visibility texture is -// 32x smaller than the bounds texture. -export default /* wgsl */` -var boundsSphereTexture: texture_2d; -var boundsTransformIndexTexture: texture_2d; -var transformsTexture: texture_2d; - -uniform boundsTextureWidth: i32; -uniform transformsTextureWidth: i32; -uniform totalBoundsEntries: i32; -uniform frustumPlanes: array; - -@fragment -fn fragmentMain(input: FragmentInput) -> FragmentOutput { - var output: FragmentOutput; - - // Linear texel index in the (small) visibility texture - let visWidth = uniform.boundsTextureWidth / 32; - let texelIndex = i32(input.position.y) * visWidth + i32(input.position.x); - - // Base sphere index for this group of 32 - let baseIndex = texelIndex * 32; - - // Since boundsTextureWidth is a multiple of 32, all 32 spheres are on the same row. - // Compute row coordinates once. - let baseX = baseIndex % uniform.boundsTextureWidth; - let boundsY = baseIndex / uniform.boundsTextureWidth; - - var visBits = 0u; - var cachedTransformIdx = 0xFFFFFFFFu; - var row0: vec4f; - var row1: vec4f; - var row2: vec4f; - var worldMatrix: mat4x4f; - - for (var b = 0; b < 32; b++) { - let sphereIndex = baseIndex + b; - if (sphereIndex >= uniform.totalBoundsEntries) { break; } - - let boundsCoord = vec2i(baseX + b, boundsY); - - // Read local-space bounding sphere (center.xyz, radius) - let sphere = textureLoad(boundsSphereTexture, boundsCoord, 0); - let localCenter = sphere.xyz; - let radius = sphere.w; - - // Read GSplatInfo transform index - let transformIdx = textureLoad(boundsTransformIndexTexture, boundsCoord, 0).r; - - // Reconstruct world matrix only when transform index changes. - // The texture stores 3 texels per matrix (rows of a 4x3 affine matrix). - // Transpose back to column-major mat4 and append the implicit (0,0,0,1) row. - if (transformIdx != cachedTransformIdx) { - cachedTransformIdx = transformIdx; - let baseTexel = i32(transformIdx) * 3; - let tx = baseTexel % uniform.transformsTextureWidth; - let ty = baseTexel / uniform.transformsTextureWidth; - row0 = textureLoad(transformsTexture, vec2i(tx, ty), 0); - row1 = textureLoad(transformsTexture, vec2i(tx + 1, ty), 0); - row2 = textureLoad(transformsTexture, vec2i(tx + 2, ty), 0); - worldMatrix = mat4x4f( - row0.x, row1.x, row2.x, 0, - row0.y, row1.y, row2.y, 0, - row0.z, row1.z, row2.z, 0, - row0.w, row1.w, row2.w, 1.0 - ); - } - - // Transform sphere center to world space - let worldCenter = (worldMatrix * vec4f(localCenter, 1.0)).xyz; - - // World-space radius (uniform scale: all column lengths are equal) - let worldRadius = radius * length(vec3f(row0.x, row1.x, row2.x)); - - // Test against 6 frustum planes - var visible = true; - for (var p = 0; p < 6; p++) { - let plane = uniform.frustumPlanes[p]; - let dist = dot(plane.xyz, worldCenter) + plane.w; - if (dist <= -worldRadius) { - visible = false; - break; - } - } - - if (visible) { - visBits |= (1u << u32(b)); - } - } - - output.color = visBits; - return output; -} -`;