From 37c7c7e0690b28b69f61b2be3d025241efb56757 Mon Sep 17 00:00:00 2001 From: Andreas Sundquist Date: Sat, 19 Jul 2025 08:13:06 -0700 Subject: [PATCH 1/4] Add commented benchmarkSort to compare JS vs Wasm sorting. --- src/worker.ts | 40 +++++++++++++++++++++++++++++++++++----- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/src/worker.ts b/src/worker.ts index 98c95a5..d3b7a7a 100644 --- a/src/worker.ts +++ b/src/worker.ts @@ -134,11 +134,8 @@ async function onMessage(event: MessageEvent) { readback: Uint16Array; ordering: Uint32Array; }; - result = { - id, - readback, - ordering, - }; + // Benchmark sort + // benchmarkSort(numSplats, readback, ordering); if (WASM_SPLAT_SORT) { result = { id, @@ -180,6 +177,39 @@ async function onMessage(event: MessageEvent) { ); } +function benchmarkSort( + numSplats: number, + readback: Uint16Array, + ordering: Uint32Array, +) { + if (numSplats > 0) { + const WARMUP = 10; + for (let i = 0; i < WARMUP; ++i) { + const activeSplats = sort_splats(numSplats, readback, ordering); + const results = sortDoubleSplats({ numSplats, readback, ordering }); + } + + const TIMING_SAMPLES = 1000; + let start: number; + + start = performance.now(); + for (let i = 0; i < TIMING_SAMPLES; ++i) { + const activeSplats = sort_splats(numSplats, readback, ordering); + } + const wasmTime = (performance.now() - start) / TIMING_SAMPLES; + + start = performance.now(); + for (let i = 0; i < TIMING_SAMPLES; ++i) { + const results = sortDoubleSplats({ numSplats, readback, ordering }); + } + const jsTime = (performance.now() - start) / TIMING_SAMPLES; + + console.log( + `JS: ${jsTime} ms, WASM: ${wasmTime} ms, numSplats: ${numSplats}`, + ); + } +} + async function unpackPly({ packedArray, fileBytes, From 086125f0a8a7cc5ffaefb3958b887d2997a17039 Mon Sep 17 00:00:00 2001 From: Andreas Sundquist Date: Sat, 19 Jul 2025 10:25:22 -0700 Subject: [PATCH 2/4] Add 32-bit float sort via SparkViewpooint.sort32, using 2-pass radix-65536 sort. Turn on sort32 by default for examples/editor. Implemented Rust sort and JS sort, updated benchmarking code to include float16 and float32 sort. --- examples/editor/index.html | 2 + rust/spark-internal-rs/src/lib.rs | 32 +++- rust/spark-internal-rs/src/sort.rs | 122 ++++++++++++++- src/SparkViewpoint.ts | 70 +++++++-- src/worker.ts | 232 ++++++++++++++++++++++++----- 5 files changed, 403 insertions(+), 55 deletions(-) diff --git a/examples/editor/index.html b/examples/editor/index.html index f744c1d..317c93a 100644 --- a/examples/editor/index.html +++ b/examples/editor/index.html @@ -480,6 +480,8 @@ stats.dom.style.display = value ? "block" : "none"; }); gui.add(spark.defaultView, "sortRadial").name("Radial sort").listen(); + spark.defaultView.sort32 = true; + gui.add(spark.defaultView, "sort32").name("Float32 sort").listen(); gui.add(grid, "opacity", 0, 1, 0.01).name("Grid opacity").listen(); gui.add({ logFocalDistance: 0.0, diff --git a/rust/spark-internal-rs/src/lib.rs b/rust/spark-internal-rs/src/lib.rs index 417b0c2..b00414a 100644 --- a/rust/spark-internal-rs/src/lib.rs +++ b/rust/spark-internal-rs/src/lib.rs @@ -4,7 +4,7 @@ use js_sys::{Float32Array, Uint16Array, Uint32Array}; use wasm_bindgen::prelude::*; mod sort; -use sort::{sort_internal, SortBuffers}; +use sort::{sort_internal, SortBuffers, sort32_internal, Sort32Buffers}; mod raycast; use raycast::{raycast_ellipsoids, raycast_spheres}; @@ -13,6 +13,7 @@ const RAYCAST_BUFFER_COUNT: u32 = 65536; thread_local! { static SORT_BUFFERS: RefCell = RefCell::new(SortBuffers::default()); + static SORT32_BUFFERS: RefCell = RefCell::new(Sort32Buffers::default()); static RAYCAST_BUFFER: RefCell> = RefCell::new(vec![0; RAYCAST_BUFFER_COUNT as usize * 4]); } @@ -45,6 +46,35 @@ pub fn sort_splats( active_splats } +#[wasm_bindgen] +pub fn sort32_splats( + num_splats: u32, readback: Uint32Array, ordering: Uint32Array, +) -> u32 { + let max_splats = readback.length() as usize; + + let active_splats = SORT32_BUFFERS.with_borrow_mut(|buffers| { + buffers.ensure_size(max_splats); + let sub_readback = readback.subarray(0, num_splats); + sub_readback.copy_to(&mut buffers.readback[..num_splats as usize]); + + let active_splats = match sort32_internal(buffers, max_splats, num_splats as usize) { + Ok(active_splats) => active_splats, + Err(err) => { + wasm_bindgen::throw_str(&format!("{}", err)); + } + }; + + if active_splats > 0 { + // Copy out ordering result + let subarray = &buffers.ordering[..active_splats as usize]; + ordering.subarray(0, active_splats).copy_from(&subarray); + } + active_splats + }); + + active_splats +} + #[wasm_bindgen] pub fn raycast_splats( origin_x: f32, origin_y: f32, origin_z: f32, diff --git a/rust/spark-internal-rs/src/sort.rs b/rust/spark-internal-rs/src/sort.rs index 9b16f2f..9b1f947 100644 --- a/rust/spark-internal-rs/src/sort.rs +++ b/rust/spark-internal-rs/src/sort.rs @@ -1,7 +1,7 @@ use anyhow::anyhow; -const DEPTH_INFINITY: u32 = 0x7c00; -const DEPTH_SIZE: usize = DEPTH_INFINITY as usize + 1; +const DEPTH_INFINITY_F16: u32 = 0x7c00; +const DEPTH_SIZE_F16: usize = DEPTH_INFINITY_F16 as usize + 1; #[derive(Default)] pub struct SortBuffers { @@ -18,8 +18,8 @@ impl SortBuffers { if self.ordering.len() < max_splats { self.ordering.resize(max_splats, 0); } - if self.buckets.len() < DEPTH_SIZE { - self.buckets.resize(DEPTH_SIZE, 0); + if self.buckets.len() < DEPTH_SIZE_F16 { + self.buckets.resize(DEPTH_SIZE_F16, 0); } } } @@ -30,11 +30,11 @@ pub fn sort_internal(buffers: &mut SortBuffers, num_splats: usize) -> anyhow::Re // Set the bucket counts to zero buckets.clear(); - buckets.resize(DEPTH_SIZE, 0); + buckets.resize(DEPTH_SIZE_F16, 0); // Count the number of splats in each bucket for &metric in readback.iter() { - if (metric as u32) < DEPTH_INFINITY { + if (metric as u32) < DEPTH_INFINITY_F16 { buckets[metric as usize] += 1; } } @@ -49,7 +49,7 @@ pub fn sort_internal(buffers: &mut SortBuffers, num_splats: usize) -> anyhow::Re // Write out splat indices at the right location using bucket offsets for (index, &metric) in readback.iter().enumerate() { - if (metric as u32) < DEPTH_INFINITY { + if (metric as u32) < DEPTH_INFINITY_F16 { ordering[buckets[metric as usize] as usize] = index as u32; buckets[metric as usize] += 1; } @@ -65,3 +65,111 @@ pub fn sort_internal(buffers: &mut SortBuffers, num_splats: usize) -> anyhow::Re } Ok(active_splats) } + +const DEPTH_INFINITY_F32: u32 = 0x7f800000; +const RADIX_BASE: usize = 1 << 16; // 65536 + +#[derive(Default)] +pub struct Sort32Buffers { + /// raw f32 bit‑patterns (one per splat) + pub readback: Vec, + /// output indices + pub ordering: Vec, + /// bucket counts / offsets (length == RADIX_BASE) + pub buckets16: Vec, + /// scratch space for indices + pub scratch: Vec, +} + +impl Sort32Buffers { + /// ensure all internal buffers are large enough for up to `max_splats` + pub fn ensure_size(&mut self, max_splats: usize) { + if self.readback.len() < max_splats { + self.readback.resize(max_splats, 0); + } + if self.ordering.len() < max_splats { + self.ordering.resize(max_splats, 0); + } + if self.scratch.len() < max_splats { + self.scratch.resize(max_splats, 0); + } + if self.buckets16.len() < RADIX_BASE { + self.buckets16.resize(RADIX_BASE, 0); + } + } +} + +/// Two‑pass radix sort (base 2¹⁶) of 32‑bit float bit‑patterns, +/// descending order (largest keys first). Mirrors the JS `sort32Splats`. +pub fn sort32_internal( + buffers: &mut Sort32Buffers, + max_splats: usize, + num_splats: usize, +) -> anyhow::Result { + // make sure our buffers can hold `max_splats` + buffers.ensure_size(max_splats); + + let Sort32Buffers { readback, ordering, buckets16, scratch } = buffers; + let keys = &readback[..num_splats]; + + // ——— Pass #1: bucket by inv(low 16 bits) ——— + buckets16.fill(0); + for &key in keys.iter() { + if key < DEPTH_INFINITY_F32 { + let inv = !key; + buckets16[(inv & 0xFFFF) as usize] += 1; + } + } + // exclusive prefix‑sum → starting offsets + let mut total: u32 = 0; + for slot in buckets16.iter_mut() { + let cnt = *slot; + *slot = total; + total = total.wrapping_add(cnt); + } + let active_splats = total; + + // scatter into scratch by low bits of inv + for (i, &key) in keys.iter().enumerate() { + if key < DEPTH_INFINITY_F32 { + let inv = !key; + let lo = (inv & 0xFFFF) as usize; + scratch[buckets16[lo] as usize] = i as u32; + buckets16[lo] += 1; + } + } + + // ——— Pass #2: bucket by inv(high 16 bits) ——— + buckets16.fill(0); + for &idx in scratch.iter().take(active_splats as usize) { + let key = keys[idx as usize]; + let inv = !key; + buckets16[(inv >> 16) as usize] += 1; + } + // exclusive prefix‑sum again + let mut sum: u32 = 0; + for slot in buckets16.iter_mut() { + let cnt = *slot; + *slot = sum; + sum = sum.wrapping_add(cnt); + } + // scatter into final ordering by high bits of inv + for &idx in scratch.iter().take(active_splats as usize) { + let key = keys[idx as usize]; + let inv = !key; + let hi = (inv >> 16) as usize; + ordering[buckets16[hi] as usize] = idx; + buckets16[hi] += 1; + } + + // sanity‑check: last bucket should have consumed all entries + if buckets16[RADIX_BASE - 1] != active_splats { + return Err(anyhow!( + "Expected {} active splats but got {}", + active_splats, + buckets16[RADIX_BASE - 1] + )); + } + + Ok(active_splats) +} \ No newline at end of file diff --git a/src/SparkViewpoint.ts b/src/SparkViewpoint.ts index c2056d4..946b77b 100644 --- a/src/SparkViewpoint.ts +++ b/src/SparkViewpoint.ts @@ -18,6 +18,7 @@ import { dyno, dynoBlock, dynoConst, + floatBitsToUint, mul, packHalf2x16, readPackedSplat, @@ -117,6 +118,11 @@ export type SparkViewpointOptions = { * @default false */ sort360?: boolean; + /* + * Set this to true to sort with float32 precision with two-pass sort. + * @default true + */ + sort32?: boolean; }; // A SparkViewpoint is created from and tied to a SparkRenderer, and represents @@ -149,6 +155,7 @@ export class SparkViewpoint { sortCoorient?: boolean; depthBias?: number; sort360?: boolean; + sort32?: boolean; display: { accumulator: SplatAccumulator; @@ -164,7 +171,8 @@ export class SparkViewpoint { } | null = null; private sortingCheck = false; - private readback: Uint16Array = new Uint16Array(0); + private readback16: Uint16Array = new Uint16Array(0); + private readback32: Uint32Array = new Uint32Array(0); private orderingFreelist: FreeList; constructor(options: SparkViewpointOptions & { spark: SparkRenderer }) { @@ -209,6 +217,7 @@ export class SparkViewpoint { this.sortCoorient = options.sortCoorient; this.depthBias = options.depthBias; this.sort360 = options.sort360; + this.sort32 = options.sort32; this.orderingFreelist = new FreeList({ allocate: (maxSplats) => new Uint32Array(maxSplats), @@ -557,6 +566,7 @@ export class SparkViewpoint { const { reader, doubleSortReader, + sort32Reader, dynoSortRadial, dynoOrigin, dynoDirection, @@ -564,8 +574,16 @@ export class SparkViewpoint { dynoSort360, dynoSplats, } = SparkViewpoint.makeSorter(); - const halfMaxSplats = Math.ceil(maxSplats / 2); - this.readback = reader.ensureBuffer(halfMaxSplats, this.readback); + const sort32 = this.sort32 ?? false; + let readback: Uint16Array | Uint32Array; + if (sort32) { + this.readback32 = reader.ensureBuffer(maxSplats, this.readback32); + readback = this.readback32; + } else { + const halfMaxSplats = Math.ceil(maxSplats / 2); + this.readback16 = reader.ensureBuffer(halfMaxSplats, this.readback16); + readback = this.readback16; + } const worldToOrigin = accumulator.toWorld.clone().invert(); const viewToOrigin = viewToWorld.clone().premultiply(worldToOrigin); @@ -581,25 +599,33 @@ export class SparkViewpoint { dynoSort360.value = this.sort360 ?? false; dynoSplats.packedSplats = accumulator.splats; + const sortReader = sort32 ? sort32Reader : doubleSortReader; + const count = sort32 ? numSplats : Math.ceil(numSplats / 2); await reader.renderReadback({ renderer: this.spark.renderer, - reader: doubleSortReader, - count: Math.ceil(numSplats / 2), - readback: this.readback, + reader: sortReader, + count, + readback, }); const result = (await withWorker(async (worker) => { - return worker.call("sortDoubleSplats", { + const rpcName = sort32 ? "sort32Splats" : "sortDoubleSplats"; + return worker.call(rpcName, { + maxSplats, numSplats, - readback: this.readback, + readback, ordering, }); })) as { - readback: Uint16Array; + readback: Uint16Array | Uint32Array; ordering: Uint32Array; activeSplats: number; }; - this.readback = result.readback; + if (sort32) { + this.readback32 = result.readback as Uint32Array; + } else { + this.readback16 = result.readback as Uint16Array; + } ordering = result.ordering; activeSplats = result.activeSplats; } @@ -669,6 +695,7 @@ export class SparkViewpoint { dynoSplats: DynoPackedSplats; reader: Readback; doubleSortReader: DynoBlock<{ index: "int" }, { rgba8: "vec4" }>; + sort32Reader: DynoBlock<{ index: "int" }, { rgba8: "vec4" }>; } | null = null; private static makeSorter() { @@ -716,6 +743,28 @@ export class SparkViewpoint { }, ); + const sort32Reader = dynoBlock( + { index: "int" }, + { rgba8: "vec4" }, + ({ index }) => { + if (!index) { + throw new Error("No index"); + } + const sortParams = { + sortRadial: dynoSortRadial, + sortOrigin: dynoOrigin, + sortDirection: dynoDirection, + sortDepthBias: dynoDepthBias, + sort360: dynoSort360, + }; + + const gsplat = readPackedSplat(dynoSplats, index); + const metric = computeSortMetric({ gsplat, ...sortParams }); + const rgba8 = uintToRgba8(floatBitsToUint(metric)); + return { rgba8 }; + }, + ); + SparkViewpoint.dynos = { dynoSortRadial, dynoOrigin, @@ -725,6 +774,7 @@ export class SparkViewpoint { dynoSplats, reader, doubleSortReader, + sort32Reader, }; } return SparkViewpoint.dynos; diff --git a/src/worker.ts b/src/worker.ts index d3b7a7a..6ebfe9e 100644 --- a/src/worker.ts +++ b/src/worker.ts @@ -1,4 +1,4 @@ -import init_wasm, { sort_splats } from "spark-internal-rs"; +import init_wasm, { sort_splats, sort32_splats } from "spark-internal-rs"; import type { PcSogsJson, TranscodeSpzInput } from "./SplatLoader"; import { unpackAntiSplat } from "./antisplat"; import { WASM_SPLAT_SORT } from "./defines"; @@ -18,6 +18,7 @@ import { setPackedSplatQuat, setPackedSplatRgb, setPackedSplatScales, + toHalf, } from "./utils"; // WebWorker for Spark's background CPU tasks, such as Gsplat file decoding @@ -134,8 +135,6 @@ async function onMessage(event: MessageEvent) { readback: Uint16Array; ordering: Uint32Array; }; - // Benchmark sort - // benchmarkSort(numSplats, readback, ordering); if (WASM_SPLAT_SORT) { result = { id, @@ -152,6 +151,31 @@ async function onMessage(event: MessageEvent) { } break; } + case "sort32Splats": { + const { maxSplats, numSplats, readback, ordering } = args as { + maxSplats: number; + numSplats: number; + readback: Uint32Array; + ordering: Uint32Array; + }; + // Benchmark sort + // benchmarkSort(numSplats, readback, ordering); + if (WASM_SPLAT_SORT) { + result = { + id, + readback, + ordering, + activeSplats: sort32_splats(numSplats, readback, ordering), + }; + } else { + result = { + id, + readback, + ...sort32Splats({ maxSplats, numSplats, readback, ordering }), + }; + } + break; + } case "transcodeSpz": { const input = args as TranscodeSpzInput; const spzBytes = await transcodeSpz(input); @@ -168,6 +192,7 @@ async function onMessage(event: MessageEvent) { } } catch (e) { error = e; + console.error(error); } // Send the result or error back to the main thread, making sure to transfer any ArrayBuffers @@ -179,14 +204,32 @@ async function onMessage(event: MessageEvent) { function benchmarkSort( numSplats: number, - readback: Uint16Array, + readback32: Uint32Array, ordering: Uint32Array, ) { if (numSplats > 0) { + console.log("Running sort benchmark"); + const readbackF32 = new Float32Array(readback32.buffer); + const readback16 = new Uint16Array(readback32.length); + for (let i = 0; i < numSplats; ++i) { + readback16[i] = toHalf(readbackF32[i]); + } + const WARMUP = 10; for (let i = 0; i < WARMUP; ++i) { - const activeSplats = sort_splats(numSplats, readback, ordering); - const results = sortDoubleSplats({ numSplats, readback, ordering }); + const activeSplats = sort_splats(numSplats, readback16, ordering); + const activeSplats32 = sort32_splats(numSplats, readback32, ordering); + const results = sortDoubleSplats({ + numSplats, + readback: readback16, + ordering, + }); + const results32 = sort32Splats({ + maxSplats: numSplats, + numSplats, + readback: readback32, + ordering, + }); } const TIMING_SAMPLES = 1000; @@ -194,19 +237,44 @@ function benchmarkSort( start = performance.now(); for (let i = 0; i < TIMING_SAMPLES; ++i) { - const activeSplats = sort_splats(numSplats, readback, ordering); + const activeSplats = sort_splats(numSplats, readback16, ordering); } const wasmTime = (performance.now() - start) / TIMING_SAMPLES; start = performance.now(); for (let i = 0; i < TIMING_SAMPLES; ++i) { - const results = sortDoubleSplats({ numSplats, readback, ordering }); + const results = sortDoubleSplats({ + numSplats, + readback: readback16, + ordering, + }); } const jsTime = (performance.now() - start) / TIMING_SAMPLES; console.log( `JS: ${jsTime} ms, WASM: ${wasmTime} ms, numSplats: ${numSplats}`, ); + + start = performance.now(); + for (let i = 0; i < TIMING_SAMPLES; ++i) { + const activeSplats32 = sort32_splats(numSplats, readback32, ordering); + } + const wasm32Time = (performance.now() - start) / TIMING_SAMPLES; + + start = performance.now(); + for (let i = 0; i < TIMING_SAMPLES; ++i) { + const results = sort32Splats({ + maxSplats: numSplats, + numSplats, + readback: readback32, + ordering, + }); + } + const js32Time = (performance.now() - start) / TIMING_SAMPLES; + + console.log( + `JS32: ${js32Time} ms, WASM32: ${wasm32Time} ms, numSplats: ${numSplats}`, + ); } } @@ -338,9 +406,9 @@ function unpackSpz(fileBytes: Uint8Array): { } // Array of buckets for sorting float16 distances with range [0, DEPTH_INFINITY]. -const DEPTH_INFINITY = 0x7c00; -const DEPTH_SIZE = DEPTH_INFINITY + 1; -let depthArray: Uint32Array | null = null; +const DEPTH_INFINITY_F16 = 0x7c00; +const DEPTH_SIZE_16 = DEPTH_INFINITY_F16 + 1; +let depthArray16: Uint32Array | null = null; function sortSplats({ totalSplats, @@ -353,10 +421,10 @@ function sortSplats({ // Sort totalSplats Gsplats, each with 4 bytes of readback, and outputs Uint32Array // of indices from most distant to nearest. Each 4 bytes encode a float16 distance // and unused high bytes. - if (!depthArray) { - depthArray = new Uint32Array(DEPTH_SIZE); + if (!depthArray16) { + depthArray16 = new Uint32Array(DEPTH_SIZE_16); } - depthArray.fill(0); + depthArray16.fill(0); const readbackUint32 = readback.map((layer) => new Uint32Array(layer.buffer)); const layerSize = readbackUint32[0].length; @@ -368,17 +436,17 @@ function sortSplats({ const layerSplats = Math.min(readbackLayer.length, totalSplats - layerBase); for (let i = 0; i < layerSplats; ++i) { const pri = readbackLayer[i] & 0x7fff; - if (pri < DEPTH_INFINITY) { - depthArray[pri] += 1; + if (pri < DEPTH_INFINITY_F16) { + depthArray16[pri] += 1; } } layerBase += layerSplats; } let activeSplats = 0; - for (let j = 0; j < DEPTH_SIZE; ++j) { - const nextIndex = activeSplats + depthArray[j]; - depthArray[j] = activeSplats; + for (let j = 0; j < DEPTH_SIZE_16; ++j) { + const nextIndex = activeSplats + depthArray16[j]; + depthArray16[j] = activeSplats; activeSplats = nextIndex; } @@ -388,16 +456,16 @@ function sortSplats({ const layerSplats = Math.min(readbackLayer.length, totalSplats - layerBase); for (let i = 0; i < layerSplats; ++i) { const pri = readbackLayer[i] & 0x7fff; - if (pri < DEPTH_INFINITY) { - ordering[depthArray[pri]] = layerBase + i; - depthArray[pri] += 1; + if (pri < DEPTH_INFINITY_F16) { + ordering[depthArray16[pri]] = layerBase + i; + depthArray16[pri] += 1; } } layerBase += layerSplats; } - if (depthArray[DEPTH_SIZE - 1] !== activeSplats) { + if (depthArray16[DEPTH_SIZE_16 - 1] !== activeSplats) { throw new Error( - `Expected ${activeSplats} active splats but got ${depthArray[DEPTH_SIZE - 1]}`, + `Expected ${activeSplats} active splats but got ${depthArray16[DEPTH_SIZE_16 - 1]}`, ); } @@ -415,16 +483,16 @@ function sortDoubleSplats({ ordering: Uint32Array; } { // Ensure depthArray is allocated and zeroed out for our buckets. - if (!depthArray) { - depthArray = new Uint32Array(DEPTH_SIZE); + if (!depthArray16) { + depthArray16 = new Uint32Array(DEPTH_SIZE_16); } - depthArray.fill(0); + depthArray16.fill(0); // Count the number of splats in each bucket (cull Gsplats at infinity). for (let i = 0; i < numSplats; ++i) { const pri = readback[i]; - if (pri < DEPTH_INFINITY) { - depthArray[pri] += 1; + if (pri < DEPTH_INFINITY_F16) { + depthArray16[pri] += 1; } } @@ -432,9 +500,9 @@ function sortDoubleSplats({ // total number of active (non-infinity) splats, going in reverse order // because we want most distant Gsplats to be first in the output array. let activeSplats = 0; - for (let j = DEPTH_INFINITY - 1; j >= 0; --j) { - const nextIndex = activeSplats + depthArray[j]; - depthArray[j] = activeSplats; + for (let j = DEPTH_INFINITY_F16 - 1; j >= 0; --j) { + const nextIndex = activeSplats + depthArray16[j]; + depthArray16[j] = activeSplats; activeSplats = nextIndex; } @@ -442,16 +510,106 @@ function sortDoubleSplats({ // bucket order. for (let i = 0; i < numSplats; ++i) { const pri = readback[i]; - if (pri < DEPTH_INFINITY) { - ordering[depthArray[pri]] = i; - depthArray[pri] += 1; + if (pri < DEPTH_INFINITY_F16) { + ordering[depthArray16[pri]] = i; + depthArray16[pri] += 1; } } // Sanity check that the end of the closest bucket is the same as // our total count of active splats (not at infinity). - if (depthArray[0] !== activeSplats) { + if (depthArray16[0] !== activeSplats) { + throw new Error( + `Expected ${activeSplats} active splats but got ${depthArray16[0]}`, + ); + } + + return { activeSplats, ordering }; +} + +const DEPTH_INFINITY_F32 = 0x7f800000; +let bucket16: Uint32Array | null = null; +let scratchSplats: Uint32Array | null = null; + +// two-pass radix sort (base 65536) of 32-bit keys in readback, +// but placing largest values first. +function sort32Splats({ + maxSplats, + numSplats, + readback, // Uint32Array of bit‑patterns + ordering, // Uint32Array to fill with sorted indices +}: { + maxSplats: number; + numSplats: number; + readback: Uint32Array; + ordering: Uint32Array; +}): { activeSplats: number; ordering: Uint32Array } { + const BASE = 1 << 16; // 65536 + + // allocate once + if (!bucket16) { + bucket16 = new Uint32Array(BASE); + } + if (!scratchSplats || scratchSplats.length < maxSplats) { + scratchSplats = new Uint32Array(maxSplats); + } + + // + // ——— Pass #1: bucket by inv(lo 16 bits) ——— + // + bucket16.fill(0); + for (let i = 0; i < numSplats; ++i) { + const key = readback[i]; + if (key < DEPTH_INFINITY_F32) { + const inv = ~key >>> 0; + bucket16[inv & 0xffff] += 1; + } + } + // exclusive prefix‑sum → starting offsets + let total = 0; + for (let b = 0; b < BASE; ++b) { + const c = bucket16[b]; + bucket16[b] = total; + total += c; + } + const activeSplats = total; + + // scatter into scratch by low bits of inv + for (let i = 0; i < numSplats; ++i) { + const key = readback[i]; + if (key < DEPTH_INFINITY_F32) { + const inv = ~key >>> 0; + scratchSplats[bucket16[inv & 0xffff]++] = i; + } + } + + // + // ——— Pass #2: bucket by inv(hi 16 bits) ——— + // + bucket16.fill(0); + for (let k = 0; k < activeSplats; ++k) { + const idx = scratchSplats[k]; + const inv = ~readback[idx] >>> 0; + bucket16[inv >>> 16] += 1; + } + // exclusive prefix‑sum again + let sum = 0; + for (let b = 0; b < BASE; ++b) { + const c = bucket16[b]; + bucket16[b] = sum; + sum += c; + } + + // scatter into final ordering by high bits of inv + for (let k = 0; k < activeSplats; ++k) { + const idx = scratchSplats[k]; + const inv = ~readback[idx] >>> 0; + ordering[bucket16[inv >>> 16]++] = idx; + } + + // sanity‑check: the last bucket should have eaten all entries + if (bucket16[BASE - 1] !== activeSplats) { throw new Error( - `Expected ${activeSplats} active splats but got ${depthArray[0]}`, + `Expected ${activeSplats} active splats but got ${bucket16[BASE - 1]}`, ); } From 39f9b898fb908b352ac9b850ad94b1e69bdc8e0a Mon Sep 17 00:00:00 2001 From: Andreas Sundquist Date: Sat, 19 Jul 2025 11:10:51 -0700 Subject: [PATCH 3/4] Add sort32 to docs/spark-viewpoint.md. --- docs/docs/spark-viewpoint.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/docs/spark-viewpoint.md b/docs/docs/spark-viewpoint.md index db24c05..22e8cf4 100644 --- a/docs/docs/spark-viewpoint.md +++ b/docs/docs/spark-viewpoint.md @@ -23,6 +23,7 @@ const viewpoint = spark.newViewpoint({ sortCoorient?: boolean; depthBias?: number; sort360?: boolean; + sort32?: boolean; }); ``` @@ -44,6 +45,7 @@ const viewpoint = spark.newViewpoint({ | **sortCoorient** | View direction dot product threshold for re-sorting splats. For `sortRadial: true` it defaults to 0.99 while `sortRadial: false` uses 0.999 because it is more sensitive to view direction. (default: `0.99` if `sortRadial` else `0.999`) | **depthBias** | Constant added to Z-depth to bias values into the positive range for `sortRadial: false`, but also used for culling splats "well behind" the viewpoint origin (default: `1.0`) | **sort360** | Set this to true if rendering a 360 to disable "behind the viewpoint" culling during sorting. This is set automatically when rendering 360 envMaps using the `SparkRenderer.renderEnvMap()` utility function. (default: `false`) +| **sort32** | Set this to true to sort with float32 precision with two-pass sort. (default: `false`) ## `dispose()` From 732df78e11e9fa9c2adf2fd799fd358373619d21 Mon Sep 17 00:00:00 2001 From: Andreas Sundquist Date: Sat, 19 Jul 2025 11:18:42 -0700 Subject: [PATCH 4/4] Externalized SparkRenderer.maxPixelRadius and .minAlpha settings. Added to documentation. Added to examples/editor under Debug folder, moved sort32 there. --- docs/docs/spark-renderer.md | 2 ++ examples/editor/index.html | 6 ++++-- src/SparkRenderer.ts | 20 ++++++++++++++++++++ src/shaders/splatDefines.glsl | 4 ---- src/shaders/splatFragment.glsl | 3 ++- src/shaders/splatVertex.glsl | 14 ++++++++------ 6 files changed, 36 insertions(+), 13 deletions(-) diff --git a/docs/docs/spark-renderer.md b/docs/docs/spark-renderer.md index 29959ad..84c6ca9 100644 --- a/docs/docs/spark-renderer.md +++ b/docs/docs/spark-renderer.md @@ -57,6 +57,8 @@ const spark = new SparkRenderer({ | **preUpdate** | Controls whether to update the splats before or after rendering. For WebXR this *must* be false in order to complete rendering as soon as possible. (default: `false`) | **originDistance** | Distance threshold for `SparkRenderer` movement triggering a splat update at the new origin. (default: `1.0`) This can be useful when your `SparkRenderer` is a child of your camera and you want to retain high precision coordinates near the camera. | **maxStdDev** | Maximum standard deviations from the center to render Gaussians. Values `Math.sqrt(5)`..`Math.sqrt(9)` produce good results and can be tweaked for performance. (default: `Math.sqrt(8)`) +| **maxPixelRadius** | Maximum pixel radius for splat rendering. (default: `512.0`) +| **minAlpha** | Minimum alpha value for splat rendering. (default: `0.5 * (1.0 / 255.0)`) | **enable2DGS** | Enable 2D Gaussian splatting rendering ability. When this mode is enabled, any `scale` x/y/z component that is exactly `0` (minimum quantized value) results in the other two non-zero axes being interpreted as an oriented 2D Gaussian Splat instead of the usual approximate projected 3DGS Z-slice. When reading PLY files, scale values less than e^-30 will be interpreted as `0`. (default: `false`) | **preBlurAmount** | Scalar value to add to 2D splat covariance diagonal, effectively blurring + enlarging splats. In scenes trained without the splat anti-aliasing tweak this value was typically 0.3, but with anti-aliasing it is 0.0 (default: `0.0`) | **blurAmount** | Scalar value to add to 2D splat covariance diagonal, with opacity adjustment to correctly account for "blurring" when anti-aliasing. Typically 0.3 (equivalent to approx 0.5 pixel radius) in scenes trained with anti-aliasing. diff --git a/examples/editor/index.html b/examples/editor/index.html index 317c93a..a37e418 100644 --- a/examples/editor/index.html +++ b/examples/editor/index.html @@ -480,8 +480,6 @@ stats.dom.style.display = value ? "block" : "none"; }); gui.add(spark.defaultView, "sortRadial").name("Radial sort").listen(); - spark.defaultView.sort32 = true; - gui.add(spark.defaultView, "sort32").name("Float32 sort").listen(); gui.add(grid, "opacity", 0, 1, 0.01).name("Grid opacity").listen(); gui.add({ logFocalDistance: 0.0, @@ -515,6 +513,10 @@ }, }, "AA").name("AA preset"); debugFolder.add(spark, "focalAdjustment", 0.1, 2.0, 0.1).name("Tweak focalAdjustment"); + spark.defaultView.sort32 = true; + debugFolder.add(spark.defaultView, "sort32").name("Float32 sort").listen(); + debugFolder.add(spark, "maxPixelRadius", 1, 1024, 1).name("Max pixel radius").listen(); + debugFolder.add(spark, "minAlpha", 0, 1, 0.001).name("Min alpha").listen(); const splatsFolder = secondGui.addFolder("Files"); diff --git a/src/SparkRenderer.ts b/src/SparkRenderer.ts index 18da9a0..ee898fe 100644 --- a/src/SparkRenderer.ts +++ b/src/SparkRenderer.ts @@ -116,6 +116,16 @@ export type SparkRendererOptions = { * @default Math.sqrt(8) */ maxStdDev?: number; + /** + * Maximum pixel radius for splat rendering. + * @default 512.0 + */ + maxPixelRadius?: number; + /** + * Minimum alpha value for splat rendering. + * @default 0.5 * (1.0 / 255.0) + */ + minAlpha?: number; /** * Enable 2D Gaussian splatting rendering ability. When this mode is enabled, * any scale x/y/z component that is exactly 0 (minimum quantized value) results @@ -185,6 +195,8 @@ export class SparkRenderer extends THREE.Mesh { preUpdate: boolean; originDistance: number; maxStdDev: number; + maxPixelRadius: number; + minAlpha: number; enable2DGS: boolean; preBlurAmount: number; blurAmount: number; @@ -301,6 +313,8 @@ export class SparkRenderer extends THREE.Mesh { this.preUpdate = options.preUpdate ?? false; this.originDistance = options.originDistance ?? 1; this.maxStdDev = options.maxStdDev ?? Math.sqrt(8.0); + this.maxPixelRadius = options.maxPixelRadius ?? 512.0; + this.minAlpha = options.minAlpha ?? 0.5 * (1.0 / 255.0); this.enable2DGS = options.enable2DGS ?? false; this.preBlurAmount = options.preBlurAmount ?? 0.0; this.blurAmount = options.blurAmount ?? 0.3; @@ -350,6 +364,10 @@ export class SparkRenderer extends THREE.Mesh { renderToViewPos: { value: new THREE.Vector3() }, // Maximum distance (in stddevs) from Gsplat center to render maxStdDev: { value: 1.0 }, + // Maximum pixel radius for splat rendering + maxPixelRadius: { value: 512.0 }, + // Minimum alpha value for splat rendering + minAlpha: { value: 0.5 * (1.0 / 255.0) }, // Enable interpreting 0-thickness Gsplats as 2DGS enable2DGS: { value: false }, // Add to projected 2D splat covariance diagonal (thickens and brightens) @@ -517,6 +535,8 @@ export class SparkRenderer extends THREE.Mesh { this.uniforms.far.value = typedCamera.far; this.uniforms.encodeLinear.value = viewpoint.encodeLinear; this.uniforms.maxStdDev.value = this.maxStdDev; + this.uniforms.maxPixelRadius.value = this.maxPixelRadius; + this.uniforms.minAlpha.value = this.minAlpha; this.uniforms.enable2DGS.value = this.enable2DGS; this.uniforms.preBlurAmount.value = this.preBlurAmount; this.uniforms.blurAmount.value = this.blurAmount; diff --git a/src/shaders/splatDefines.glsl b/src/shaders/splatDefines.glsl index 4a4fc10..d7b8bae 100644 --- a/src/shaders/splatDefines.glsl +++ b/src/shaders/splatDefines.glsl @@ -21,10 +21,6 @@ const float PI = 3.1415926535897932384626433832795; const float INFINITY = 1.0 / 0.0; const float NEG_INFINITY = -INFINITY; -const float MAX_PIXEL_RADIUS = 512.0; -const float MIN_ALPHA = 0.5 * (1.0 / 255.0); // 0.00196 -const float MAX_STDDEV = sqrt(8.0); - float sqr(float x) { return x * x; } diff --git a/src/shaders/splatFragment.glsl b/src/shaders/splatFragment.glsl index c49de39..4c4ef77 100644 --- a/src/shaders/splatFragment.glsl +++ b/src/shaders/splatFragment.glsl @@ -8,6 +8,7 @@ uniform float near; uniform float far; uniform bool encodeLinear; uniform float maxStdDev; +uniform float minAlpha; uniform bool disableFalloff; uniform float falloff; @@ -60,7 +61,7 @@ void main() { rgba.a *= mix(1.0, exp(-0.5 * z), falloff); - if (rgba.a < MIN_ALPHA) { + if (rgba.a < minAlpha) { discard; } if (encodeLinear) { diff --git a/src/shaders/splatVertex.glsl b/src/shaders/splatVertex.glsl index aea0475..c7086db 100644 --- a/src/shaders/splatVertex.glsl +++ b/src/shaders/splatVertex.glsl @@ -16,9 +16,11 @@ uniform uint numSplats; uniform vec4 renderToViewQuat; uniform vec3 renderToViewPos; uniform float maxStdDev; +uniform float maxPixelRadius; uniform float time; uniform float deltaTime; uniform bool debugFlag; +uniform float minAlpha; uniform bool enable2DGS; uniform float blurAmount; uniform float preBlurAmount; @@ -52,7 +54,7 @@ void main() { vec4 quaternion, rgba; unpackSplat(packed, center, scales, quaternion, rgba); - if (rgba.a < MIN_ALPHA) { + if (rgba.a < minAlpha) { return; } bvec3 zeroScales = equal(scales, vec3(0.0)); @@ -141,13 +143,13 @@ void main() { float fullBlurAmount = blurAmount; if ((focalDistance > 0.0) && (apertureAngle > 0.0)) { - float focusRadius = MAX_PIXEL_RADIUS; + float focusRadius = maxPixelRadius; if (viewCenter.z < 0.0) { float focusBlur = abs((-viewCenter.z - focalDistance) / viewCenter.z); float apertureRadius = focal.x * tan(0.5 * apertureAngle); focusRadius = focusBlur * apertureRadius; } - fullBlurAmount = clamp(sqr(focusRadius), blurAmount, sqr(MAX_PIXEL_RADIUS)); + fullBlurAmount = clamp(sqr(focusRadius), blurAmount, sqr(maxPixelRadius)); } // Do convolution with a 0.5-pixel Gaussian for anti-aliasing: sqrt(0.3) ~= 0.5 @@ -159,7 +161,7 @@ void main() { // Compute anti-aliasing intensity scaling factor float blurAdjust = sqrt(max(0.0, detOrig / det)); rgba.a *= blurAdjust; - if (rgba.a < MIN_ALPHA) { + if (rgba.a < minAlpha) { return; } @@ -172,8 +174,8 @@ void main() { vec2 eigenVec1 = normalize(vec2((abs(b) < 0.001) ? 1.0 : b, eigen1 - a)); vec2 eigenVec2 = vec2(eigenVec1.y, -eigenVec1.x); - float scale1 = position.x * min(MAX_PIXEL_RADIUS, maxStdDev * sqrt(eigen1)); - float scale2 = position.y * min(MAX_PIXEL_RADIUS, maxStdDev * sqrt(eigen2)); + float scale1 = position.x * min(maxPixelRadius, maxStdDev * sqrt(eigen1)); + float scale2 = position.y * min(maxPixelRadius, maxStdDev * sqrt(eigen2)); // Compute the NDC coordinates for the ellipsoid's diagonal axes. vec2 pixelOffset = eigenVec1 * scale1 + eigenVec2 * scale2;