Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/docs/spark-renderer.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ const spark = new SparkRenderer({
| **preUpdate** | Controls whether to update the splats before or after rendering. For WebXR this *must* be false in order to complete rendering as soon as possible. (default: `false`)
| **originDistance** | Distance threshold for `SparkRenderer` movement triggering a splat update at the new origin. (default: `1.0`) This can be useful when your `SparkRenderer` is a child of your camera and you want to retain high precision coordinates near the camera.
| **maxStdDev** | Maximum standard deviations from the center to render Gaussians. Values `Math.sqrt(5)`..`Math.sqrt(9)` produce good results and can be tweaked for performance. (default: `Math.sqrt(8)`)
| **maxPixelRadius** | Maximum pixel radius for splat rendering. (default: `512.0`)
| **minAlpha** | Minimum alpha value for splat rendering. (default: `0.5 * (1.0 / 255.0)`)
| **enable2DGS** | Enable 2D Gaussian splatting rendering ability. When this mode is enabled, any `scale` x/y/z component that is exactly `0` (minimum quantized value) results in the other two non-zero axes being interpreted as an oriented 2D Gaussian Splat instead of the usual approximate projected 3DGS Z-slice. When reading PLY files, scale values less than e^-30 will be interpreted as `0`. (default: `false`)
| **preBlurAmount** | Scalar value to add to 2D splat covariance diagonal, effectively blurring + enlarging splats. In scenes trained without the splat anti-aliasing tweak this value was typically 0.3, but with anti-aliasing it is 0.0 (default: `0.0`)
| **blurAmount** | Scalar value to add to 2D splat covariance diagonal, with opacity adjustment to correctly account for "blurring" when anti-aliasing. Typically 0.3 (equivalent to approx 0.5 pixel radius) in scenes trained with anti-aliasing.
Expand Down
2 changes: 2 additions & 0 deletions docs/docs/spark-viewpoint.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ const viewpoint = spark.newViewpoint({
sortCoorient?: boolean;
depthBias?: number;
sort360?: boolean;
sort32?: boolean;
});
```

Expand All @@ -44,6 +45,7 @@ const viewpoint = spark.newViewpoint({
| **sortCoorient** | View direction dot product threshold for re-sorting splats. For `sortRadial: true` it defaults to 0.99 while `sortRadial: false` uses 0.999 because it is more sensitive to view direction. (default: `0.99` if `sortRadial` else `0.999`)
| **depthBias** | Constant added to Z-depth to bias values into the positive range for `sortRadial: false`, but also used for culling splats "well behind" the viewpoint origin (default: `1.0`)
| **sort360** | Set this to true if rendering a 360 to disable "behind the viewpoint" culling during sorting. This is set automatically when rendering 360 envMaps using the `SparkRenderer.renderEnvMap()` utility function. (default: `false`)
| **sort32** | Set this to true to sort with float32 precision with two-pass sort. (default: `false`)

## `dispose()`

Expand Down
4 changes: 4 additions & 0 deletions examples/editor/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,10 @@
},
}, "AA").name("AA preset");
debugFolder.add(spark, "focalAdjustment", 0.1, 2.0, 0.1).name("Tweak focalAdjustment");
spark.defaultView.sort32 = true;
debugFolder.add(spark.defaultView, "sort32").name("Float32 sort").listen();
debugFolder.add(spark, "maxPixelRadius", 1, 1024, 1).name("Max pixel radius").listen();
debugFolder.add(spark, "minAlpha", 0, 1, 0.001).name("Min alpha").listen();

const splatsFolder = secondGui.addFolder("Files");

Expand Down
32 changes: 31 additions & 1 deletion rust/spark-internal-rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use js_sys::{Float32Array, Uint16Array, Uint32Array};
use wasm_bindgen::prelude::*;

mod sort;
use sort::{sort_internal, SortBuffers};
use sort::{sort_internal, SortBuffers, sort32_internal, Sort32Buffers};

mod raycast;
use raycast::{raycast_ellipsoids, raycast_spheres};
Expand All @@ -13,6 +13,7 @@ const RAYCAST_BUFFER_COUNT: u32 = 65536;

thread_local! {
static SORT_BUFFERS: RefCell<SortBuffers> = RefCell::new(SortBuffers::default());
static SORT32_BUFFERS: RefCell<Sort32Buffers> = RefCell::new(Sort32Buffers::default());
static RAYCAST_BUFFER: RefCell<Vec<u32>> = RefCell::new(vec![0; RAYCAST_BUFFER_COUNT as usize * 4]);
}

Expand Down Expand Up @@ -45,6 +46,35 @@ pub fn sort_splats(
active_splats
}

#[wasm_bindgen]
pub fn sort32_splats(
num_splats: u32, readback: Uint32Array, ordering: Uint32Array,
) -> u32 {
let max_splats = readback.length() as usize;

let active_splats = SORT32_BUFFERS.with_borrow_mut(|buffers| {
buffers.ensure_size(max_splats);
let sub_readback = readback.subarray(0, num_splats);
sub_readback.copy_to(&mut buffers.readback[..num_splats as usize]);

let active_splats = match sort32_internal(buffers, max_splats, num_splats as usize) {
Ok(active_splats) => active_splats,
Err(err) => {
wasm_bindgen::throw_str(&format!("{}", err));
}
};

if active_splats > 0 {
// Copy out ordering result
let subarray = &buffers.ordering[..active_splats as usize];
ordering.subarray(0, active_splats).copy_from(&subarray);
}
active_splats
});

active_splats
}

#[wasm_bindgen]
pub fn raycast_splats(
origin_x: f32, origin_y: f32, origin_z: f32,
Expand Down
122 changes: 115 additions & 7 deletions rust/spark-internal-rs/src/sort.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use anyhow::anyhow;

const DEPTH_INFINITY: u32 = 0x7c00;
const DEPTH_SIZE: usize = DEPTH_INFINITY as usize + 1;
const DEPTH_INFINITY_F16: u32 = 0x7c00;
const DEPTH_SIZE_F16: usize = DEPTH_INFINITY_F16 as usize + 1;

#[derive(Default)]
pub struct SortBuffers {
Expand All @@ -18,8 +18,8 @@ impl SortBuffers {
if self.ordering.len() < max_splats {
self.ordering.resize(max_splats, 0);
}
if self.buckets.len() < DEPTH_SIZE {
self.buckets.resize(DEPTH_SIZE, 0);
if self.buckets.len() < DEPTH_SIZE_F16 {
self.buckets.resize(DEPTH_SIZE_F16, 0);
}
}
}
Expand All @@ -30,11 +30,11 @@ pub fn sort_internal(buffers: &mut SortBuffers, num_splats: usize) -> anyhow::Re

// Set the bucket counts to zero
buckets.clear();
buckets.resize(DEPTH_SIZE, 0);
buckets.resize(DEPTH_SIZE_F16, 0);

// Count the number of splats in each bucket
for &metric in readback.iter() {
if (metric as u32) < DEPTH_INFINITY {
if (metric as u32) < DEPTH_INFINITY_F16 {
buckets[metric as usize] += 1;
}
}
Expand All @@ -49,7 +49,7 @@ pub fn sort_internal(buffers: &mut SortBuffers, num_splats: usize) -> anyhow::Re

// Write out splat indices at the right location using bucket offsets
for (index, &metric) in readback.iter().enumerate() {
if (metric as u32) < DEPTH_INFINITY {
if (metric as u32) < DEPTH_INFINITY_F16 {
ordering[buckets[metric as usize] as usize] = index as u32;
buckets[metric as usize] += 1;
}
Expand All @@ -65,3 +65,111 @@ pub fn sort_internal(buffers: &mut SortBuffers, num_splats: usize) -> anyhow::Re
}
Ok(active_splats)
}

const DEPTH_INFINITY_F32: u32 = 0x7f800000;
const RADIX_BASE: usize = 1 << 16; // 65536

#[derive(Default)]
pub struct Sort32Buffers {
/// raw f32 bit‑patterns (one per splat)
pub readback: Vec<u32>,
/// output indices
pub ordering: Vec<u32>,
/// bucket counts / offsets (length == RADIX_BASE)
pub buckets16: Vec<u32>,
/// scratch space for indices
pub scratch: Vec<u32>,
}

impl Sort32Buffers {
/// ensure all internal buffers are large enough for up to `max_splats`
pub fn ensure_size(&mut self, max_splats: usize) {
if self.readback.len() < max_splats {
self.readback.resize(max_splats, 0);
}
if self.ordering.len() < max_splats {
self.ordering.resize(max_splats, 0);
}
if self.scratch.len() < max_splats {
self.scratch.resize(max_splats, 0);
}
if self.buckets16.len() < RADIX_BASE {
self.buckets16.resize(RADIX_BASE, 0);
}
}
}

/// Two‑pass radix sort (base 2¹⁶) of 32‑bit float bit‑patterns,
/// descending order (largest keys first). Mirrors the JS `sort32Splats`.
pub fn sort32_internal(
buffers: &mut Sort32Buffers,
max_splats: usize,
num_splats: usize,
) -> anyhow::Result<u32> {
// make sure our buffers can hold `max_splats`
buffers.ensure_size(max_splats);

let Sort32Buffers { readback, ordering, buckets16, scratch } = buffers;
let keys = &readback[..num_splats];

// ——— Pass #1: bucket by inv(low 16 bits) ———
buckets16.fill(0);
for &key in keys.iter() {
if key < DEPTH_INFINITY_F32 {
let inv = !key;
buckets16[(inv & 0xFFFF) as usize] += 1;
}
}
// exclusive prefix‑sum → starting offsets
let mut total: u32 = 0;
for slot in buckets16.iter_mut() {
let cnt = *slot;
*slot = total;
total = total.wrapping_add(cnt);
}
let active_splats = total;

// scatter into scratch by low bits of inv
for (i, &key) in keys.iter().enumerate() {
if key < DEPTH_INFINITY_F32 {
let inv = !key;
let lo = (inv & 0xFFFF) as usize;
scratch[buckets16[lo] as usize] = i as u32;
buckets16[lo] += 1;
}
}

// ——— Pass #2: bucket by inv(high 16 bits) ———
buckets16.fill(0);
for &idx in scratch.iter().take(active_splats as usize) {
let key = keys[idx as usize];
let inv = !key;
buckets16[(inv >> 16) as usize] += 1;
}
// exclusive prefix‑sum again
let mut sum: u32 = 0;
for slot in buckets16.iter_mut() {
let cnt = *slot;
*slot = sum;
sum = sum.wrapping_add(cnt);
}
// scatter into final ordering by high bits of inv
for &idx in scratch.iter().take(active_splats as usize) {
let key = keys[idx as usize];
let inv = !key;
let hi = (inv >> 16) as usize;
ordering[buckets16[hi] as usize] = idx;
buckets16[hi] += 1;
}

// sanity‑check: last bucket should have consumed all entries
if buckets16[RADIX_BASE - 1] != active_splats {
return Err(anyhow!(
"Expected {} active splats but got {}",
active_splats,
buckets16[RADIX_BASE - 1]
));
}

Ok(active_splats)
}
20 changes: 20 additions & 0 deletions src/SparkRenderer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,16 @@ export type SparkRendererOptions = {
* @default Math.sqrt(8)
*/
maxStdDev?: number;
/**
* Maximum pixel radius for splat rendering.
* @default 512.0
*/
maxPixelRadius?: number;
/**
* Minimum alpha value for splat rendering.
* @default 0.5 * (1.0 / 255.0)
*/
minAlpha?: number;
/**
* Enable 2D Gaussian splatting rendering ability. When this mode is enabled,
* any scale x/y/z component that is exactly 0 (minimum quantized value) results
Expand Down Expand Up @@ -185,6 +195,8 @@ export class SparkRenderer extends THREE.Mesh {
preUpdate: boolean;
originDistance: number;
maxStdDev: number;
maxPixelRadius: number;
minAlpha: number;
enable2DGS: boolean;
preBlurAmount: number;
blurAmount: number;
Expand Down Expand Up @@ -301,6 +313,8 @@ export class SparkRenderer extends THREE.Mesh {
this.preUpdate = options.preUpdate ?? false;
this.originDistance = options.originDistance ?? 1;
this.maxStdDev = options.maxStdDev ?? Math.sqrt(8.0);
this.maxPixelRadius = options.maxPixelRadius ?? 512.0;
this.minAlpha = options.minAlpha ?? 0.5 * (1.0 / 255.0);
this.enable2DGS = options.enable2DGS ?? false;
this.preBlurAmount = options.preBlurAmount ?? 0.0;
this.blurAmount = options.blurAmount ?? 0.3;
Expand Down Expand Up @@ -350,6 +364,10 @@ export class SparkRenderer extends THREE.Mesh {
renderToViewPos: { value: new THREE.Vector3() },
// Maximum distance (in stddevs) from Gsplat center to render
maxStdDev: { value: 1.0 },
// Maximum pixel radius for splat rendering
maxPixelRadius: { value: 512.0 },
// Minimum alpha value for splat rendering
minAlpha: { value: 0.5 * (1.0 / 255.0) },
// Enable interpreting 0-thickness Gsplats as 2DGS
enable2DGS: { value: false },
// Add to projected 2D splat covariance diagonal (thickens and brightens)
Expand Down Expand Up @@ -517,6 +535,8 @@ export class SparkRenderer extends THREE.Mesh {
this.uniforms.far.value = typedCamera.far;
this.uniforms.encodeLinear.value = viewpoint.encodeLinear;
this.uniforms.maxStdDev.value = this.maxStdDev;
this.uniforms.maxPixelRadius.value = this.maxPixelRadius;
this.uniforms.minAlpha.value = this.minAlpha;
this.uniforms.enable2DGS.value = this.enable2DGS;
this.uniforms.preBlurAmount.value = this.preBlurAmount;
this.uniforms.blurAmount.value = this.blurAmount;
Expand Down
Loading