Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/scene/gsplat-unified/gsplat-compute-local-renderer.js
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import { computeGsplatCommonSource } from '../shader-lib/wgsl/chunks/gsplat/comp
import { computeGsplatTileIntersectSource } from '../shader-lib/wgsl/chunks/gsplat/compute-gsplat-tile-intersect.js';
import { GSplatTileComposite } from './gsplat-tile-composite.js';
import { GSplatLocalDispatchSet } from './gsplat-local-dispatch-set.js';
import { CACHE_STRIDE } from './gsplat-local-constants.js';
import computeSplatSource from '../shader-lib/wgsl/chunks/gsplat/vert/gsplatComputeSplat.js';

/**
Expand Down Expand Up @@ -76,7 +77,6 @@ const INITIAL_LARGE_SPLAT_CAPACITY = 16384;

const TILE_SIZE = 16;
const MAX_TILES = 65535; // tile index must fit in 16 bits for pair packing (tileIdx << 16 | localOffset)
const CACHE_STRIDE = 7;
const MAX_CHUNKS_PER_TILE = 8;

// ---- Module-scope scratch (reusable, never exported) ----
Expand Down Expand Up @@ -1031,11 +1031,15 @@ class GSplatComputeLocalRenderer extends GSplatRenderer {
new BindUniformBufferFormat('uniforms', SHADERSTAGE_COMPUTE)
]);

const cdefines = new Map();
cdefines.set('{CACHE_STRIDE}', CACHE_STRIDE.toString());

this._largeSplatShader = new Shader(device, {
name: 'GSplatLocalTileCountLarge',
shaderLanguage: SHADERLANGUAGE_WGSL,
cshader: computeGsplatLocalTileCountLargeSource,
cincludes,
cdefines,
computeBindGroupFormat: this._largeSplatBindGroupFormat,
computeUniformBufferFormats: { uniforms: ubf }
});
Expand Down Expand Up @@ -1261,6 +1265,7 @@ class GSplatComputeLocalRenderer extends GSplatRenderer {
cincludes.set('gsplatFormatReadCS', wbFormat.getReadCode());

const cdefines = new Map();
cdefines.set('{CACHE_STRIDE}', CACHE_STRIDE.toString());
if (fisheyeEnabled) {
cdefines.set('GSPLAT_FISHEYE', '');
}
Expand Down
8 changes: 8 additions & 0 deletions src/scene/gsplat-unified/gsplat-local-constants.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// Shared constants for the local compute gsplat renderer. These values are mirrored into
// the WGSL shaders via cdefines text substitution (e.g. `{CACHE_STRIDE}u`), so they must
// live in a single place to keep JS-side buffer allocations and shader-side indexing in sync.

// Number of u32 slots per splat in projCache. 8 = 32 bytes (cache-line friendly).
// Slots: [0] centerX, [1] centerY, [2..4] conic coeffs, [5] pickId/color, [6] viewDepth/opacity,
// [7] precomputed -0.5 * radiusFactor (power cutoff for rasterize early-out).
export const CACHE_STRIDE = 8;
2 changes: 2 additions & 0 deletions src/scene/gsplat-unified/gsplat-local-dispatch-set.js
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import {
import { PrefixSumKernel } from '../graphics/prefix-sum-kernel.js';
import { shaderChunksWGSL } from '../shader-lib/wgsl/collections/shader-chunks-wgsl.js';
import { computeGsplatLocalRasterizeSource } from '../shader-lib/wgsl/chunks/gsplat/compute-gsplat-local-rasterize.js';
import { CACHE_STRIDE } from './gsplat-local-constants.js';

/**
* @import { GraphicsDevice } from '../../platform/graphics/graphics-device.js'
Expand Down Expand Up @@ -355,6 +356,7 @@ class GSplatLocalDispatchSet {
const bgf = new BindGroupFormat(device, [...sharedBindings, ...outputBindings, ...depthBindings]);

const cdefines = new Map();
cdefines.set('{CACHE_STRIDE}', CACHE_STRIDE.toString());
if (pickMode) cdefines.set('PICK_MODE', '');
if (depthTest) cdefines.set('DEPTH_TEST', '');
if (heatmap) cdefines.set('HEATMAP_MODE', '');
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ export const computeGsplatLocalRasterizeSource = /* wgsl */`
#endif
#endif

const CACHE_STRIDE: u32 = 7u;
const BATCH_SIZE: u32 = 64u;
const ALPHA_THRESHOLD: half = half(1.0) / half(255.0);
const EXP4: half = exp(half(-4.0));
Expand Down Expand Up @@ -60,6 +59,9 @@ struct Uniforms {

var<workgroup> sharedCenterScreen: array<vec2f, 64>;
var<workgroup> sharedCoeffs: array<vec3f, 64>;
// Per-splat Gaussian exponent cutoff: -radiusFactor / 2. Used to skip exp() and the
// blend chain for splats whose contribution at all 4 pixels of the quad is below alphaClip.
var<workgroup> sharedPowerCutoff: array<f32, 64>;
#ifdef HEATMAP_MODE
var<workgroup> sharedHeatCount: atomic<u32>;
#endif
Expand Down Expand Up @@ -210,7 +212,7 @@ fn main(
let batchOffset = batch * BATCH_SIZE + localIdx;
if (batchOffset < tileCount) {
let cacheIdx = tileEntries[tStart + batchOffset];
let base = cacheIdx * CACHE_STRIDE;
let base = cacheIdx * {CACHE_STRIDE}u;
sharedCenterScreen[localIdx] = vec2f(
bitcast<f32>(projCache[base + 0u]),
bitcast<f32>(projCache[base + 1u])
Expand All @@ -220,6 +222,7 @@ fn main(
let cy = bitcast<f32>(projCache[base + 3u]);
let cz = bitcast<f32>(projCache[base + 4u]);
sharedCoeffs[localIdx] = vec3f(cx * -0.5, cz * -0.5, -cy);
sharedPowerCutoff[localIdx] = bitcast<f32>(projCache[base + 7u]);

#ifdef PICK_MODE
sharedPickId[localIdx] = projCache[base + 5u];
Expand Down Expand Up @@ -265,6 +268,16 @@ fn main(
let splatPickId = sharedPickId[i];
let splatDepth = sharedViewDepth[i];

// Skip the 4 per-pixel pick evaluations entirely when the splat
// contributes nothing to any pixel in this quad.
let d = p00 - center;
let dxV = vec4f(d.x, d.x + 1.0, d.x, d.x + 1.0);
let dyV = vec4f(d.y, d.y, d.y + 1.0, d.y + 1.0);
let power4 = coeffs.x * dxV * dxV + coeffs.z * dxV * dyV + coeffs.y * dyV * dyV;
if (all(power4 <= vec4f(sharedPowerCutoff[i]))) {
continue;
}

evalSplatPick(p00, center, coeffs.x, coeffs.y, coeffs.z, splatOpacity, splatPickId, splatDepth, clipH, &pickId00, &dAcc00, &wAcc00, &T00);
evalSplatPick(p10, center, coeffs.x, coeffs.y, coeffs.z, splatOpacity, splatPickId, splatDepth, clipH, &pickId10, &dAcc10, &wAcc10, &T10);
evalSplatPick(p01, center, coeffs.x, coeffs.y, coeffs.z, splatOpacity, splatPickId, splatDepth, clipH, &pickId01, &dAcc01, &wAcc01, &T01);
Expand Down Expand Up @@ -295,6 +308,14 @@ fn main(
let dxV = vec4f(d.x, d.x + 1.0, d.x, d.x + 1.0);
let dyV = vec4f(d.y, d.y, d.y + 1.0, d.y + 1.0);
let power4 = coeffs.x * dxV * dxV + coeffs.z * dxV * dyV + coeffs.y * dyV * dyV;

// Skip exp() and the blend chain when the splat contributes nothing
// at any of the 4 pixels. The per-splat cutoff is tighter than -4 for
// low-opacity splats, so they drop earlier.
if (all(power4 <= vec4f(sharedPowerCutoff[i]))) {
continue;
}

let gauss4 = (half4(exp(power4)) - half4(EXP4)) * half4(INV_EXP4);
let alpha4 = min(half4(0.99), half4(splatColor.a) * gauss4);
let newT = T * (half4(1.0) - alpha4);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ export const computeGsplatLocalTileCountLargeSource = /* wgsl */`
#include "gsplatCommonCS"
#include "gsplatTileIntersectCS"

const CACHE_STRIDE: u32 = 7u;
const WG_SIZE: u32 = 256u;
const MAX_TILE_ENTRIES: u32 = 0xFFFFu;

Expand Down Expand Up @@ -72,7 +71,7 @@ fn main(
if (isActive) {
threadIdx = largeSplatIds[largeSplatIdx];

let cacheBase = threadIdx * CACHE_STRIDE;
let cacheBase = threadIdx * {CACHE_STRIDE}u;
screen = vec2f(bitcast<f32>(projCache[cacheBase + 0u]), bitcast<f32>(projCache[cacheBase + 1u]));
cx = bitcast<f32>(projCache[cacheBase + 2u]);
cy = bitcast<f32>(projCache[cacheBase + 3u]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ export const computeGsplatLocalTileCountSource = /* wgsl */`
#include "gsplatCommonCS"
#include "gsplatTileIntersectCS"

const CACHE_STRIDE: u32 = 7u;

// Caps the 16-bit localOffset field in packed pairs (tileIdx << 16 | localOffset).
const MAX_TILE_ENTRIES: u32 = 0xFFFFu;

Expand Down Expand Up @@ -105,7 +103,7 @@ fn main(
let opacity = getOpacity();

if (opacity < uniforms.alphaClip) {
projCache[threadIdx * CACHE_STRIDE + 6u] = 0u;
projCache[threadIdx * {CACHE_STRIDE}u + 6u] = 0u;
splatPairStart[threadIdx] = 0u;
splatPairCount[threadIdx] = 0u;
return;
Expand All @@ -127,7 +125,7 @@ fn main(
);

if (!proj.valid) {
projCache[threadIdx * CACHE_STRIDE + 6u] = 0u;
projCache[threadIdx * {CACHE_STRIDE}u + 6u] = 0u;
splatPairStart[threadIdx] = 0u;
splatPairCount[threadIdx] = 0u;
return;
Expand All @@ -139,7 +137,7 @@ fn main(
let cy = -4.0 * proj.b * invDet;
let cz = 4.0 * proj.a * invDet;

let base = threadIdx * CACHE_STRIDE;
let base = threadIdx * {CACHE_STRIDE}u;
projCache[base + 0u] = bitcast<u32>(proj.screen.x);
projCache[base + 1u] = bitcast<u32>(proj.screen.y);
projCache[base + 2u] = bitcast<u32>(cx);
Expand All @@ -165,6 +163,13 @@ fn main(
uniforms.alphaClip);
let radiusFactor = eval.radiusFactor;

// Per-splat power cutoff for the rasterize pass: the Gaussian exponent below which
// the splat's contribution at a pixel drops below alphaClip. Equal to -radiusFactor / 2
// = -log(opacity / alphaClip), clamped with radiusFactor. For high-opacity splats this
// is -4 (matching the global cutoff); for low-opacity splats it's tighter, letting the
// rasterize kernel skip exp() and the blend chain entirely for non-contributing pixels.
projCache[base + 7u] = bitcast<u32>(-0.5 * radiusFactor);

let minTileX = max(0i, i32(floor(eval.splatMin.x / f32(TILE_SIZE))));
let maxTileX = min(i32(uniforms.numTilesX) - 1i, i32(floor(eval.splatMax.x / f32(TILE_SIZE))));
let minTileY = max(0i, i32(floor(eval.splatMin.y / f32(TILE_SIZE))));
Expand Down