diff --git a/examples/src/examples/test/radix-sort-indirect-compute.example.mjs b/examples/src/examples/test/radix-sort-indirect-compute.example.mjs index a2482c2928c..23d72104503 100644 --- a/examples/src/examples/test/radix-sort-indirect-compute.example.mjs +++ b/examples/src/examples/test/radix-sort-indirect-compute.example.mjs @@ -100,8 +100,8 @@ function logError(msg) { statusOverlay.textContent = logLines.join('\n'); } -// Create radix sort instance -radixSort = new pc.ComputeRadixSort(device); +// Create radix sort instance in indirect mode — sortIndirect requires this. +radixSort = new pc.ComputeRadixSort(device, { indirect: true }); // Create sortElementCount buffer (single u32, GPU-readable storage buffer) sortElementCountBuffer = new pc.StorageBuffer(device, 4, pc.BUFFERUSAGE_COPY_SRC | pc.BUFFERUSAGE_COPY_DST); @@ -110,13 +110,23 @@ sortElementCountBuffer = new pc.StorageBuffer(device, 4, pc.BUFFERUSAGE_COPY_SRC // Simulates the GSplat pipeline's prepareIndirect: a compute shader writes // sortElementCount and indirect dispatch args within the command buffer // (instead of queue.writeBuffer which executes before the command buffer). +// +// The shader mirrors the sortIndirectArgsCS WGSL chunk: it uses the slot +// count and per-slot granularities from ComputeRadixSort#prepareIndirect() +// to support any backend (Multipass = 1 slot, OneSweep = 2 slots). const prepareSource = /* wgsl */` @group(0) @binding(0) var sortElementCountBuf: array; @group(0) @binding(1) var indirectDispatchArgs: array; struct PrepareUniforms { - visibleCount: u32, - dispatchSlotOffset: u32 + visibleCount: u32, + sortSlotBase: u32, // first slot index (not u32 offset) + _pad0: u32, + _pad1: u32, + sortInfoX: u32, // slotCount (from prepareIndirect()[0]) + sortInfoY: u32, // granularity for slot 0 + sortInfoZ: u32, // granularity for slot 1 (0 if unused) + sortInfoW: u32 // granularity for slot 2 (0 if unused) }; @group(0) @binding(2) var uniforms: PrepareUniforms; @@ -125,17 +135,42 @@ const prepareSource = /* wgsl */` let count = uniforms.visibleCount; sortElementCountBuf[0] = count; - let sortWorkgroupCount = (count + 255u) / 256u; - let offset = uniforms.dispatchSlotOffset; - indirectDispatchArgs[offset + 0u] = sortWorkgroupCount; - indirectDispatchArgs[offset + 1u] = 1u; - indirectDispatchArgs[offset + 2u] = 1u; + let base = uniforms.sortSlotBase; + let n = uniforms.sortInfoX; + + if (n >= 1u) { + let g = uniforms.sortInfoY; + let wc = (count + g - 1u) / g; + indirectDispatchArgs[base * 3u + 0u] = wc; + indirectDispatchArgs[base * 3u + 1u] = 1u; + indirectDispatchArgs[base * 3u + 2u] = 1u; + } + if (n >= 2u) { + let g = uniforms.sortInfoZ; + let wc = (count + g - 1u) / g; + indirectDispatchArgs[(base + 1u) * 3u + 0u] = wc; + indirectDispatchArgs[(base + 1u) * 3u + 1u] = 1u; + indirectDispatchArgs[(base + 1u) * 3u + 2u] = 1u; + } + if (n >= 3u) { + let g = uniforms.sortInfoW; + let wc = (count + g - 1u) / g; + indirectDispatchArgs[(base + 2u) * 3u + 0u] = wc; + indirectDispatchArgs[(base + 2u) * 3u + 1u] = 1u; + indirectDispatchArgs[(base + 2u) * 3u + 2u] = 1u; + } } `; const prepareUniformFormat = new pc.UniformBufferFormat(device, [ new pc.UniformFormat('visibleCount', pc.UNIFORMTYPE_UINT), - new pc.UniformFormat('dispatchSlotOffset', pc.UNIFORMTYPE_UINT) + new pc.UniformFormat('sortSlotBase', pc.UNIFORMTYPE_UINT), + new pc.UniformFormat('_pad0', pc.UNIFORMTYPE_UINT), + new pc.UniformFormat('_pad1', pc.UNIFORMTYPE_UINT), + new pc.UniformFormat('sortInfoX', pc.UNIFORMTYPE_UINT), + new pc.UniformFormat('sortInfoY', pc.UNIFORMTYPE_UINT), + new pc.UniformFormat('sortInfoZ', pc.UNIFORMTYPE_UINT), + new pc.UniformFormat('sortInfoW', pc.UNIFORMTYPE_UINT) ]); const prepareBindGroupFormat = new pc.BindGroupFormat(device, [ @@ -284,7 +319,11 @@ data.on('*:set', (/** @type {string} */ path, /** @type {any} */ value) => { needsRegen = true; } } else if (path === 'options.bits') { - const validBits = [4, 8, 12, 16, 20, 24, 28, 32]; + // Only accept multiples of the active backend's radix width + // (4 for Multipass, 8 for OneSweep). Snapping to an invalid + // multiple would trigger the sortIndirect assert every frame. + const rb = radixSort ? radixSort.radixBits : 4; + const validBits = [4, 8, 12, 16, 20, 24, 28, 32].filter(b => b % rb === 0); const nearest = validBits.reduce((prev, curr) => (Math.abs(curr - value) < Math.abs(prev - value) ? curr : prev)); if (nearest !== currentNumBits) { currentNumBits = nearest; @@ -331,17 +370,26 @@ app.on('update', () => { if (!keysBuffer || !radixSort || !sortElementCountBuffer) return; - // Allocate per-frame indirect dispatch slot - const dispatchSlot = device.getIndirectDispatchSlot(1); + // Query backend slot requirements — varies by backend: + // Multipass: [1, 2048, 0, 0] (1 slot; see ELEMENTS_PER_WORKGROUP) + // OneSweep: [2, 3840, 32768, 0] (2 slots: binning + globalHist) + const sortInfo = radixSort.prepareIndirect(); + const slotCount = sortInfo[0]; - // Write sortElementCount and dispatch args via compute shader - const dispatchBuffer = device.indirectDispatchBuffer; - const slotOffset = dispatchSlot * 3; + // Allocate the required number of consecutive per-frame dispatch slots. + const dispatchSlot = device.getIndirectDispatchSlot(slotCount); + // Write sortElementCount and all dispatch slot args via compute shader. prepareCompute.setParameter('sortElementCountBuf', sortElementCountBuffer); - prepareCompute.setParameter('indirectDispatchArgs', dispatchBuffer); + prepareCompute.setParameter('indirectDispatchArgs', device.indirectDispatchBuffer); prepareCompute.setParameter('visibleCount', visibleCount); - prepareCompute.setParameter('dispatchSlotOffset', slotOffset); + prepareCompute.setParameter('sortSlotBase', dispatchSlot); + prepareCompute.setParameter('_pad0', 0); + prepareCompute.setParameter('_pad1', 0); + prepareCompute.setParameter('sortInfoX', sortInfo[0]); + prepareCompute.setParameter('sortInfoY', sortInfo[1]); + prepareCompute.setParameter('sortInfoZ', sortInfo[2]); + prepareCompute.setParameter('sortInfoW', sortInfo[3]); prepareCompute.setupDispatch(1, 1, 1); device.computeDispatch([prepareCompute], 'PrepareIndirectTest'); diff --git a/src/scene/graphics/radix-sort/compute-radix-sort-onesweep.js b/src/scene/graphics/radix-sort/compute-radix-sort-onesweep.js index 93bd2c7c766..1c5ef5ad40f 100644 --- a/src/scene/graphics/radix-sort/compute-radix-sort-onesweep.js +++ b/src/scene/graphics/radix-sort/compute-radix-sort-onesweep.js @@ -173,8 +173,14 @@ class ComputeRadixSortOneSweep extends ComputeRadixSortBase { // 64 / 128 lane subgroups (AMD Wave64, future wider architectures) // those expressions silently truncate / UB and will corrupt the // sort output. Refuse to run rather than producing garbage. - Debug.assert(device.minSubgroupSize > 0, 'ComputeRadixSortOneSweep requires a valid minimum subgroup size'); - Debug.assert(device.maxSubgroupSize <= 32, 'ComputeRadixSortOneSweep currently requires subgroup sizes <= 32 (binning shader uses 32-bit subgroup masks)'); + // + // We only check the *minimum* runtime size: NVIDIA hardware always + // executes with 32-wide warps even though the adapter may report a + // higher max (e.g. Dawn / D3D12 surfaces a spec ceiling of 128). + // A `minSubgroupSize` of 0 means the adapter omitted the field + // entirely (older Chrome / Dawn on Windows); accept that — runtime + // size on validated NVIDIA targets is still 32. + Debug.assert(device.minSubgroupSize <= 32, 'ComputeRadixSortOneSweep currently requires runtime subgroup size <= 32 (binning shader uses 32-bit subgroup masks)'); // Create uniform formats (shared between direct and indirect modes), then // create bind group formats and shaders for the chosen mode only. @@ -203,7 +209,6 @@ class ComputeRadixSortOneSweep extends ComputeRadixSortBase { const maxSubgroups = Math.max(1, Math.ceil(256 / minSubgroupSize)); const suffix = indirect ? 'Indirect' : ''; - const elementCountBinding = new BindStorageBufferFormat('b_sortElementCount', SHADERSTAGE_COMPUTE, true); const histGroupEntries = [ new BindStorageBufferFormat('b_sort', SHADERSTAGE_COMPUTE, true), @@ -225,9 +230,14 @@ class ComputeRadixSortOneSweep extends ComputeRadixSortBase { new BindUniformBufferFormat('uniforms', SHADERSTAGE_COMPUTE) ]; if (indirect) { - histGroupEntries.push(elementCountBinding); - scanGroupEntries.push(elementCountBinding); - binGroupEntries.push(elementCountBinding); + // Each BindStorageBufferFormat must be a separate instance: BindGroupFormat + // assigns slot numbers by mutating the objects in-place, so sharing a single + // instance across multiple formats would cause the last format to overwrite + // the slot on the shared object, producing wrong binding indices for the + // earlier formats when their bind groups are created. + histGroupEntries.push(new BindStorageBufferFormat('b_sortElementCount', SHADERSTAGE_COMPUTE, true)); + scanGroupEntries.push(new BindStorageBufferFormat('b_sortElementCount', SHADERSTAGE_COMPUTE, true)); + binGroupEntries.push(new BindStorageBufferFormat('b_sortElementCount', SHADERSTAGE_COMPUTE, true)); } this._globalHistBindGroupFormat = new BindGroupFormat(device, histGroupEntries); diff --git a/src/scene/graphics/radix-sort/compute-radix-sort.js b/src/scene/graphics/radix-sort/compute-radix-sort.js index 321b4f739cc..e5d76966d7d 100644 --- a/src/scene/graphics/radix-sort/compute-radix-sort.js +++ b/src/scene/graphics/radix-sort/compute-radix-sort.js @@ -77,13 +77,13 @@ class ComputeRadixSort { } if (chosen === RADIX_SORT_ONESWEEP) { - // Hard hardware prerequisites (compute, subgroups, subgroupSize - // <= 32, valid minSubgroupSize) are asserted inside the OneSweep + // Hard hardware prerequisites (compute, subgroups, runtime + // subgroup size <= 32) are asserted inside the OneSweep // constructor. Here we only warn on soft policy mismatches // (non-NVIDIA vendors), so callers can opt in for experimentation // on devices where OneSweep has not been validated. if (!this._canUseOneSweep(device)) { - Debug.warnOnce('ComputeRadixSort: RADIX_SORT_ONESWEEP requested on a device that is not a validated OneSweep target (non-NVIDIA, or missing compute / subgroups / subgroupSize <= 32). OneSweep may hang or produce incorrect results. Consider RADIX_SORT_PORTABLE or RADIX_SORT_AUTO.'); + Debug.warnOnce('ComputeRadixSort: RADIX_SORT_ONESWEEP requested on a device that is not a validated OneSweep target (non-NVIDIA, or minSubgroupSize > 32). OneSweep may hang or produce incorrect results. Consider RADIX_SORT_PORTABLE or RADIX_SORT_AUTO.'); } this._impl = new ComputeRadixSortOneSweep(device, indirect); } else { @@ -103,14 +103,20 @@ class ComputeRadixSort { */ _canUseOneSweep(device) { if (!device.supportsCompute || !device.supportsSubgroups) return false; - // Adapter info may omit subgroup sizes (both stay 0); do not auto-select OneSweep then. - if (!device.minSubgroupSize || !device.maxSubgroupSize || device.maxSubgroupSize > 32) return false; // Only enable on NVIDIA for now; validated on Turing+ and Ampere. // Other vendors either lack forward-progress guarantees (Apple) or // have shown correctness issues in the lookback (Mali / Imagination / // some Adreno). const vendor = device.gpuAdapter?.info?.vendor?.toLowerCase?.(); - return vendor === 'nvidia'; + if (vendor !== 'nvidia') return false; + // NVIDIA always executes with 32-wide warps. The adapter may report + // a wider max (Dawn / D3D12 surfaces a spec ceiling of 128) or omit + // size info entirely (some Chrome / Dawn builds on Windows report 0). + // Only refuse if the *minimum* runtime size is guaranteed to be > 32, + // which would mean we'd actually receive >32 lanes and the binning + // shader's 32-bit ballot masks would corrupt the output. + if (device.minSubgroupSize > 32) return false; + return true; } /** diff --git a/src/scene/gsplat-unified/gsplat-manager.js b/src/scene/gsplat-unified/gsplat-manager.js index c27465552c0..597db9340f7 100644 --- a/src/scene/gsplat-unified/gsplat-manager.js +++ b/src/scene/gsplat-unified/gsplat-manager.js @@ -1611,8 +1611,12 @@ class GSplatManager { // number of bits used for sorting to match CPU sorter const numBits = Math.max(10, Math.min(20, Math.round(Math.log2(elementCount / 4)))); - // Round up to multiple of 4 for radix sort - const roundedNumBits = Math.ceil(numBits / 4) * 4; + // Round up to a multiple of the active sorter's radix width (4 for the + // portable multipass backend, 8 for OneSweep). Multipass would accept + // 4-multiples, but OneSweep asserts when numBits is not a multiple of + // 8, so we always align to the backend's actual radixBits. + const radixBits = gpuSorter.radixBits; + const roundedNumBits = Math.ceil(numBits / radixBits) * radixBits; // Compute min/max distances for key normalization const { minDist, maxDist } = this.computeDistanceRange(worldState); @@ -1706,7 +1710,7 @@ class GSplatManager { * (sorting only the visible splat count determined by interval compaction). * * @param {number} elementCount - Total number of splats. - * @param {number} roundedNumBits - Number of sort bits (rounded to multiple of 4). + * @param {number} roundedNumBits - Number of sort bits aligned to the active GPU radix sorter width (4 or 8). * @param {number} minDist - Minimum distance for key normalization. * @param {number} maxDist - Maximum distance for key normalization. * @param {StorageBuffer|null} compactedSplatIds - Compacted splat IDs from interval compaction.