Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 27 additions & 71 deletions src/scene/gsplat-unified/gsplat-compute-local-renderer.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,16 @@ import { BindGroupFormat, BindStorageBufferFormat, BindStorageTextureFormat, Bin
import { UniformBufferFormat, UniformFormat } from '../../platform/graphics/uniform-buffer-format.js';
import {
BUFFERUSAGE_COPY_DST, BUFFERUSAGE_COPY_SRC, BUFFERUSAGE_INDIRECT,
CULLFACE_NONE,
FILTER_NEAREST,
PIXELFORMAT_RGBA8,
SAMPLETYPE_UINT,
SEMANTIC_POSITION,
SHADERLANGUAGE_WGSL,
SHADERSTAGE_COMPUTE,
UNIFORMTYPE_FLOAT,
UNIFORMTYPE_MAT4,
UNIFORMTYPE_UINT
} from '../../platform/graphics/constants.js';
import { BLEND_PREMULTIPLIED, GSPLAT_FORWARD } from '../constants.js';
import { ShaderMaterial } from '../materials/shader-material.js';
import { MeshInstance } from '../mesh-instance.js';
import { Mesh } from '../mesh.js';
import { GSPLAT_FORWARD } from '../constants.js';
import { Mat4 } from '../../core/math/mat4.js';
import { GSplatRenderer } from './gsplat-renderer.js';
import { FramePassGSplatComputeLocal } from './frame-pass-gsplat-compute-local.js';
Expand All @@ -37,6 +32,7 @@ import { computeGsplatLocalCopySource } from '../shader-lib/wgsl/chunks/gsplat/c
import { computeGsplatLocalBitonicSource } from '../shader-lib/wgsl/chunks/gsplat/compute-gsplat-local-bitonic.js';
import { computeGsplatCommonSource } from '../shader-lib/wgsl/chunks/gsplat/compute-gsplat-common.js';
import { computeGsplatTileIntersectSource } from '../shader-lib/wgsl/chunks/gsplat/compute-gsplat-tile-intersect.js';
import { GSplatTileComposite } from './gsplat-tile-composite.js';

/**
* @import { GraphNode } from '../graph-node.js'
Expand Down Expand Up @@ -128,11 +124,8 @@ class GSplatComputeLocalRenderer extends GSplatRenderer {
/** @type {FramePassGSplatComputeLocal} */
framePass;

/** @type {ShaderMaterial} */
_material;

/** @type {MeshInstance} */
meshInstance;
/** @type {GSplatTileComposite} */
tileComposite;

/** @type {boolean} */
_needsFramePassRegister = false;
Expand Down Expand Up @@ -264,16 +257,20 @@ class GSplatComputeLocalRenderer extends GSplatRenderer {
this._createChunkSortCompute();
this._createRasterizeCompute();
this.framePass = new FramePassGSplatComputeLocal(this);
this._createCompositeMaterial();
this.meshInstance = this._createMeshInstance();

const thisCamera = cameraNode.camera;
this.tileComposite = new GSplatTileComposite(device, node, (camera) => {
const renderMode = this.renderMode ?? 0;
return thisCamera.camera === camera && (renderMode & GSPLAT_FORWARD) !== 0;
});
}

destroy() {
this._unregisterFramePass();

if (this.renderMode) {
if (this.renderMode & GSPLAT_FORWARD) {
this.layer.removeMeshInstances([this.meshInstance], true);
this.layer.removeMeshInstances([this.tileComposite.meshInstance], true);
}
}

Expand Down Expand Up @@ -309,14 +306,13 @@ class GSplatComputeLocalRenderer extends GSplatRenderer {
this._chunkSortIndirectBuffer?.destroy();

this.outputTexture.destroy();
this._material.destroy();
this.meshInstance.destroy();
this.tileComposite.destroy();

super.destroy();
}

get material() {
return this._material;
return this.tileComposite.material;
}

setRenderMode(renderMode) {
Expand All @@ -325,12 +321,12 @@ class GSplatComputeLocalRenderer extends GSplatRenderer {
const isForward = (renderMode & GSPLAT_FORWARD) !== 0;

if (!wasForward && isForward) {
this.layer.addMeshInstances([this.meshInstance], true);
this.layer.addMeshInstances([this.tileComposite.meshInstance], true);
this._registerFramePass();
}

if (wasForward && !isForward) {
this.layer.removeMeshInstances([this.meshInstance], true);
this.layer.removeMeshInstances([this.tileComposite.meshInstance], true);
this._unregisterFramePass();
}

Expand Down Expand Up @@ -532,17 +528,22 @@ class GSplatComputeLocalRenderer extends GSplatRenderer {
// Reserve 3 indirect dispatch slots: 0=smallSort, 1=bucketSort, 2=rasterize
const indirectSlot = device.getIndirectDispatchSlot(3);

// Reserve 1 indirect draw slot for the tile-based composite
const drawSlot = device.getIndirectDrawSlot(1);

this.classifyCompute.setParameter('tileSplatCounts', this._tileSplatCountsBuffer);
this.classifyCompute.setParameter('smallTileList', this._smallTileListBuffer);
this.classifyCompute.setParameter('largeTileList', this._largeTileListBuffer);
this.classifyCompute.setParameter('rasterizeTileList', this._rasterizeTileListBuffer);
this.classifyCompute.setParameter('tileListCounts', this._tileListCountsBuffer);
this.classifyCompute.setParameter('indirectDispatchArgs', device.indirectDispatchBuffer);
this.classifyCompute.setParameter('largeTileOverflowBases', this._largeTileOverflowBasesBuffer);
this.classifyCompute.setParameter('indirectDrawArgs', device.indirectDrawBuffer);
this.classifyCompute.setParameter('numTiles', numTiles);
this.classifyCompute.setParameter('dispatchSlotOffset', indirectSlot * 3);
this.classifyCompute.setParameter('bufferCapacity', maxEntries);
this.classifyCompute.setParameter('maxWorkgroupsPerDim', device.limits.maxComputeWorkgroupsPerDimension || 65535);
this.classifyCompute.setParameter('drawSlot', drawSlot);

this.classifyCompute.setupDispatch(1, 1, 1);
device.computeDispatch([this.classifyCompute], 'GSplatLocalClassify');
Expand Down Expand Up @@ -608,6 +609,9 @@ class GSplatComputeLocalRenderer extends GSplatRenderer {
this.rasterizeCompute.setupIndirectDispatch(indirectSlot + 2);
device.computeDispatch([this.rasterizeCompute], 'GSplatLocalRasterize');

// Update tile composite for indirect draw
this.tileComposite.update(drawSlot, this.outputTexture, this._rasterizeTileListBuffer, numTilesX, width, height);

// Async readback: check if the buffer was large enough, grow multiplier if not.
// Reads totalEntries (from prefix sum) and totalOverflowUsed (from classify).
// The readback arrives 1-2 frames later; until then rendering may be degraded.
Expand Down Expand Up @@ -747,7 +751,8 @@ class GSplatComputeLocalRenderer extends GSplatRenderer {
new UniformFormat('numTiles', UNIFORMTYPE_UINT),
new UniformFormat('dispatchSlotOffset', UNIFORMTYPE_UINT),
new UniformFormat('bufferCapacity', UNIFORMTYPE_UINT),
new UniformFormat('maxWorkgroupsPerDim', UNIFORMTYPE_UINT)
new UniformFormat('maxWorkgroupsPerDim', UNIFORMTYPE_UINT),
new UniformFormat('drawSlot', UNIFORMTYPE_UINT)
]);

this._classifyBindGroupFormat = new BindGroupFormat(device, [
Expand All @@ -758,7 +763,8 @@ class GSplatComputeLocalRenderer extends GSplatRenderer {
new BindStorageBufferFormat('tileListCounts', SHADERSTAGE_COMPUTE),
new BindStorageBufferFormat('indirectDispatchArgs', SHADERSTAGE_COMPUTE),
new BindStorageBufferFormat('largeTileOverflowBases', SHADERSTAGE_COMPUTE),
new BindUniformBufferFormat('uniforms', SHADERSTAGE_COMPUTE)
new BindUniformBufferFormat('uniforms', SHADERSTAGE_COMPUTE),
new BindStorageBufferFormat('indirectDrawArgs', SHADERSTAGE_COMPUTE)
]);

const shader = new Shader(device, {
Expand Down Expand Up @@ -911,56 +917,6 @@ class GSplatComputeLocalRenderer extends GSplatRenderer {

this.rasterizeCompute = new Compute(device, shader, 'GSplatLocalRasterize');
}

/** @private */
_createCompositeMaterial() {
this._material = new ShaderMaterial({
uniqueName: 'GSplatLocalComputeComposite',
vertexGLSL: '#include "fullscreenQuadVS"',
fragmentGLSL: '#include "outputTex2DPS"',
vertexWGSL: '#include "fullscreenQuadVS"',
fragmentWGSL: '#include "outputTex2DPS"',
attributes: {
vertex_position: SEMANTIC_POSITION
}
});

this._material.setParameter('source', this.outputTexture);
this._material.blendType = BLEND_PREMULTIPLIED;
this._material.cull = CULLFACE_NONE;
this._material.depthWrite = false;
this._material.update();
}

/**
* @returns {MeshInstance} The compositing mesh instance.
* @private
*/
_createMeshInstance() {
const mesh = new Mesh(this.device);
mesh.setPositions(new Float32Array([
-1, -1,
1, -1,
1, 1,
-1, 1
]), 2);
mesh.setIndices(new Uint32Array([0, 1, 2, 0, 2, 3]));
mesh.update();

const meshInstance = new MeshInstance(mesh, this._material);
meshInstance.node = this.node;

const thisCamera = this.cameraNode.camera;
meshInstance.isVisibleFunc = (camera) => {
const renderMode = this.renderMode ?? 0;
if (thisCamera.camera === camera && (renderMode & GSPLAT_FORWARD)) {
return true;
}
return false;
};

return meshInstance;
}
}

export { GSplatComputeLocalRenderer };
104 changes: 104 additions & 0 deletions src/scene/gsplat-unified/gsplat-tile-composite.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import {
CULLFACE_NONE,
PRIMITIVE_TRIANGLES
} from '../../platform/graphics/constants.js';
import { BLEND_PREMULTIPLIED } from '../constants.js';
import { ShaderMaterial } from '../materials/shader-material.js';
import { MeshInstance } from '../mesh-instance.js';
import { Mesh } from '../mesh.js';

/**
* @import { GraphNode } from '../graph-node.js'
* @import { GraphicsDevice } from '../../platform/graphics/graphics-device.js'
* @import { StorageBuffer } from '../../platform/graphics/storage-buffer.js'
* @import { Texture } from '../../platform/graphics/texture.js'
*/

/**
* Manages the tile-based composite for the local compute gsplat renderer. Instead of blitting a
* full-screen quad, only tiles that contain splats are drawn using indirect rendering. The vertex
* shader procedurally generates tile quads from the built-in vertex index and a storage buffer
* of non-empty tile indices populated by the classify pass.
*
* @ignore
*/
class GSplatTileComposite {
/** @type {GraphicsDevice} */
device;

/** @type {ShaderMaterial} */
_material;

/** @type {Mesh} */
_mesh;

/** @type {MeshInstance} */
_meshInstance;

/**
* @param {GraphicsDevice} device - The graphics device.
* @param {GraphNode} node - The graph node for the mesh instance.
* @param {Function} isVisibleFunc - Visibility callback: `(camera) => boolean`.
*/
constructor(device, node, isVisibleFunc) {
this.device = device;

this._material = new ShaderMaterial({
uniqueName: 'GSplatTileComposite',
vertexWGSL: '#include "gsplatTileCompositeVS"',
fragmentWGSL: '#include "outputTex2DPS"'
});

this._material.blendType = BLEND_PREMULTIPLIED;
this._material.cull = CULLFACE_NONE;
this._material.depthWrite = false;
this._material.update();

this._mesh = new Mesh(device);
this._mesh.primitive[0].type = PRIMITIVE_TRIANGLES;
this._mesh.primitive[0].base = 0;
this._mesh.primitive[0].count = 0;
this._mesh.primitive[0].indexed = false;

this._meshInstance = new MeshInstance(this._mesh, this._material);
this._meshInstance.node = node;
this._meshInstance.instancingCount = 1;
this._meshInstance.isVisibleFunc = isVisibleFunc;
}

destroy() {
this._material.destroy();
this._mesh.destroy();
this._meshInstance.destroy();
}

get material() {
return this._material;
}

get meshInstance() {
return this._meshInstance;
}

/**
* Per-frame update: binds the indirect draw slot and updates material parameters.
*
* @param {number} drawSlot - The indirect draw slot reserved for this frame.
* @param {Texture} outputTexture - The compute-rasterized splat texture.
* @param {StorageBuffer} rasterizeTileList - Buffer of non-empty tile indices.
* @param {number} numTilesX - Number of tiles horizontally.
* @param {number} screenWidth - Viewport width in pixels.
* @param {number} screenHeight - Viewport height in pixels.
*/
update(drawSlot, outputTexture, rasterizeTileList, numTilesX, screenWidth, screenHeight) {
this._meshInstance.setIndirect(null, drawSlot, 1);

this._material.setParameter('source', outputTexture);
this._material.setParameter('rasterizeTileList', rasterizeTileList);
this._material.setParameter('numTilesX', numTilesX);
this._material.setParameter('screenWidth', screenWidth);
this._material.setParameter('screenHeight', screenHeight);
}
}

export { GSplatTileComposite };
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
// Tile classification: scans prefix-summed tile counts, builds small/large/rasterize
// tile lists, and writes indirect dispatch args for subsequent passes.
// tile lists, writes indirect dispatch args for subsequent passes, and writes indirect
// draw args for the tile-based composite.
// For large tiles (>4096 entries), assigns compact overflow scratch offsets within
// the shared tileEntries buffer (overflow region starts at totalEntries).
// Single workgroup (256 threads) — each thread processes ceil(numTiles/256) tiles.

import indirectCoreCS from '../common/comp/indirect-core.js';

export const computeGsplatLocalClassifySource = /* wgsl */`

${indirectCoreCS}

const MAX_TILE_ENTRIES: u32 = 4096u;
const CLASSIFY_WORKGROUP: u32 = 256u;

Expand All @@ -15,12 +21,14 @@ const CLASSIFY_WORKGROUP: u32 = 256u;
@group(0) @binding(4) var<storage, read_write> tileListCounts: array<atomic<u32>>;
@group(0) @binding(5) var<storage, read_write> indirectDispatchArgs: array<u32>;
@group(0) @binding(6) var<storage, read_write> largeTileOverflowBases: array<u32>;
@group(0) @binding(8) var<storage, read_write> indirectDrawArgs: array<DrawIndirectArgs>;

struct Uniforms {
numTiles: u32,
dispatchSlotOffset: u32,
bufferCapacity: u32,
maxWorkgroupsPerDim: u32,
drawSlot: u32,
}
@group(0) @binding(7) var<uniform> uniforms: Uniforms;

Expand Down Expand Up @@ -88,6 +96,9 @@ fn main(@builtin(local_invocation_index) localIdx: u32) {
indirectDispatchArgs[off + 6u] = (rasterizeCount + ry - 1u) / ry;
indirectDispatchArgs[off + 7u] = ry;
indirectDispatchArgs[off + 8u] = 1u;

// Indirect draw args for tile-based composite: 6 vertices per tile quad
indirectDrawArgs[uniforms.drawSlot] = DrawIndirectArgs(rasterizeCount * 6u, 1u, 0u, 0u, 0u);
}
}
`;
Loading