diff --git a/src/framework/parsers/ply.js b/src/framework/parsers/ply.js index 4cc41302d0c..8441fd962bd 100644 --- a/src/framework/parsers/ply.js +++ b/src/framework/parsers/ply.js @@ -294,7 +294,11 @@ class PlyParser { } else { readPly(response.body.getReader(), asset.data.elementFilter ?? defaultElementFilter) .then((response) => { - callback(null, new GSplatResource(this.device, new GSplatData(response))); + const gsplatData = new GSplatData(response, { + performZScale: asset.data.performZScale, + reorder: asset.data.reorder + }); + callback(null, new GSplatResource(this.device, gsplatData)); }) .catch((err) => { callback(err, null); diff --git a/src/platform/graphics/shader-utils.js b/src/platform/graphics/shader-utils.js index 1f6235f3885..d37ca80b688 100644 --- a/src/platform/graphics/shader-utils.js +++ b/src/platform/graphics/shader-utils.js @@ -209,7 +209,7 @@ class ShaderUtils { if (!device.isWebGPU) { - code = `precision ${precision} float;\n`; + code = `precision ${precision} float;\nprecision ${precision} int;`; if (device.isWebGL2) { code += `precision ${precision} sampler2DShadow;\n`; diff --git a/src/platform/graphics/texture.js b/src/platform/graphics/texture.js index f082af884e6..19ddb5a0bf2 100644 --- a/src/platform/graphics/texture.js +++ b/src/platform/graphics/texture.js @@ -781,7 +781,7 @@ class Texture { * - {@link TEXTURELOCK_READ} * - {@link TEXTURELOCK_WRITE} * Defaults to {@link TEXTURELOCK_WRITE}. - * @returns {Uint8Array|Uint16Array|Float32Array} A typed array containing the pixel data of + * @returns {Uint8Array|Uint16Array|Uint32Array|Float32Array} A typed array containing the pixel data of * the locked mip level. */ lock(options = {}) { diff --git a/src/scene/gsplat/gsplat-data.js b/src/scene/gsplat/gsplat-data.js index 23cf07c1f5e..a01e72f1a3c 100644 --- a/src/scene/gsplat/gsplat-data.js +++ b/src/scene/gsplat/gsplat-data.js @@ -80,14 +80,25 @@ class GSplatData { // /** // * @param {import('./ply-reader').PlyElement[]} elements - The elements. // * @param {boolean} [performZScale] - Whether to perform z scaling. + // * @param {object} [options] - The options. + // * @param {boolean} [options.performZScale] - Whether to perform z scaling. + // * @param {boolean} [options.reorder] - Whether to reorder the data. // */ - constructor(elements, performZScale = true) { + constructor(elements, options = {}) { this.elements = elements; this.vertexElement = elements.find(element => element.name === 'vertex'); - if (!this.isCompressed && performZScale) { - mat4.setScale(-1, -1, 1); - this.transform(mat4); + if (!this.isCompressed) { + if (options.performZScale ?? true) { + mat4.setScale(-1, -1, 1); + this.transform(mat4); + } + + // reorder uncompressed splats in morton order for better memory access + // efficiency during rendering + if (options.reorder ?? true) { + this.reorderData(); + } } } @@ -406,7 +417,87 @@ class GSplatData { storage: data[name] }; }) - }], false); + }], { + performZScale: false, + reorder: false + }); + } + + calcMortonOrder() { + const calcMinMax = (arr) => { + let min = arr[0]; + let max = arr[0]; + for (let i = 1; i < arr.length; i++) { + if (arr[i] < min) min = arr[i]; + if (arr[i] > max) max = arr[i]; + } + return { min, max }; + }; + + // https://fgiesen.wordpress.com/2009/12/13/decoding-morton-codes/ + const encodeMorton3 = (x, y, z) => { + const Part1By2 = (x) => { + x &= 0x000003ff; + x = (x ^ (x << 16)) & 0xff0000ff; + x = (x ^ (x << 8)) & 0x0300f00f; + x = (x ^ (x << 4)) & 0x030c30c3; + x = (x ^ (x << 2)) & 0x09249249; + return x; + }; + + return (Part1By2(z) << 2) + (Part1By2(y) << 1) + Part1By2(x); + }; + + const x = this.getProp('x'); + const y = this.getProp('y'); + const z = this.getProp('z'); + + const { min: minX, max: maxX } = calcMinMax(x); + const { min: minY, max: maxY } = calcMinMax(y); + const { min: minZ, max: maxZ } = calcMinMax(z); + + const sizeX = 1024 / (maxX - minX); + const sizeY = 1024 / (maxY - minY); + const sizeZ = 1024 / (maxZ - minZ); + + const morton = new Uint32Array(this.numSplats); + for (let i = 0; i < this.numSplats; i++) { + const ix = Math.floor((x[i] - minX) * sizeX); + const iy = Math.floor((y[i] - minY) * sizeY); + const iz = Math.floor((z[i] - minZ) * sizeZ); + morton[i] = encodeMorton3(ix, iy, iz); + } + + // generate indices + const indices = new Uint32Array(this.numSplats); + for (let i = 0; i < this.numSplats; i++) { + indices[i] = i; + } + // order splats by morton code + indices.sort((a, b) => morton[a] - morton[b]); + + return indices; + } + + reorderData() { + // calculate splat morton order + const order = this.calcMortonOrder(); + + const reorder = (data) => { + const result = new data.constructor(data.length); + + for (let i = 0; i < order.length; i++) { + result[i] = data[order[i]]; + } + + return result; + }; + + this.elements.forEach((element) => { + element.properties.forEach((property) => { + property.storage = reorder(property.storage); + }); + }); } } diff --git a/src/scene/gsplat/gsplat-instance.js b/src/scene/gsplat/gsplat-instance.js index 02c8280d2f2..b6da74aa0e0 100644 --- a/src/scene/gsplat/gsplat-instance.js +++ b/src/scene/gsplat/gsplat-instance.js @@ -1,11 +1,13 @@ import { Mat4 } from '../../core/math/mat4.js'; import { Vec3 } from '../../core/math/vec3.js'; -import { SEMANTIC_POSITION, TYPE_UINT32 } from '../../platform/graphics/constants.js'; +import { BUFFER_STATIC, PIXELFORMAT_R32U, SEMANTIC_ATTR13, TYPE_UINT32 } from '../../platform/graphics/constants.js'; import { DITHER_NONE } from '../constants.js'; import { MeshInstance } from '../mesh-instance.js'; import { Mesh } from '../mesh.js'; import { createGSplatMaterial } from './gsplat-material.js'; import { GSplatSorter } from './gsplat-sorter.js'; +import { VertexFormat } from '../../platform/graphics/vertex-format.js'; +import { VertexBuffer } from '../../platform/graphics/vertex-buffer.js'; const mat = new Mat4(); const cameraPosition = new Vec3(); @@ -26,8 +28,8 @@ class GSplatInstance { /** @type {import('../materials/material.js').Material} */ material; - /** @type {import('../../platform/graphics/vertex-buffer.js').VertexBuffer} */ - vb; + /** @type {import('../../platform/graphics/texture.js').Texture} */ + orderTexture; options = {}; @@ -58,45 +60,67 @@ class GSplatInstance { // not supported on WebGL1 const device = splat.device; - if (device.isWebGL1) - return; + + // create the order texture + this.orderTexture = this.splat.createTexture( + 'splatOrder', + PIXELFORMAT_R32U, + this.splat.evalTextureSize(this.splat.numSplats) + ); // material this.createMaterial(options); - const numSplats = splat.numSplats; - const indices = new Uint32Array(numSplats * 6); - const ids = new Uint32Array(numSplats * 4); - - for (let i = 0; i < numSplats; ++i) { - const base = i * 4; - - // 4 vertices - ids[base + 0] = i; - ids[base + 1] = i; - ids[base + 2] = i; - ids[base + 3] = i; - - // 2 triangles - const triBase = i * 6; - indices[triBase + 0] = base; - indices[triBase + 1] = base + 1; - indices[triBase + 2] = base + 2; - indices[triBase + 3] = base; - indices[triBase + 4] = base + 2; - indices[triBase + 5] = base + 3; + // number of quads to combine into a single instance. this is to increase occupancy + // in the vertex shader. + const splatInstanceSize = 128; + const numSplats = Math.ceil(splat.numSplats / splatInstanceSize) * splatInstanceSize; + const numSplatInstances = numSplats / splatInstanceSize; + + // specify the base splat index per instance + const indexData = new Uint32Array(numSplatInstances); + for (let i = 0; i < numSplatInstances; ++i) { + indexData[i] = i * splatInstanceSize; + } + + const vertexFormat = new VertexFormat(device, [ + { semantic: SEMANTIC_ATTR13, components: 1, type: TYPE_UINT32, asInt: true } + ]); + + const indicesVB = new VertexBuffer(device, vertexFormat, numSplatInstances, { + usage: BUFFER_STATIC, + data: indexData.buffer + }); + + // build the instance mesh + const meshPositions = new Float32Array(12 * splatInstanceSize); + const meshIndices = new Uint32Array(6 * splatInstanceSize); + for (let i = 0; i < splatInstanceSize; ++i) { + meshPositions.set([ + -2, -2, i, + 2, -2, i, + 2, 2, i, + -2, 2, i + ], i * 12); + + const b = i * 4; + meshIndices.set([ + 0 + b, 1 + b, 2 + b, 0 + b, 2 + b, 3 + b + ], i * 6); } - // mesh const mesh = new Mesh(device); - mesh.setVertexStream(SEMANTIC_POSITION, ids, 1, numSplats * 4, TYPE_UINT32, false, !device.isWebGL1); - mesh.setIndices(indices); + mesh.setPositions(meshPositions, 3); + mesh.setIndices(meshIndices); mesh.update(); + this.mesh = mesh; this.mesh.aabb.copy(splat.aabb); this.meshInstance = new MeshInstance(this.mesh, this.material); + this.meshInstance.setInstancing(indicesVB, true); this.meshInstance.gsplatInstance = this; + this.meshInstance.instancingCount = numSplatInstances; // clone centers to allow multiple instances of sorter this.centers = new Float32Array(splat.centers); @@ -104,7 +128,13 @@ class GSplatInstance { // create sorter if (!options.dither || options.dither === DITHER_NONE) { this.sorter = new GSplatSorter(); - this.sorter.init(mesh.vertexBuffer, this.centers, !this.splat.device.isWebGL1); + this.sorter.init(this.orderTexture, this.centers); + this.sorter.on('updated', (count) => { + // limit splat render count to exclude those behind the camera. + // NOTE: the last instance rendered may include non-existant splat + // data. this should be ok though as the data is filled with 0's. + this.meshInstance.instancingCount = Math.ceil(count / splatInstanceSize); + }); } } @@ -120,6 +150,7 @@ class GSplatInstance { createMaterial(options) { this.material = createGSplatMaterial(options); + this.material.setParameter('splatOrder', this.orderTexture); this.splat.setupMaterial(this.material); if (this.meshInstance) { this.meshInstance.material = this.material; diff --git a/src/scene/gsplat/gsplat-sorter.js b/src/scene/gsplat/gsplat-sorter.js index 3beec67a95b..0adcc5d2962 100644 --- a/src/scene/gsplat/gsplat-sorter.js +++ b/src/scene/gsplat/gsplat-sorter.js @@ -1,4 +1,5 @@ import { EventHandler } from "../../core/event-handler.js"; +import { TEXTURELOCK_READ } from "../../platform/graphics/constants.js"; // sort blind set of data function SortWorker() { @@ -26,6 +27,21 @@ function SortWorker() { let target; let countBuffer; + const binarySearch = (m, n, compare_fn) => { + while (m <= n) { + const k = (n + m) >> 1; + const cmp = compare_fn(k); + if (cmp > 0) { + m = k + 1; + } else if (cmp < 0) { + n = k - 1; + } else { + return k; + } + } + return ~m; + }; + const update = () => { if (!centers || !data || !cameraPosition || !cameraDirection) return; @@ -58,7 +74,10 @@ function SortWorker() { const numVertices = centers.length / 3; if (distances?.length !== numVertices) { distances = new Uint32Array(numVertices); - target = new Float32Array(numVertices * 4); // output 4 indices per splat (quad) + } + + if (target?.length !== data.length) { + target = data.slice(); } // calc min/max distance using bound @@ -104,17 +123,19 @@ function SortWorker() { countBuffer[i] += countBuffer[i - 1]; // Build the output array - const outputArray = new Uint32Array(target.buffer); for (let i = 0; i < numVertices; i++) { const distance = distances[i]; - const destIndex = (--countBuffer[distance]) * 4; - - outputArray[destIndex] = i; - outputArray[destIndex + 1] = i; - outputArray[destIndex + 2] = i; - outputArray[destIndex + 3] = i; + const destIndex = --countBuffer[distance]; + target[destIndex] = i; } + // find splat with distance 0 to limit rendering + const dist = i => distances[target[i]] / divider + minDist; + const findZero = () => { + const result = binarySearch(0, numVertices - 1, i => -dist(i)); + return Math.min(numVertices, Math.abs(result)); + }; + // swap const tmp = data; data = target; @@ -122,7 +143,8 @@ function SortWorker() { // send results self.postMessage({ - data: data.buffer + data: data.buffer, + count: dist(numVertices - 1) >= 0 ? findZero() : numVertices }, [data.buffer]); data = null; @@ -130,7 +152,7 @@ function SortWorker() { self.onmessage = (message) => { if (message.data.data) { - data = new Float32Array(message.data.data); + data = new Uint32Array(message.data.data); } if (message.data.centers) { centers = new Float32Array(message.data.centers); @@ -165,7 +187,7 @@ function SortWorker() { class GSplatSorter extends EventHandler { worker; - vertexBuffer; + orderTexture; constructor() { super(); @@ -176,19 +198,19 @@ class GSplatSorter extends EventHandler { this.worker.onmessage = (message) => { const newData = message.data.data; - const oldData = this.vertexBuffer.storage; + const oldData = this.orderTexture._levels[0].buffer; // send vertex storage to worker to start the next frame this.worker.postMessage({ data: oldData }, [oldData]); - // update vertex buffer data in the next event cycle so the above postMesssage - // call is queued before the relatively slow setData call below is invoked - setTimeout(() => { - this.vertexBuffer.setData(newData); - this.fire('updated'); - }); + // set new data directly on texture + this.orderTexture._levels[0] = new Uint32Array(newData); + this.orderTexture.upload(); + + // set new data directly on texture + this.fire('updated', message.data.count); }; } @@ -197,11 +219,16 @@ class GSplatSorter extends EventHandler { this.worker = null; } - init(vertexBuffer, centers) { - this.vertexBuffer = vertexBuffer; + init(orderTexture, centers) { + this.orderTexture = orderTexture; + + // get the texture's storage buffer and make a copy + const buf = this.orderTexture.lock({ + mode: TEXTURELOCK_READ + }).buffer.slice(); + this.orderTexture.unlock(); // send the initial buffer to worker - const buf = vertexBuffer.storage.slice(0); this.worker.postMessage({ data: buf, centers: centers.buffer diff --git a/src/scene/gsplat/gsplat.js b/src/scene/gsplat/gsplat.js index c3cc129c784..49d1d7af6cf 100644 --- a/src/scene/gsplat/gsplat.js +++ b/src/scene/gsplat/gsplat.js @@ -4,7 +4,7 @@ import { Quat } from '../../core/math/quat.js'; import { Vec2 } from '../../core/math/vec2.js'; import { Mat3 } from '../../core/math/mat3.js'; import { - ADDRESS_CLAMP_TO_EDGE, FILTER_NEAREST, PIXELFORMAT_R16F, PIXELFORMAT_R32F, PIXELFORMAT_RGBA16F, PIXELFORMAT_RGBA32F, + ADDRESS_CLAMP_TO_EDGE, FILTER_NEAREST, PIXELFORMAT_R16F, PIXELFORMAT_RGBA16F, PIXELFORMAT_RGBA32F, PIXELFORMAT_RGBA8 } from '../../platform/graphics/constants.js'; import { Texture } from '../../platform/graphics/texture.js'; @@ -24,14 +24,6 @@ class GSplat { numSplats; - /** - * True if half format should be used, false is float format should be used or undefined if none - * are available. - * - * @type {boolean|undefined} - */ - halfFormat; - /** @type {Texture} */ colorTexture; @@ -113,14 +105,13 @@ class GSplat { /** * Creates a new texture with the specified parameters. * - * @param {import('../../platform/graphics/graphics-device.js').GraphicsDevice} device - The graphics device to use for the texture creation. * @param {string} name - The name of the texture to be created. * @param {number} format - The pixel format of the texture. * @param {Vec2} size - The size of the texture in a Vec2 object, containing width (x) and height (y). * @returns {Texture} The created texture instance. */ - createTexture(device, name, format, size) { - return new Texture(device, { + createTexture(name, format, size) { + return new Texture(this.device, { name: name, width: size.x, height: size.y, @@ -232,7 +223,6 @@ class GSplat { */ updateTransformData(x, y, z, rot0, rot1, rot2, rot3, scale0, scale1, scale2) { - const { halfFormat } = this; const float2Half = FloatPacking.float2Half; if (!this.transformATexture) @@ -262,34 +252,17 @@ class GSplat { this.computeCov3d(mat, _s, cA, cB); - if (halfFormat) { - - dataA[i * 4 + 0] = float2Half(x[i]); - dataA[i * 4 + 1] = float2Half(y[i]); - dataA[i * 4 + 2] = float2Half(z[i]); - dataA[i * 4 + 3] = float2Half(cB.x); + dataA[i * 4 + 0] = x[i]; + dataA[i * 4 + 1] = y[i]; + dataA[i * 4 + 2] = z[i]; + dataA[i * 4 + 3] = cB.x; - dataB[i * 4 + 0] = float2Half(cA.x); - dataB[i * 4 + 1] = float2Half(cA.y); - dataB[i * 4 + 2] = float2Half(cA.z); - dataB[i * 4 + 3] = float2Half(cB.y); + dataB[i * 4 + 0] = float2Half(cA.x); + dataB[i * 4 + 1] = float2Half(cA.y); + dataB[i * 4 + 2] = float2Half(cA.z); + dataB[i * 4 + 3] = float2Half(cB.y); - dataC[i] = float2Half(cB.z); - - } else { - - dataA[i * 4 + 0] = x[i]; - dataA[i * 4 + 1] = y[i]; - dataA[i * 4 + 2] = z[i]; - dataA[i * 4 + 3] = cB.x; - - dataB[i * 4 + 0] = cA.x; - dataB[i * 4 + 1] = cA.y; - dataB[i * 4 + 2] = cA.z; - dataB[i * 4 + 3] = cB.y; - - dataC[i] = cB.z; - } + dataC[i] = float2Half(cB.z); } this.transformATexture.unlock(); diff --git a/src/scene/gsplat/shader-generator-gsplat.js b/src/scene/gsplat/shader-generator-gsplat.js index 511dc0436d3..33c83a91645 100644 --- a/src/scene/gsplat/shader-generator-gsplat.js +++ b/src/scene/gsplat/shader-generator-gsplat.js @@ -1,5 +1,5 @@ import { hashCode } from "../../core/hash.js"; -import { SEMANTIC_POSITION } from "../../platform/graphics/constants.js"; +import { SEMANTIC_ATTR13, SEMANTIC_POSITION } from "../../platform/graphics/constants.js"; import { ShaderUtils } from "../../platform/graphics/shader-utils.js"; import { DITHER_NONE } from "../constants.js"; import { shaderChunks } from "../shader-lib/chunks/chunks.js"; @@ -21,6 +21,7 @@ const splatCoreVS = ` uniform vec4 tex_params; uniform sampler2D splatColor; + uniform highp usampler2D splatOrder; uniform highp sampler2D transformA; uniform highp sampler2D transformB; uniform highp sampler2D transformC; @@ -29,27 +30,37 @@ const splatCoreVS = ` vec3 covA; vec3 covB; - attribute uint vertex_id; - ivec2 dataUV; - void evalDataUV() { + attribute vec3 vertex_position; + attribute uint vertex_id_attrib; - // turn vertex_id into int grid coordinates - ivec2 textureSize = ivec2(tex_params.xy); - vec2 invTextureSize = tex_params.zw; + uint splatId; + ivec2 splatUV; + void evalSplatUV() { + int bufferSizeX = int(tex_params.x); - int gridV = int(float(vertex_id) * invTextureSize.x); - int gridU = int(vertex_id) - gridV * textureSize.x; - dataUV = ivec2(gridU, gridV); + // sample order texture + uint orderId = vertex_id_attrib + uint(vertex_position.z); + ivec2 orderUV = ivec2( + int(orderId) % bufferSizeX, + int(orderId) / bufferSizeX + ); + + // calculate splatUV + splatId = texelFetch(splatOrder, orderUV, 0).r; + splatUV = ivec2( + int(splatId) % bufferSizeX, + int(splatId) / bufferSizeX + ); } vec4 getColor() { - return texelFetch(splatColor, dataUV, 0); + return texelFetch(splatColor, splatUV, 0); } void getTransform() { - vec4 tA = texelFetch(transformA, dataUV, 0); - vec4 tB = texelFetch(transformB, dataUV, 0); - vec4 tC = texelFetch(transformC, dataUV, 0); + vec4 tA = texelFetch(transformA, splatUV, 0); + vec4 tB = texelFetch(transformB, splatUV, 0); + vec4 tC = texelFetch(transformC, splatUV, 0); center = tA.xyz; covA = tB.xyz; @@ -57,7 +68,7 @@ const splatCoreVS = ` } vec3 evalCenter() { - evalDataUV(); + evalSplatUV(); // get data getTransform(); @@ -75,7 +86,7 @@ const splatCoreVS = ` return vec4(0.0, 0.0, 2.0, 1.0); } - id = float(vertex_id); + id = float(splatId); color = getColor(); mat3 Vrk = mat3( @@ -118,11 +129,7 @@ const splatCoreVS = ` return vec4(0.0, 0.0, 2.0, 1.0); } - int vertexIndex = int(gl_VertexID) % 4; - texCoord = vec2( - float((vertexIndex == 0 || vertexIndex == 3) ? -2 : 2), - float((vertexIndex == 0 || vertexIndex == 1) ? -2 : 2) - ); + texCoord = vertex_position.xy; splat_proj.xy += (texCoord.x * v1 + texCoord.y * v2) / viewport * splat_proj.w; return splat_proj; @@ -200,7 +207,8 @@ class GShaderGeneratorSplat { return ShaderUtils.createDefinition(device, { name: 'SplatShader', attributes: { - vertex_id: SEMANTIC_POSITION + vertex_position: SEMANTIC_POSITION, + vertex_id_attrib: SEMANTIC_ATTR13 }, vertexCode: vs, fragmentCode: fs