Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 66 additions & 18 deletions examples/src/examples/test/radix-sort-indirect-compute.example.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ function logError(msg) {
statusOverlay.textContent = logLines.join('\n');
}

// Create radix sort instance
radixSort = new pc.ComputeRadixSort(device);
// Create radix sort instance in indirect mode — sortIndirect requires this.
radixSort = new pc.ComputeRadixSort(device, { indirect: true });

// Create sortElementCount buffer (single u32, GPU-readable storage buffer)
sortElementCountBuffer = new pc.StorageBuffer(device, 4, pc.BUFFERUSAGE_COPY_SRC | pc.BUFFERUSAGE_COPY_DST);
Expand All @@ -110,13 +110,23 @@ sortElementCountBuffer = new pc.StorageBuffer(device, 4, pc.BUFFERUSAGE_COPY_SRC
// Simulates the GSplat pipeline's prepareIndirect: a compute shader writes
// sortElementCount and indirect dispatch args within the command buffer
// (instead of queue.writeBuffer which executes before the command buffer).
//
// The shader mirrors the sortIndirectArgsCS WGSL chunk: it uses the slot
// count and per-slot granularities from ComputeRadixSort#prepareIndirect()
// to support any backend (Multipass = 1 slot, OneSweep = 2 slots).

const prepareSource = /* wgsl */`
@group(0) @binding(0) var<storage, read_write> sortElementCountBuf: array<u32>;
@group(0) @binding(1) var<storage, read_write> indirectDispatchArgs: array<u32>;
struct PrepareUniforms {
visibleCount: u32,
dispatchSlotOffset: u32
visibleCount: u32,
sortSlotBase: u32, // first slot index (not u32 offset)
_pad0: u32,
_pad1: u32,
sortInfoX: u32, // slotCount (from prepareIndirect()[0])
sortInfoY: u32, // granularity for slot 0
sortInfoZ: u32, // granularity for slot 1 (0 if unused)
sortInfoW: u32 // granularity for slot 2 (0 if unused)
};
@group(0) @binding(2) var<uniform> uniforms: PrepareUniforms;

Expand All @@ -125,17 +135,42 @@ const prepareSource = /* wgsl */`
let count = uniforms.visibleCount;
sortElementCountBuf[0] = count;

let sortWorkgroupCount = (count + 255u) / 256u;
let offset = uniforms.dispatchSlotOffset;
indirectDispatchArgs[offset + 0u] = sortWorkgroupCount;
indirectDispatchArgs[offset + 1u] = 1u;
indirectDispatchArgs[offset + 2u] = 1u;
let base = uniforms.sortSlotBase;
let n = uniforms.sortInfoX;

if (n >= 1u) {
let g = uniforms.sortInfoY;
let wc = (count + g - 1u) / g;
indirectDispatchArgs[base * 3u + 0u] = wc;
indirectDispatchArgs[base * 3u + 1u] = 1u;
indirectDispatchArgs[base * 3u + 2u] = 1u;
}
if (n >= 2u) {
let g = uniforms.sortInfoZ;
let wc = (count + g - 1u) / g;
indirectDispatchArgs[(base + 1u) * 3u + 0u] = wc;
indirectDispatchArgs[(base + 1u) * 3u + 1u] = 1u;
indirectDispatchArgs[(base + 1u) * 3u + 2u] = 1u;
}
if (n >= 3u) {
let g = uniforms.sortInfoW;
let wc = (count + g - 1u) / g;
indirectDispatchArgs[(base + 2u) * 3u + 0u] = wc;
indirectDispatchArgs[(base + 2u) * 3u + 1u] = 1u;
indirectDispatchArgs[(base + 2u) * 3u + 2u] = 1u;
}
}
`;

const prepareUniformFormat = new pc.UniformBufferFormat(device, [
new pc.UniformFormat('visibleCount', pc.UNIFORMTYPE_UINT),
new pc.UniformFormat('dispatchSlotOffset', pc.UNIFORMTYPE_UINT)
new pc.UniformFormat('sortSlotBase', pc.UNIFORMTYPE_UINT),
new pc.UniformFormat('_pad0', pc.UNIFORMTYPE_UINT),
new pc.UniformFormat('_pad1', pc.UNIFORMTYPE_UINT),
new pc.UniformFormat('sortInfoX', pc.UNIFORMTYPE_UINT),
new pc.UniformFormat('sortInfoY', pc.UNIFORMTYPE_UINT),
new pc.UniformFormat('sortInfoZ', pc.UNIFORMTYPE_UINT),
new pc.UniformFormat('sortInfoW', pc.UNIFORMTYPE_UINT)
]);

const prepareBindGroupFormat = new pc.BindGroupFormat(device, [
Expand Down Expand Up @@ -284,7 +319,11 @@ data.on('*:set', (/** @type {string} */ path, /** @type {any} */ value) => {
needsRegen = true;
}
} else if (path === 'options.bits') {
const validBits = [4, 8, 12, 16, 20, 24, 28, 32];
// Only accept multiples of the active backend's radix width
// (4 for Multipass, 8 for OneSweep). Snapping to an invalid
// multiple would trigger the sortIndirect assert every frame.
const rb = radixSort ? radixSort.radixBits : 4;
const validBits = [4, 8, 12, 16, 20, 24, 28, 32].filter(b => b % rb === 0);
const nearest = validBits.reduce((prev, curr) => (Math.abs(curr - value) < Math.abs(prev - value) ? curr : prev));
if (nearest !== currentNumBits) {
currentNumBits = nearest;
Expand Down Expand Up @@ -331,17 +370,26 @@ app.on('update', () => {

if (!keysBuffer || !radixSort || !sortElementCountBuffer) return;

// Allocate per-frame indirect dispatch slot
const dispatchSlot = device.getIndirectDispatchSlot(1);
// Query backend slot requirements — varies by backend:
// Multipass: [1, 2048, 0, 0] (1 slot; see ELEMENTS_PER_WORKGROUP)
// OneSweep: [2, 3840, 32768, 0] (2 slots: binning + globalHist)
const sortInfo = radixSort.prepareIndirect();
const slotCount = sortInfo[0];

// Write sortElementCount and dispatch args via compute shader
const dispatchBuffer = device.indirectDispatchBuffer;
const slotOffset = dispatchSlot * 3;
// Allocate the required number of consecutive per-frame dispatch slots.
const dispatchSlot = device.getIndirectDispatchSlot(slotCount);

// Write sortElementCount and all dispatch slot args via compute shader.
prepareCompute.setParameter('sortElementCountBuf', sortElementCountBuffer);
prepareCompute.setParameter('indirectDispatchArgs', dispatchBuffer);
prepareCompute.setParameter('indirectDispatchArgs', device.indirectDispatchBuffer);
prepareCompute.setParameter('visibleCount', visibleCount);
prepareCompute.setParameter('dispatchSlotOffset', slotOffset);
prepareCompute.setParameter('sortSlotBase', dispatchSlot);
prepareCompute.setParameter('_pad0', 0);
prepareCompute.setParameter('_pad1', 0);
prepareCompute.setParameter('sortInfoX', sortInfo[0]);
prepareCompute.setParameter('sortInfoY', sortInfo[1]);
prepareCompute.setParameter('sortInfoZ', sortInfo[2]);
prepareCompute.setParameter('sortInfoW', sortInfo[3]);
prepareCompute.setupDispatch(1, 1, 1);
device.computeDispatch([prepareCompute], 'PrepareIndirectTest');

Expand Down
22 changes: 16 additions & 6 deletions src/scene/graphics/radix-sort/compute-radix-sort-onesweep.js
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,14 @@ class ComputeRadixSortOneSweep extends ComputeRadixSortBase {
// 64 / 128 lane subgroups (AMD Wave64, future wider architectures)
// those expressions silently truncate / UB and will corrupt the
// sort output. Refuse to run rather than producing garbage.
Debug.assert(device.minSubgroupSize > 0, 'ComputeRadixSortOneSweep requires a valid minimum subgroup size');
Debug.assert(device.maxSubgroupSize <= 32, 'ComputeRadixSortOneSweep currently requires subgroup sizes <= 32 (binning shader uses 32-bit subgroup masks)');
//
// We only check the *minimum* runtime size: NVIDIA hardware always
// executes with 32-wide warps even though the adapter may report a
// higher max (e.g. Dawn / D3D12 surfaces a spec ceiling of 128).
// A `minSubgroupSize` of 0 means the adapter omitted the field
// entirely (older Chrome / Dawn on Windows); accept that — runtime
// size on validated NVIDIA targets is still 32.
Debug.assert(device.minSubgroupSize <= 32, 'ComputeRadixSortOneSweep currently requires runtime subgroup size <= 32 (binning shader uses 32-bit subgroup masks)');

// Create uniform formats (shared between direct and indirect modes), then
// create bind group formats and shaders for the chosen mode only.
Expand Down Expand Up @@ -203,7 +209,6 @@ class ComputeRadixSortOneSweep extends ComputeRadixSortBase {
const maxSubgroups = Math.max(1, Math.ceil(256 / minSubgroupSize));

const suffix = indirect ? 'Indirect' : '';
const elementCountBinding = new BindStorageBufferFormat('b_sortElementCount', SHADERSTAGE_COMPUTE, true);

const histGroupEntries = [
new BindStorageBufferFormat('b_sort', SHADERSTAGE_COMPUTE, true),
Expand All @@ -225,9 +230,14 @@ class ComputeRadixSortOneSweep extends ComputeRadixSortBase {
new BindUniformBufferFormat('uniforms', SHADERSTAGE_COMPUTE)
];
if (indirect) {
histGroupEntries.push(elementCountBinding);
scanGroupEntries.push(elementCountBinding);
binGroupEntries.push(elementCountBinding);
// Each BindStorageBufferFormat must be a separate instance: BindGroupFormat
// assigns slot numbers by mutating the objects in-place, so sharing a single
// instance across multiple formats would cause the last format to overwrite
// the slot on the shared object, producing wrong binding indices for the
// earlier formats when their bind groups are created.
histGroupEntries.push(new BindStorageBufferFormat('b_sortElementCount', SHADERSTAGE_COMPUTE, true));
scanGroupEntries.push(new BindStorageBufferFormat('b_sortElementCount', SHADERSTAGE_COMPUTE, true));
binGroupEntries.push(new BindStorageBufferFormat('b_sortElementCount', SHADERSTAGE_COMPUTE, true));
}

this._globalHistBindGroupFormat = new BindGroupFormat(device, histGroupEntries);
Expand Down
18 changes: 12 additions & 6 deletions src/scene/graphics/radix-sort/compute-radix-sort.js
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,13 @@ class ComputeRadixSort {
}

if (chosen === RADIX_SORT_ONESWEEP) {
// Hard hardware prerequisites (compute, subgroups, subgroupSize
// <= 32, valid minSubgroupSize) are asserted inside the OneSweep
// Hard hardware prerequisites (compute, subgroups, runtime
// subgroup size <= 32) are asserted inside the OneSweep
// constructor. Here we only warn on soft policy mismatches
// (non-NVIDIA vendors), so callers can opt in for experimentation
// on devices where OneSweep has not been validated.
if (!this._canUseOneSweep(device)) {
Debug.warnOnce('ComputeRadixSort: RADIX_SORT_ONESWEEP requested on a device that is not a validated OneSweep target (non-NVIDIA, or missing compute / subgroups / subgroupSize <= 32). OneSweep may hang or produce incorrect results. Consider RADIX_SORT_PORTABLE or RADIX_SORT_AUTO.');
Debug.warnOnce('ComputeRadixSort: RADIX_SORT_ONESWEEP requested on a device that is not a validated OneSweep target (non-NVIDIA, or minSubgroupSize > 32). OneSweep may hang or produce incorrect results. Consider RADIX_SORT_PORTABLE or RADIX_SORT_AUTO.');
}
this._impl = new ComputeRadixSortOneSweep(device, indirect);
} else {
Expand All @@ -103,14 +103,20 @@ class ComputeRadixSort {
*/
_canUseOneSweep(device) {
if (!device.supportsCompute || !device.supportsSubgroups) return false;
// Adapter info may omit subgroup sizes (both stay 0); do not auto-select OneSweep then.
if (!device.minSubgroupSize || !device.maxSubgroupSize || device.maxSubgroupSize > 32) return false;
// Only enable on NVIDIA for now; validated on Turing+ and Ampere.
// Other vendors either lack forward-progress guarantees (Apple) or
// have shown correctness issues in the lookback (Mali / Imagination /
// some Adreno).
const vendor = device.gpuAdapter?.info?.vendor?.toLowerCase?.();
return vendor === 'nvidia';
if (vendor !== 'nvidia') return false;
// NVIDIA always executes with 32-wide warps. The adapter may report
// a wider max (Dawn / D3D12 surfaces a spec ceiling of 128) or omit
// size info entirely (some Chrome / Dawn builds on Windows report 0).
// Only refuse if the *minimum* runtime size is guaranteed to be > 32,
// which would mean we'd actually receive >32 lanes and the binning
// shader's 32-bit ballot masks would corrupt the output.
if (device.minSubgroupSize > 32) return false;
return true;
}

/**
Expand Down
10 changes: 7 additions & 3 deletions src/scene/gsplat-unified/gsplat-manager.js
Original file line number Diff line number Diff line change
Expand Up @@ -1611,8 +1611,12 @@ class GSplatManager {

// number of bits used for sorting to match CPU sorter
const numBits = Math.max(10, Math.min(20, Math.round(Math.log2(elementCount / 4))));
// Round up to multiple of 4 for radix sort
const roundedNumBits = Math.ceil(numBits / 4) * 4;
// Round up to a multiple of the active sorter's radix width (4 for the
// portable multipass backend, 8 for OneSweep). Multipass would accept
// 4-multiples, but OneSweep asserts when numBits is not a multiple of
// 8, so we always align to the backend's actual radixBits.
const radixBits = gpuSorter.radixBits;
const roundedNumBits = Math.ceil(numBits / radixBits) * radixBits;

// Compute min/max distances for key normalization
const { minDist, maxDist } = this.computeDistanceRange(worldState);
Expand Down Expand Up @@ -1706,7 +1710,7 @@ class GSplatManager {
* (sorting only the visible splat count determined by interval compaction).
*
* @param {number} elementCount - Total number of splats.
* @param {number} roundedNumBits - Number of sort bits (rounded to multiple of 4).
* @param {number} roundedNumBits - Number of sort bits aligned to the active GPU radix sorter width (4 or 8).
* @param {number} minDist - Minimum distance for key normalization.
* @param {number} maxDist - Maximum distance for key normalization.
* @param {StorageBuffer|null} compactedSplatIds - Compacted splat IDs from interval compaction.
Expand Down