GPU-accelerate KNN and edge cost in the decimator#244
Conversation
There was a problem hiding this comment.
Pull request overview
This PR accelerates the decimator by moving KNN search and per-edge cost evaluation from CPU to WebGPU compute, while also reducing peak allocations (shared radix sort, scratch reuse, Float32 caches). It introduces a breaking API change: simplifyGaussians is now async and must be awaited by direct callers.
Changes:
- Add GPU implementations for KD-tree KNN (
GpuKnn) and edge cost evaluation (GpuEdgeCost), and thread an optionalcreateDevicefactory into the decimation pipeline. - Consolidate and reuse a Float32 radix sort implementation across rendering and decimation to reduce duplicate code and large temporary allocations.
- Optimize CPU merge path allocations (module-level scratch for
momentMatch, fewer transient typed array allocations), and update tests for the async signature.
Reviewed changes
Copilot reviewed 10 out of 10 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| test/decimate.test.mjs | Updates tests to await the now-async simplifyGaussians. |
| src/lib/spatial/radix-sort.ts | New shared radix-sort utility with reusable scratch buffers. |
| src/lib/spatial/kd-tree.ts | Adds KdTree.flattenForGpu() to export a GPU-friendly tree layout. |
| src/lib/spatial/index.ts | Re-exports the new radix-sort utilities. |
| src/lib/render/preprocess.ts | Switches depth sorting to shared radixSortIndicesByFloat. |
| src/lib/process.ts | Awaits simplifyGaussians and threads options.createDevice through. |
| src/lib/gpu/index.ts | Exports new GPU decimator helpers. |
| src/lib/gpu/gpu-knn.ts | New GPU KD-tree KNN compute implementation. |
| src/lib/gpu/gpu-edge-cost.ts | New GPU per-edge cost compute implementation. |
| src/lib/data-table/decimate.ts | Makes simplifyGaussians async, adds GPU paths, reduces CPU allocations, and uses shared radix sort. |
Comments suppressed due to low confidence (1)
src/lib/data-table/decimate.ts:651
- The GPU device is created (
await createDevice()) before verifying required columns. If a required column is missing, the function falls back to visibility pruning and returns, but the (potentially expensive) device creation has already happened unnecessarily. Consider deferringawait createDevice()until after the required-column check (or until you actually choose the GPU path).
// Mirrors the factory contract used by `filterFloaters` — caller hands us
// a `DeviceCreator`, we own creation here so multiple decimate actions
// don't each leak a device.
const device = createDevice ? await createDevice() : undefined;
const requiredCols = ['x', 'y', 'z', 'opacity', 'scale_0', 'scale_1', 'scale_2',
'rot_0', 'rot_1', 'rot_2', 'rot_3'];
for (const name of requiredCols) {
if (!dataTable.hasColumn(name)) {
logger.debug(`missing required column '${name}', falling back to visibility pruning`);
const indices = new Uint32Array(N);
for (let i = 0; i < N; i++) indices[i] = i;
sortByVisibility(dataTable, indices);
return dataTable.clone({ rows: indices.subarray(0, targetCount) });
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 10 out of 10 changed files in this pull request and generated 3 comments.
Comments suppressed due to low confidence (1)
src/lib/data-table/decimate.ts:652
createDevice()is awaited before validating required columns / deciding whether the GPU path will be used. If the input is missing required columns (fallback-to-visibility) orcreateDevice()throws (e.g., no WebGPU), this prevents the CPU fallback and/or does unnecessary GPU initialization. Consider moving device creation until after the required-column check (and only when the GPU path is actually needed).
// Mirrors the factory contract used by `filterFloaters` — caller hands us
// a `DeviceCreator`, we own creation here so multiple decimate actions
// don't each leak a device.
const device = createDevice ? await createDevice() : undefined;
const requiredCols = ['x', 'y', 'z', 'opacity', 'scale_0', 'scale_1', 'scale_2',
'rot_0', 'rot_1', 'rot_2', 'rot_3'];
for (const name of requiredCols) {
if (!dataTable.hasColumn(name)) {
logger.debug(`missing required column '${name}', falling back to visibility pruning`);
const indices = new Uint32Array(N);
for (let i = 0; i < N; i++) indices[i] = i;
sortByVisibility(dataTable, indices);
return dataTable.clone({ rows: indices.subarray(0, targetCount) });
}
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 10 out of 10 changed files in this pull request and generated 3 comments.
Comments suppressed due to low confidence (1)
src/lib/data-table/decimate.ts:667
createDeviceis awaited before validating required columns / deciding to fall back tosortByVisibility. This means GPU device creation can happen even when the decimator will immediately take the fallback path, which is expensive and may allocate resources unnecessarily. Consider movingawait createDevice()until after the required-column check (and any other early-return conditions) so callers only pay the GPU setup cost when the GPU path can actually run.
// Mirrors the factory contract used by `filterFloaters` — caller hands us
// a `DeviceCreator`, we own creation here so multiple decimate actions
// don't each leak a device.
const device = createDevice ? await createDevice() : undefined;
const requiredCols = ['x', 'y', 'z', 'opacity', 'scale_0', 'scale_1', 'scale_2',
'rot_0', 'rot_1', 'rot_2', 'rot_3'];
for (const name of requiredCols) {
if (!dataTable.hasColumn(name)) {
logger.debug(`missing required column '${name}', falling back to visibility pruning`);
const indices = new Uint32Array(N);
for (let i = 0; i < N; i++) indices[i] = i;
sortByVisibility(dataTable, indices);
return dataTable.clone({ rows: indices.subarray(0, targetCount) });
}
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 10 out of 10 changed files in this pull request and generated 3 comments.
Comments suppressed due to low confidence (1)
src/lib/data-table/decimate.ts:667
createDeviceis awaited before validating required columns / deciding whether to fall back to visibility pruning. If the input is missing a required column (e.g. rotation columns) andcreateDeviceis provided, this will still create a GPU device even though the GPU path won’t be used. Consider moving device creation to after the required-column check (and any other early-return/fallback decisions) so the factory is only invoked when GPU execution is actually possible.
// Mirrors the factory contract used by `filterFloaters` — caller hands us
// a `DeviceCreator`, we own creation here so multiple decimate actions
// don't each leak a device.
const device = createDevice ? await createDevice() : undefined;
const requiredCols = ['x', 'y', 'z', 'opacity', 'scale_0', 'scale_1', 'scale_2',
'rot_0', 'rot_1', 'rot_2', 'rot_3'];
for (const name of requiredCols) {
if (!dataTable.hasColumn(name)) {
logger.debug(`missing required column '${name}', falling back to visibility pruning`);
const indices = new Uint32Array(N);
for (let i = 0; i < N; i++) indices[i] = i;
sortByVisibility(dataTable, indices);
return dataTable.clone({ rows: indices.subarray(0, targetCount) });
}
Summary
Moves the two dominant phases of
simplifyGaussiansonto the GPU. On the 17.9M-splat windmill scene at 50%, total wall time drops from ~2m25s to ~32s (~4.5× faster) and peak RAM from 11.1 GB to 8.5 GB. Output is PSNR-equivalent to the CPU path on real scenes (26.19 dB vs reference for both; tiny byte-level differences come from Float32 vs Float64 in cost evaluation and resolve to different tie-breaking in greedy pair selection).The KNN port (
src/lib/gpu/gpu-knn.ts) flattens the existing CPUKdTreeinto a typed-array representation (newKdTree.flattenForGpu()), uploads it once, then runs an iterative DFS in a WGSL compute shader — one thread per query, per-thread stack of 48 entries, top-K maintained unsorted with worst-index tracking so the dominant candidate-rejection path is a single compare. SameO(N log N)total work as the CPU KD-tree, parallelised across queries. Replaces the 92 s CPU loop with ~10 s of GPU work on windmill.The edge-cost port (
src/lib/gpu/gpu-edge-cost.ts) mirrorscomputeEdgeCostexactly — merged covariance / determinant / single Monte-Carlo sample / log-add-exp + L2 over SH coefficients — one thread per edge. The per-splat cache is packed into three buffers (interleaved positions, row-major R, 5-wide scalars) to stay under the WebGPU per-stage 10-storage-buffer limit. Replaces the 20 s CPU loop with ~2 s of GPU work.simplifyGaussiansis now async and accepts an optionalcreateDevice?: DeviceCreatorfactory (matches the pattern used byfilterFloaters);processDataTablethreadsoptions.createDevicethrough, falling back to the existing CPU KD-tree when no device is supplied. The async signature is a breaking change for direct callers —awaitis required.Several CPU-side wins came along for the ride and apply to the CPU fallback path too: a shared
radixSortIndicesByFloat(insrc/lib/spatial/radix-sort.ts) replaces the duplicated 4-pass LSD radix-sort impls in the rasterizer (render/preprocess.ts) and decimator; module-level scratch buffers inmomentMatcheliminate ~5 GB of throwaway per-call allocation on a 17.9M run (this alone cut merge phase from 13.7 s → 6.1 s on the windmill); per-splat cache is Float32 throughout (~860 MB saved on the cache); and aggressive reference-nulling on the giant edge/KNN/cache buffers lets V8 reclaim them before the merge phase pushes peak. The shared sort also fixes a small consolidation point — five separate radix-sort sites collapse to one.GpuEdgeCost sizes its edge buffers to
n · k(the true upper bound, not then · k / 2expected count) — variance in the directed-edge filter (j > i) lets the actual edgeCount exceedn · k / 2by a few percent, which the CPU path handles via dynamic growth but the fixed-size GPU buffers cannot.Behavior change — opacity pre-pruning removed. The previous
simplifyGaussiansstarted with a median-based opacity pruning pass (drop splats withsigmoid(opacity) < min(0.1, median)before merging). Investigation showed this caused the visible darkening / desaturation on dense scenes: on windmill at 50% reduction, pruning removed 21% of splats (3.75M) carrying 3.84% of totalα·areamass — and because the dropped splats were spatially concentrated, the loss read as a ~9-unit luma drop (PSNR 23.0) vs the un-decimated reference. The merge step alone is mass-conserving, so removing the pruning lifts windmill 50% from PSNR 23.01 → 27.69 (ΔLuma −9.33 → −0.25) and 25% from 19.83 → 23.39 (ΔLuma −12.69 → −2.81). Net cost is ~25% more KNN/edge-cost work (those low-α splats now participate in the merge), which the new GPU path absorbs comfortably.Build clean, all 490 existing tests pass. The
axis-sorted-knnscaffolding from an earlier exploration was removed before this PR.