From 304ef5d59b0dbf29dbcf1d4daf5718aa9c840f06 Mon Sep 17 00:00:00 2001 From: Andreas Sundquist Date: Sat, 19 Jul 2025 11:18:42 -0700 Subject: [PATCH 1/3] Externalized SparkRenderer.maxPixelRadius and .minAlpha settings. Added to documentation. Added to examples/editor under Debug folder, moved sort32 there. --- examples/editor/index.html | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/editor/index.html b/examples/editor/index.html index 4f17afc..a37e418 100644 --- a/examples/editor/index.html +++ b/examples/editor/index.html @@ -480,8 +480,6 @@ stats.dom.style.display = value ? "block" : "none"; }); gui.add(spark.defaultView, "sortRadial").name("Radial sort").listen(); - spark.defaultView.sort32 = true; - gui.add(spark.defaultView, "sort32").name("Float32 sort").listen(); gui.add(grid, "opacity", 0, 1, 0.01).name("Grid opacity").listen(); gui.add({ logFocalDistance: 0.0, From 7658bc7877512f50e4a4555bd2bbf5a455916bc3 Mon Sep 17 00:00:00 2001 From: Andreas Sundquist Date: Tue, 22 Jul 2025 09:00:57 -0700 Subject: [PATCH 2/3] Added SplatEncoding concept, applied individually to SparkRenderer (rgbMin/Max, lnScaleMinMax), SplatMesh (additionally sh1Min/Max, sh2, sh3). Updated examples/editor to control settings via gui. Updated all related code to pipe along settings where needed. Added SparkRenderer.premultipliedAlpha to control blend mode and enable hyper-saturated color contributions by default. --- examples/editor/index.html | 54 +++++++-- rust/spark-internal-rs/src/lib.rs | 5 +- rust/spark-internal-rs/src/raycast.rs | 32 ++---- src/PackedSplats.ts | 116 +++++++++++++++++++- src/SparkRenderer.ts | 60 +++++++++- src/SplatLoader.ts | 18 ++- src/SplatMesh.ts | 53 ++++++++- src/antisplat.ts | 7 +- src/defines.ts | 1 - src/dyno/output.ts | 28 +++-- src/dyno/splats.ts | 9 +- src/index.ts | 3 + src/ksplat.ts | 13 ++- src/pcsogs.ts | 21 +++- src/shaders/splatDefines.glsl | 46 ++++++-- src/shaders/splatFragment.glsl | 8 +- src/shaders/splatVertex.glsl | 3 +- src/utils.ts | 152 +++++++++++++++++++------- src/worker.ts | 79 +++++++++---- 19 files changed, 561 insertions(+), 147 deletions(-) diff --git a/examples/editor/index.html b/examples/editor/index.html index a37e418..4963eb0 100644 --- a/examples/editor/index.html +++ b/examples/editor/index.html @@ -69,7 +69,7 @@ import * as THREE from "three"; import { OrbitControls } from "three/addons/controls/OrbitControls.js"; import { GUI } from "lil-gui"; - import { constructGrid, SparkControls, SparkRenderer, SplatMesh, textSplats, dyno, transcodeSpz, isMobile, isPcSogs } from "@sparkjsdev/spark"; + import { constructGrid, SparkControls, SparkRenderer, SplatMesh, textSplats, dyno, transcodeSpz, isMobile, isPcSogs, LN_SCALE_MIN, LN_SCALE_MAX } from "@sparkjsdev/spark"; import { getAssetFileURL } from "/examples/js/get-asset-url.js"; const scene = new THREE.Scene(); @@ -230,15 +230,21 @@ applyCameraFromQuery(); const cameraFolder = gui.addFolder("Camera"); - cameraFolder.add(camera.position, "x", -10, 10, 0.01).name("X").listen(); - cameraFolder.add(camera.position, "y", -10, 10, 0.01).name("Y").listen(); - cameraFolder.add(camera.position, "z", -10, 10, 0.01).name("Z").listen(); - const rotX = cameraFolder.add(camera.rotation, "x", -Math.PI, Math.PI, 0.01).name("RotateX").listen(); - const rotY = cameraFolder.add(camera.rotation, "y", -Math.PI, Math.PI, 0.01).name("RotateY").listen(); - const rotZ = cameraFolder.add(camera.rotation, "z", -Math.PI, Math.PI, 0.01).name("RotateZ").listen(); - cameraFolder.add(camera, "fov", 1, 179, 1).name("Fov Y degrees").listen().onChange((value) => { + const cameraPose = cameraFolder.addFolder("Camera Pose"); + cameraPose.add(camera.position, "x", -10, 10, 0.01).name("X").listen(); + cameraPose.add(camera.position, "y", -10, 10, 0.01).name("Y").listen(); + cameraPose.add(camera.position, "z", -10, 10, 0.01).name("Z").listen(); + const rotX = cameraPose.add(camera.rotation, "x", -Math.PI, Math.PI, 0.01).name("RotateX").listen(); + const rotY = cameraPose.add(camera.rotation, "y", -Math.PI, Math.PI, 0.01).name("RotateY").listen(); + const rotZ = cameraPose.add(camera.rotation, "z", -Math.PI, Math.PI, 0.01).name("RotateZ").listen(); + cameraPose.add(camera, "fov", 1, 179, 1).name("Fov Y degrees").listen().onChange((value) => { camera.updateProjectionMatrix(); }); + cameraPose.close(); + + function touch() { + spark.needsUpdate = true; + } // Progress bar functions const progressBar = document.getElementById('progress-bar'); @@ -383,6 +389,7 @@ } const init = url ? { url } : { fileBytes: fileBytes.slice(), fileName }; + init.splatEncoding = { ...splatEncoding }; const splatMesh = new SplatMesh(init); const translate = guiOptions.loadOffset * index splatMesh.position.set(translate, 0.5 * translate, 0.1 * translate); @@ -518,6 +525,37 @@ debugFolder.add(spark, "maxPixelRadius", 1, 1024, 1).name("Max pixel radius").listen(); debugFolder.add(spark, "minAlpha", 0, 1, 0.001).name("Min alpha").listen(); + debugFolder.add(spark, "premultipliedAlpha").name("Premultiplied alpha").listen(); + const accumFolder = debugFolder.addFolder("Accumulator encoding").close();; + accumFolder.add(spark.splatEncoding, "rgbMin", -1, 1, 0.1).name("RGB min").onChange(touch); + accumFolder.add(spark.splatEncoding, "rgbMax", 0, 4, 0.1).name("RGB max").onChange(touch); + accumFolder.add(spark.splatEncoding, "lnScaleMin", -14, -2.5, 0.1).name("Ln scale min").onChange(touch); + accumFolder.add(spark.splatEncoding, "lnScaleMax", -14, 14, 0.1).name("Ln scale max").onChange(touch); + + const splatEncoding = { + rgbMin: 0.0, + rgbMax: 1.0, + lnScaleMin: LN_SCALE_MIN, + lnScaleMax: LN_SCALE_MAX, + sh1Min: -1, + sh1Max: 1, + sh2Min: -1, + sh2Max: 1, + sh3Min: -1, + sh3Max: 1, + }; + const splatFolder = debugFolder.addFolder("SplatMesh encoding").close(); + splatFolder.add(splatEncoding, "rgbMin", -1, 1, 0.1).name("RGB min").onChange(touch); + splatFolder.add(splatEncoding, "rgbMax", 0, 4, 0.1).name("RGB max").onChange(touch); + splatFolder.add(splatEncoding, "lnScaleMin", -14, -2.5, 0.1).name("Ln scale min").onChange(touch); + splatFolder.add(splatEncoding, "lnScaleMax", -14, 14, 0.1).name("Ln scale max").onChange(touch); + splatFolder.add(splatEncoding, "sh1Min", -6, 6, 0.1).name("SH1 min").onChange(touch); + splatFolder.add(splatEncoding, "sh1Max", -6, 6, 0.1).name("SH1 max").onChange(touch); + splatFolder.add(splatEncoding, "sh2Min", -6, 6, 0.1).name("SH2 min").onChange(touch); + splatFolder.add(splatEncoding, "sh2Max", -6, 6, 0.1).name("SH2 max").onChange(touch); + splatFolder.add(splatEncoding, "sh3Min", -6, 6, 0.1).name("SH3 min").onChange(touch); + splatFolder.add(splatEncoding, "sh3Max", -6, 6, 0.1).name("SH3 max").onChange(touch); + const splatsFolder = secondGui.addFolder("Files"); const clipFolder = gui.addFolder("Clip Splats").close(); diff --git a/rust/spark-internal-rs/src/lib.rs b/rust/spark-internal-rs/src/lib.rs index b00414a..03ffdbf 100644 --- a/rust/spark-internal-rs/src/lib.rs +++ b/rust/spark-internal-rs/src/lib.rs @@ -82,6 +82,7 @@ pub fn raycast_splats( near: f32, far: f32, num_splats: u32, packed_splats: Uint32Array, raycast_ellipsoid: bool, + ln_scale_min: f32, ln_scale_max: f32, ) -> Float32Array { let mut distances = Vec::::new(); @@ -94,9 +95,9 @@ pub fn raycast_splats( subarray.copy_to(subbuffer); if raycast_ellipsoid { - raycast_ellipsoids(subbuffer, &mut distances, [origin_x, origin_y, origin_z], [dir_x, dir_y, dir_z], near, far); + raycast_ellipsoids(subbuffer, &mut distances, [origin_x, origin_y, origin_z], [dir_x, dir_y, dir_z], near, far, ln_scale_min, ln_scale_max); } else { - raycast_spheres(subbuffer, &mut distances, [origin_x, origin_y, origin_z], [dir_x, dir_y, dir_z], near, far); + raycast_spheres(subbuffer, &mut distances, [origin_x, origin_y, origin_z], [dir_x, dir_y, dir_z], near, far, ln_scale_min, ln_scale_max); } base += chunk_size; diff --git a/rust/spark-internal-rs/src/raycast.rs b/rust/spark-internal-rs/src/raycast.rs index a75d2b3..a1a0aab 100644 --- a/rust/spark-internal-rs/src/raycast.rs +++ b/rust/spark-internal-rs/src/raycast.rs @@ -2,30 +2,19 @@ use half::f16; const MIN_OPACITY: f32 = 0.1; -pub const LN_SCALE_MIN: f32 = -12.0; -pub const LN_SCALE_MAX: f32 = 9.0; -pub const LN_RESCALE: f32 = (LN_SCALE_MAX - LN_SCALE_MIN) / 254.0; // 1..=255 - -// pub fn encode_scale(scale: f32) -> u8 { -// if scale == 0.0 { -// 0 -// } else { -// // Allow scales below LN_SCALE_MIN to be encoded as 0, which signifies a 2DGS -// ((scale.ln() - LN_SCALE_MIN) / LN_RESCALE + 1.0).clamp(0.0, 255.0).round() as u8 -// } -// } - -pub fn decode_scale(scale: u8) -> f32 { +pub fn decode_scale(scale: u8, ln_scale_min: f32, ln_scale_max: f32) -> f32 { if scale == 0 { 0.0 } else { - (LN_SCALE_MIN + (scale - 1) as f32 * LN_RESCALE).exp() + let ln_scale_scale = (ln_scale_max - ln_scale_min) / 254.0; + (ln_scale_min + (scale - 1) as f32 * ln_scale_scale).exp() } } pub fn raycast_spheres( buffer: &[u32], distances: &mut Vec, origin: [f32; 3], dir: [f32; 3], near: f32, far: f32, + ln_scale_min: f32, ln_scale_max: f32, ) { let quad_a = vec3_dot(dir, dir); @@ -36,7 +25,7 @@ pub fn raycast_spheres( } let origin = vec3_sub(origin, extract_center(packed)); - let scale = extract_scale(packed); + let scale = extract_scale(packed, ln_scale_min, ln_scale_max); // Model the Gsplat as a sphere for faster approximate raycasting let radius = (scale[0] + scale[1] + scale[2]) / 3.0; @@ -58,6 +47,7 @@ pub fn raycast_spheres( pub fn raycast_ellipsoids( buffer: &[u32], distances: &mut Vec, origin: [f32; 3], dir: [f32; 3], near: f32, far: f32, + ln_scale_min: f32, ln_scale_max: f32, ) { for packed in buffer.chunks(4) { let opacity = ((packed[0] >> 24) as u8) as f32 / 255.0; @@ -66,7 +56,7 @@ pub fn raycast_ellipsoids( } let origin = vec3_sub(origin, extract_center(packed)); - let scale = extract_scale(packed); + let scale = extract_scale(packed, ln_scale_min, ln_scale_max); let quat = extract_quat(packed); let inv_quat = [-quat[0], -quat[1], -quat[2], quat[3]]; @@ -139,10 +129,10 @@ fn extract_center(packed: &[u32]) -> [f32; 3] { [x, y, z] } -fn extract_scale(packed: &[u32]) -> [f32; 3] { - let scale_x = decode_scale(packed[3] as u8); - let scale_y = decode_scale((packed[3] >> 8) as u8); - let scale_z = decode_scale((packed[3] >> 16) as u8); +fn extract_scale(packed: &[u32], ln_scale_min: f32, ln_scale_max: f32) -> [f32; 3] { + let scale_x = decode_scale(packed[3] as u8, ln_scale_min, ln_scale_max); + let scale_y = decode_scale((packed[3] >> 8) as u8, ln_scale_min, ln_scale_max); + let scale_z = decode_scale((packed[3] >> 16) as u8, ln_scale_min, ln_scale_max); [scale_x, scale_y, scale_z] } diff --git a/src/PackedSplats.ts b/src/PackedSplats.ts index 4cc8387..bb52d19 100644 --- a/src/PackedSplats.ts +++ b/src/PackedSplats.ts @@ -3,11 +3,18 @@ import { FullScreenQuad } from "three/addons/postprocessing/Pass.js"; import type { GsplatGenerator } from "./SplatGenerator"; import { type SplatFileType, SplatLoader, unpackSplats } from "./SplatLoader"; -import { SPLAT_TEX_HEIGHT, SPLAT_TEX_WIDTH } from "./defines"; +import { + LN_SCALE_MAX, + LN_SCALE_MIN, + SPLAT_TEX_HEIGHT, + SPLAT_TEX_WIDTH, +} from "./defines"; import { DynoProgram, DynoProgramTemplate, DynoUniform, + DynoVec2, + DynoVec4, dynoBlock, outputPackedSplat, } from "./dyno"; @@ -15,6 +22,32 @@ import { TPackedSplats, definePackedSplats } from "./dyno/splats"; import computeUvec4Template from "./shaders/computeUvec4.glsl"; import { getTextureSize, setPackedSplat, unpackSplat } from "./utils"; +export type SplatEncoding = { + rgbMin?: number; + rgbMax?: number; + lnScaleMin?: number; + lnScaleMax?: number; + sh1Min?: number; + sh1Max?: number; + sh2Min?: number; + sh2Max?: number; + sh3Min?: number; + sh3Max?: number; +}; + +export const DEFAULT_SPLAT_RANGES: SplatEncoding = { + rgbMin: 0, + rgbMax: 1, + lnScaleMin: LN_SCALE_MIN, + lnScaleMax: LN_SCALE_MAX, + sh1Min: -1, + sh1Max: 1, + sh2Min: -1, + sh2Max: 1, + sh3Min: -1, + sh3Max: 1, +}; + // Initialize a PackedSplats collection from source data via // url, fileBytes, or packedArray. Creates an empty array if none are set, // and splat data can be constructed using pushSplat()/setSplat(). The maximum @@ -47,6 +80,9 @@ export type PackedSplatsOptions = { construct?: (splats: PackedSplats) => Promise | void; // Additional splat data, such as spherical harmonics components (sh1, sh2, sh3). (default: {}) extra?: Record; + // Override the default splat encoding ranges for the PackedSplats. + // (default: undefined) + splatEncoding?: SplatEncoding; }; // A PackedSplats is a collection of Gaussian splats, packed into a format that @@ -61,6 +97,7 @@ export class PackedSplats { numSplats = 0; packedArray: Uint32Array | null = null; extra: Record; + splatEncoding?: SplatEncoding; initialized: Promise; isInitialized = false; @@ -75,10 +112,60 @@ export class PackedSplats { // A PackedSplats can be used in a dyno graph using the below property dyno: // const gsplat = dyno.readPackedSplats(this.dyno, dynoIndex); dyno: DynoUniform; + dynoRgbMinMaxLnScaleMinMax: DynoUniform<"vec4", "rgbMinMaxLnScaleMinMax">; + dynoSh1MinMax: DynoUniform<"vec2", "sh1MinMax">; + dynoSh2MinMax: DynoUniform<"vec2", "sh2MinMax">; + dynoSh3MinMax: DynoUniform<"vec2", "sh3MinMax">; constructor(options: PackedSplatsOptions = {}) { this.extra = {}; this.dyno = new DynoPackedSplats({ packedSplats: this }); + this.dynoRgbMinMaxLnScaleMinMax = new DynoVec4({ + key: "rgbMinMaxLnScaleMinMax", + value: new THREE.Vector4(0.0, 1.0, LN_SCALE_MIN, LN_SCALE_MAX), + update: (value) => { + value.set( + this.splatEncoding?.rgbMin ?? 0.0, + this.splatEncoding?.rgbMax ?? 1.0, + this.splatEncoding?.lnScaleMin ?? LN_SCALE_MIN, + this.splatEncoding?.lnScaleMax ?? LN_SCALE_MAX, + ); + return value; + }, + }); + this.dynoSh1MinMax = new DynoVec2({ + key: "sh1MinMax", + value: new THREE.Vector2(-1, 1), + update: (value) => { + value.set( + this.splatEncoding?.sh1Min ?? -1, + this.splatEncoding?.sh1Max ?? 1, + ); + return value; + }, + }); + this.dynoSh2MinMax = new DynoVec2({ + key: "sh2MinMax", + value: new THREE.Vector2(-1, 1), + update: (value) => { + value.set( + this.splatEncoding?.sh2Min ?? -1, + this.splatEncoding?.sh2Max ?? 1, + ); + return value; + }, + }); + this.dynoSh3MinMax = new DynoVec2({ + key: "sh3MinMax", + value: new THREE.Vector2(-1, 1), + update: (value) => { + value.set( + this.splatEncoding?.sh3Min ?? -1, + this.splatEncoding?.sh3Max ?? 1, + ); + return value; + }, + }); // The following line will be overridden by reinitialize() this.initialized = Promise.resolve(this); @@ -87,6 +174,10 @@ export class PackedSplats { reinitialize(options: PackedSplatsOptions) { this.isInitialized = false; + + this.extra = {}; + this.splatEncoding = options.splatEncoding; + if (options.url || options.fileBytes || options.construct) { // We need to initialize asynchronously given the options this.initialized = this.asyncInitialize(options).then(() => { @@ -131,6 +222,7 @@ export class PackedSplats { input: fileBytes, fileType: options.fileType, pathOrUrl: options.fileName ?? url, + splatEncoding: options.splatEncoding ?? DEFAULT_SPLAT_RANGES, }); this.initialize(unpacked); } @@ -239,7 +331,7 @@ export class PackedSplats { if (!this.packedArray || index >= this.numSplats) { throw new Error("Invalid index"); } - return unpackSplat(this.packedArray, index); + return unpackSplat(this.packedArray, index, this.splatEncoding); } // Set all PackedSplat components at index with the provided Gsplat attributes @@ -322,7 +414,7 @@ export class PackedSplats { return; } for (let i = 0; i < this.numSplats; ++i) { - const unpacked = unpackSplat(this.packedArray, i); + const unpacked = unpackSplat(this.packedArray, i, this.splatEncoding); callback( i, unpacked.center, @@ -473,7 +565,10 @@ export class PackedSplats { ({ index }) => { generator.inputs.index = index; const gsplat = generator.outputs.gsplat; - const output = outputPackedSplat(gsplat); + const output = outputPackedSplat( + gsplat, + this.dynoRgbMinMaxLnScaleMinMax, + ); return { output }; }, ); @@ -616,6 +711,7 @@ export class DynoPackedSplats extends DynoUniform< { texture: THREE.DataArrayTexture; numSplats: number; + rgbMinMaxLnScaleMinMax: THREE.Vector4; } > { packedSplats?: PackedSplats; @@ -628,11 +724,23 @@ export class DynoPackedSplats extends DynoUniform< value: { texture: PackedSplats.getEmpty(), numSplats: 0, + rgbMinMaxLnScaleMinMax: new THREE.Vector4( + 0, + 1, + LN_SCALE_MIN, + LN_SCALE_MAX, + ), }, update: (value) => { value.texture = this.packedSplats?.getTexture() ?? PackedSplats.getEmpty(); value.numSplats = this.packedSplats?.numSplats ?? 0; + value.rgbMinMaxLnScaleMinMax.set( + this.packedSplats?.splatEncoding?.rgbMin ?? 0, + this.packedSplats?.splatEncoding?.rgbMax ?? 1, + this.packedSplats?.splatEncoding?.lnScaleMin ?? LN_SCALE_MIN, + this.packedSplats?.splatEncoding?.lnScaleMax ?? LN_SCALE_MAX, + ); return value; }, }); diff --git a/src/SparkRenderer.ts b/src/SparkRenderer.ts index ee898fe..6743207 100644 --- a/src/SparkRenderer.ts +++ b/src/SparkRenderer.ts @@ -1,6 +1,10 @@ import * as THREE from "three"; -import { PackedSplats } from "./PackedSplats"; +import { + DEFAULT_SPLAT_RANGES, + PackedSplats, + type SplatEncoding, +} from "./PackedSplats"; import { RgbaArray } from "./RgbaArray"; import { SparkViewpoint, type SparkViewpointOptions } from "./SparkViewpoint"; import { type GeneratorMapping, SplatAccumulator } from "./SplatAccumulator"; @@ -8,6 +12,7 @@ import { SplatEdit } from "./SplatEdit"; import { SplatGenerator, SplatModifier } from "./SplatGenerator"; import { SplatGeometry } from "./SplatGeometry"; import { SplatMesh } from "./SplatMesh"; +import { LN_SCALE_MAX, LN_SCALE_MIN } from "./defines"; import { DynoVec3, DynoVec4, @@ -85,6 +90,11 @@ export type SparkRendererOptions = { * rendering and significantly reduces performance. */ renderer: THREE.WebGLRenderer; + /** + * Whether to use premultiplied alpha when accumulating splat RGB + * @default true + */ + premultipliedAlpha?: boolean; /** * Pass in a THREE.Clock to synchronize time-based effects across different * systems. Alternatively, you can set the SparkRenderer properties time and @@ -184,15 +194,22 @@ export type SparkRendererOptions = { * radial distance or Z-depth) */ view?: SparkViewpointOptions; + /** + * Override the default splat encoding ranges for the PackedSplats. + * (default: undefined) + */ + splatEncoding?: SplatEncoding; }; export class SparkRenderer extends THREE.Mesh { renderer: THREE.WebGLRenderer; + premultipliedAlpha: boolean; material: THREE.ShaderMaterial; uniforms: ReturnType; autoUpdate: boolean; preUpdate: boolean; + needsUpdate: boolean; originDistance: number; maxStdDev: number; maxPixelRadius: number; @@ -205,6 +222,7 @@ export class SparkRenderer extends THREE.Mesh { falloff: number; clipXY: number; focalAdjustment: number; + splatEncoding: SplatEncoding; splatTexture: null | { enable?: boolean; @@ -270,13 +288,20 @@ export class SparkRenderer extends THREE.Mesh { constructor(options: SparkRendererOptions) { const uniforms = SparkRenderer.makeUniforms(); const shaders = getShaders(); + const premultipliedAlpha = options.premultipliedAlpha ?? true; const material = new THREE.ShaderMaterial({ glslVersion: THREE.GLSL3, vertexShader: shaders.splatVertex, fragmentShader: shaders.splatFragment, uniforms, transparent: true, - blending: THREE.NormalBlending, + blending: premultipliedAlpha + ? THREE.CustomBlending + : THREE.NormalBlending, + blendSrc: premultipliedAlpha ? THREE.OneFactor : THREE.SrcAlphaFactor, + blendDst: premultipliedAlpha + ? THREE.OneMinusSrcAlphaFactor + : THREE.OneFactor, depthTest: true, depthWrite: false, side: THREE.DoubleSide, @@ -309,8 +334,10 @@ export class SparkRenderer extends THREE.Mesh { ); this.modifier = new SplatModifier(modifier); + this.premultipliedAlpha = premultipliedAlpha; this.autoUpdate = options.autoUpdate ?? true; this.preUpdate = options.preUpdate ?? false; + this.needsUpdate = false; this.originDistance = options.originDistance ?? 1; this.maxStdDev = options.maxStdDev ?? Math.sqrt(8.0); this.maxPixelRadius = options.maxPixelRadius ?? 512.0; @@ -323,6 +350,7 @@ export class SparkRenderer extends THREE.Mesh { this.falloff = options.falloff ?? 1.0; this.clipXY = options.clipXY ?? 1.4; this.focalAdjustment = options.focalAdjustment ?? 1.0; + this.splatEncoding = options.splatEncoding ?? { ...DEFAULT_SPLAT_RANGES }; this.active = new SplatAccumulator(); this.accumulatorCount = 1; @@ -401,10 +429,14 @@ export class SparkRenderer extends THREE.Mesh { splatTexMid: { value: 0.0 }, // Gsplat collection to render packedSplats: { type: "t", value: PackedSplats.getEmpty() }, + // Splat encoding ranges + rgbMinMaxLnScaleMinMax: { value: new THREE.Vector4() }, // Time in seconds for time-based effects time: { value: 0 }, // Delta time in seconds since last frame deltaTime: { value: 0 }, + // Whether to use premultiplied alpha when accumulating splat RGB + premultipliedAlpha: { value: true }, // Whether to encode Gsplat with linear RGB (for environment mapping) encodeLinear: { value: false }, // Debug flag that alternates each frame @@ -499,6 +531,20 @@ export class SparkRenderer extends THREE.Mesh { if (isNewFrame) { // Keep these uniforms the same for both eyes if in WebXR + const blending = this.premultipliedAlpha + ? THREE.CustomBlending + : THREE.NormalBlending; + if (blending !== this.material.blending) { + this.material.blending = blending; + this.material.blendSrc = this.premultipliedAlpha + ? THREE.OneFactor + : THREE.SrcAlphaFactor; + this.material.blendDst = this.premultipliedAlpha + ? THREE.OneMinusSrcAlphaFactor + : THREE.OneFactor; + this.material.needsUpdate = true; + } + this.uniforms.premultipliedAlpha.value = this.premultipliedAlpha; this.uniforms.time.value = time; this.uniforms.deltaTime.value = deltaTime; // Alternating debug flag that can aid in visual debugging @@ -598,6 +644,12 @@ export class SparkRenderer extends THREE.Mesh { const { accumulator, geometry } = this.viewpoint.display; this.uniforms.numSplats.value = accumulator.splats.numSplats; this.uniforms.packedSplats.value = accumulator.splats.getTexture(); + this.uniforms.rgbMinMaxLnScaleMinMax.value.set( + accumulator.splats.splatEncoding?.rgbMin ?? 0.0, + accumulator.splats.splatEncoding?.rgbMax ?? 1.0, + accumulator.splats.splatEncoding?.lnScaleMin ?? LN_SCALE_MIN, + accumulator.splats.splatEncoding?.lnScaleMax ?? LN_SCALE_MAX, + ); this.geometry = geometry; } else { // No Gsplats to display for this viewpoint yet @@ -691,6 +743,7 @@ export class SparkRenderer extends THREE.Mesh { const isVisible = object.generator && visibleGenHash.has(object.uuid); const numSplats = isVisible ? object.numSplats : 0; if ( + this.needsUpdate || object.generator !== current?.generator || numSplats !== current?.count ) { @@ -708,9 +761,11 @@ export class SparkRenderer extends THREE.Mesh { // Check if we need any update at all const needsUpdate = + this.needsUpdate || originUpdate || generators.length !== activeMapping.size || generators.some((g) => g.version !== activeMapping.get(g)?.version); + this.needsUpdate = false; let accumulator: SplatAccumulator | null = null; if (needsUpdate) { @@ -782,6 +837,7 @@ export class SparkRenderer extends THREE.Mesh { // Generate the Gsplats according to the mapping that need updating accumulator.ensureGenerate(maxSplats); + accumulator.splats.splatEncoding = { ...this.splatEncoding }; const generated = accumulator.generateSplats({ renderer: this.renderer, modifier: this.modifier, diff --git a/src/SplatLoader.ts b/src/SplatLoader.ts index d77bde7..933dc91 100644 --- a/src/SplatLoader.ts +++ b/src/SplatLoader.ts @@ -1,6 +1,6 @@ import { unzipSync } from "fflate"; import { FileLoader, Loader, type LoadingManager } from "three"; -import { PackedSplats } from "./PackedSplats"; +import { PackedSplats, type SplatEncoding } from "./PackedSplats"; import { SplatMesh } from "./SplatMesh"; import { PlyReader } from "./ply"; import { withWorker } from "./splatWorker"; @@ -396,11 +396,13 @@ export async function unpackSplats({ extraFiles, fileType, pathOrUrl, + splatEncoding, }: { input: Uint8Array | ArrayBuffer; extraFiles?: Record; fileType?: SplatFileType; pathOrUrl?: string; + splatEncoding?: SplatEncoding; }): Promise<{ packedArray: Uint32Array; numSplats: number; @@ -422,7 +424,11 @@ export async function unpackSplats({ await ply.parseHeader(); const numSplats = ply.numSplats; const maxSplats = getTextureSize(numSplats).maxSplats; - const args = { fileBytes, packedArray: new Uint32Array(maxSplats * 4) }; + const args = { + fileBytes, + packedArray: new Uint32Array(maxSplats * 4), + splatEncoding, + }; return await withWorker(async (worker) => { const { packedArray, numSplats, extra } = (await worker.call( "unpackPly", @@ -441,6 +447,7 @@ export async function unpackSplats({ "decodeSpz", { fileBytes, + splatEncoding, }, )) as { packedArray: Uint32Array; @@ -456,6 +463,7 @@ export async function unpackSplats({ "decodeAntiSplat", { fileBytes, + splatEncoding, }, )) as { packedArray: Uint32Array; numSplats: number }; return { packedArray, numSplats }; @@ -465,7 +473,7 @@ export async function unpackSplats({ return await withWorker(async (worker) => { const { packedArray, numSplats, extra } = (await worker.call( "decodeKsplat", - { fileBytes }, + { fileBytes, splatEncoding }, )) as { packedArray: Uint32Array; numSplats: number; @@ -478,7 +486,7 @@ export async function unpackSplats({ return await withWorker(async (worker) => { const { packedArray, numSplats, extra } = (await worker.call( "decodePcSogs", - { fileBytes, extraFiles }, + { fileBytes, extraFiles, splatEncoding }, )) as { packedArray: Uint32Array; numSplats: number; @@ -491,7 +499,7 @@ export async function unpackSplats({ return await withWorker(async (worker) => { const { packedArray, numSplats, extra } = (await worker.call( "decodePcSogsZip", - { fileBytes }, + { fileBytes, splatEncoding }, )) as { packedArray: Uint32Array; numSplats: number; diff --git a/src/SplatMesh.ts b/src/SplatMesh.ts index e34579f..9b735b1 100644 --- a/src/SplatMesh.ts +++ b/src/SplatMesh.ts @@ -1,7 +1,11 @@ import * as THREE from "three"; import init_wasm, { raycast_splats } from "spark-internal-rs"; -import { PackedSplats } from "./PackedSplats"; +import { + DEFAULT_SPLAT_RANGES, + PackedSplats, + type SplatEncoding, +} from "./PackedSplats"; import { type RgbaArray, readRgbaArray } from "./RgbaArray"; import { SplatEdit, SplatEditSdf, SplatEdits } from "./SplatEdit"; import { @@ -11,6 +15,7 @@ import { } from "./SplatGenerator"; import type { SplatFileType } from "./SplatLoader"; import type { SplatSkinning } from "./SplatSkinning"; +import { LN_SCALE_MAX, LN_SCALE_MIN } from "./defines"; import { DynoFloat, DynoUsampler2DArray, @@ -27,6 +32,7 @@ import { mul, normalize, readPackedSplat, + split, splitGsplat, sub, unindent, @@ -80,6 +86,9 @@ export type SplatMeshOptions = { // Gsplat modifier to apply in world-space after transformations. // (default: undefined) worldModifier?: GsplatModifier; + // Override the default splat encoding ranges for the PackedSplats. + // (default: undefined) + splatEncoding?: SplatEncoding; }; export type SplatMeshContext = { @@ -183,6 +192,9 @@ export class SplatMesh extends SplatGenerator { }); this.packedSplats = options.packedSplats ?? new PackedSplats(); + this.packedSplats.splatEncoding = options.splatEncoding ?? { + ...DEFAULT_SPLAT_RANGES, + }; this.numSplats = this.packedSplats.numSplats; this.editable = options.editable ?? true; this.onFrame = options.onFrame; @@ -226,8 +238,15 @@ export class SplatMesh extends SplatGenerator { } async asyncInitialize(options: SplatMeshOptions) { - const { url, fileBytes, fileType, fileName, maxSplats, constructSplats } = - options; + const { + url, + fileBytes, + fileType, + fileName, + maxSplats, + constructSplats, + splatEncoding, + } = options; if (url || fileBytes || constructSplats) { const packedSplatsOptions = { url, @@ -236,6 +255,7 @@ export class SplatMesh extends SplatGenerator { fileName, maxSplats, construct: constructSplats, + splatEncoding, }; this.packedSplats.reinitialize(packedSplatsOptions); } @@ -321,13 +341,32 @@ export class SplatMesh extends SplatGenerator { const { center } = splitGsplat(gsplat).outputs; const viewDir = normalize(sub(center, viewCenterInObject)); + function rescaleSh( + sNorm: DynoVal<"vec3">, + minMax: DynoVal<"vec2">, + ) { + const { x: min, y: max } = split(minMax).outputs; + const mid = mul(add(min, max), dynoConst("float", 0.5)); + const scale = mul(sub(max, min), dynoConst("float", 0.5)); + return add(mid, mul(sNorm, scale)); + } + // Evaluate Spherical Harmonics - let rgb = evaluateSH1(gsplat, sh1Texture, viewDir); + const sh1Snorm = evaluateSH1(gsplat, sh1Texture, viewDir); + let rgb = rescaleSh(sh1Snorm, this.packedSplats.dynoSh1MinMax); if (this.maxSh >= 2 && sh2Texture) { - rgb = add(rgb, evaluateSH2(gsplat, sh2Texture, viewDir)); + const sh2Snorm = evaluateSH2(gsplat, sh2Texture, viewDir); + rgb = add( + rgb, + rescaleSh(sh2Snorm, this.packedSplats.dynoSh2MinMax), + ); } if (this.maxSh >= 3 && sh3Texture) { - rgb = add(rgb, evaluateSH3(gsplat, sh3Texture, viewDir)); + const sh3Snorm = evaluateSH3(gsplat, sh3Texture, viewDir); + rgb = add( + rgb, + rescaleSh(sh3Snorm, this.packedSplats.dynoSh3MinMax), + ); } // Flash off for 0.3 / 1.0 sec for debugging @@ -539,6 +578,8 @@ export class SplatMesh extends SplatGenerator { this.packedSplats.numSplats, this.packedSplats.packedArray, RAYCAST_ELLIPSOID, + this.packedSplats.splatEncoding?.lnScaleMin ?? LN_SCALE_MIN, + this.packedSplats.splatEncoding?.lnScaleMax ?? LN_SCALE_MAX, ); for (const distance of distances) { diff --git a/src/antisplat.ts b/src/antisplat.ts index d24017d..4943d56 100644 --- a/src/antisplat.ts +++ b/src/antisplat.ts @@ -1,3 +1,4 @@ +import type { SplatEncoding } from "./PackedSplats"; import { computeMaxSplats, setPackedSplat } from "./utils"; export function decodeAntiSplat( @@ -65,7 +66,10 @@ export function decodeAntiSplat( } } -export function unpackAntiSplat(fileBytes: Uint8Array): { +export function unpackAntiSplat( + fileBytes: Uint8Array, + splatEncoding: SplatEncoding, +): { packedArray: Uint32Array; numSplats: number; } { @@ -113,6 +117,7 @@ export function unpackAntiSplat(fileBytes: Uint8Array): { r, g, b, + splatEncoding, ); }, ); diff --git a/src/defines.ts b/src/defines.ts index 5b8e17a..f7eafa3 100644 --- a/src/defines.ts +++ b/src/defines.ts @@ -6,7 +6,6 @@ export const LN_SCALE_MIN = -12.0; export const LN_SCALE_MAX = 9.0; -export const LN_RESCALE = (LN_SCALE_MAX - LN_SCALE_MIN) / 254.0; // 1..=255 export const SCALE_MIN = Math.exp(LN_SCALE_MIN); export const SCALE_MAX = Math.exp(LN_SCALE_MAX); diff --git a/src/dyno/output.ts b/src/dyno/output.ts index 494295b..379a8d9 100644 --- a/src/dyno/output.ts +++ b/src/dyno/output.ts @@ -1,3 +1,4 @@ +import * as THREE from "three"; import { Dyno, unindentLines } from "./base"; import { Gsplat, defineGsplat } from "./splats"; import { @@ -7,30 +8,41 @@ import { type HasDynoOut, } from "./value"; -export const outputPackedSplat = (gsplat: DynoVal) => - new OutputPackedSplat({ gsplat }); +export const outputPackedSplat = ( + gsplat: DynoVal, + rgbMinMaxLnScaleMinMax: DynoVal<"vec4">, +) => new OutputPackedSplat({ gsplat, rgbMinMaxLnScaleMinMax }); export const outputRgba8 = (rgba8: DynoVal<"vec4">) => new OutputRgba8({ rgba8 }); export class OutputPackedSplat - extends Dyno<{ gsplat: typeof Gsplat }, { output: "uvec4" }> + extends Dyno< + { gsplat: typeof Gsplat; rgbMinMaxLnScaleMinMax: "vec4" }, + { output: "uvec4" } + > implements HasDynoOut<"uvec4"> { - constructor({ gsplat }: { gsplat?: DynoVal }) { + constructor({ + gsplat, + rgbMinMaxLnScaleMinMax, + }: { + gsplat?: DynoVal; + rgbMinMaxLnScaleMinMax?: DynoVal<"vec4">; + }) { super({ - inTypes: { gsplat: Gsplat }, - inputs: { gsplat }, + inTypes: { gsplat: Gsplat, rgbMinMaxLnScaleMinMax: "vec4" }, + inputs: { gsplat, rgbMinMaxLnScaleMinMax }, globals: () => [defineGsplat], statements: ({ inputs, outputs }) => { const { output } = outputs; if (!output) { return []; } - const { gsplat } = inputs; + const { gsplat, rgbMinMaxLnScaleMinMax } = inputs; if (gsplat) { return unindentLines(` if (isGsplatActive(${gsplat}.flags)) { - ${output} = packSplat(${gsplat}.center, ${gsplat}.scales, ${gsplat}.quaternion, ${gsplat}.rgba); + ${output} = packSplatEncoding(${gsplat}.center, ${gsplat}.scales, ${gsplat}.quaternion, ${gsplat}.rgba, ${rgbMinMaxLnScaleMinMax}); } else { ${output} = uvec4(0u, 0u, 0u, 0u); } diff --git a/src/dyno/splats.ts b/src/dyno/splats.ts index 51abe39..309a601 100644 --- a/src/dyno/splats.ts +++ b/src/dyno/splats.ts @@ -118,6 +118,7 @@ export const definePackedSplats = unindent(` struct PackedSplats { usampler2DArray texture; int numSplats; + vec4 rgbMinMaxLnScaleMinMax; }; `); @@ -137,10 +138,10 @@ export class NumPackedSplats extends UnaryOp< } const defineReadPackedSplat = unindent(` - bool readPackedSplat(usampler2DArray texture, int numSplats, int index, out Gsplat gsplat) { + bool readPackedSplat(usampler2DArray texture, int numSplats, vec4 rgbMinMaxLnScaleMinMax, int index, out Gsplat gsplat) { if ((index >= 0) && (index < numSplats)) { uvec4 packed = texelFetch(texture, splatTexCoord(index), 0); - unpackSplat(packed, gsplat.center, gsplat.scales, gsplat.quaternion, gsplat.rgba); + unpackSplatEncoding(packed, gsplat.center, gsplat.scales, gsplat.quaternion, gsplat.rgba, rgbMinMaxLnScaleMinMax); return true; } else { return false; @@ -173,7 +174,7 @@ export class ReadPackedSplat let statements: string[]; if (packedSplats && index) { statements = unindentLines(` - if (readPackedSplat(${packedSplats}.texture, ${packedSplats}.numSplats, ${index}, ${gsplat})) { + if (readPackedSplat(${packedSplats}.texture, ${packedSplats}.numSplats, ${packedSplats}.rgbMinMaxLnScaleMinMax, ${index}, ${gsplat})) { bool zeroSize = all(equal(${gsplat}.scales, vec3(0.0, 0.0, 0.0))); ${gsplat}.flags = zeroSize ? 0u : GSPLAT_FLAG_ACTIVE; } else { @@ -238,7 +239,7 @@ export class ReadPackedSplatRange statements = unindentLines(` ${gsplat}.flags = 0u; if ((${index} >= ${base}) && (${index} < (${base} + ${count}))) { - if (readPackedSplat(${packedSplats}.texture, ${packedSplats}.numSplats, ${index}, ${gsplat})) { + if (readPackedSplat(${packedSplats}.texture, ${packedSplats}.numSplats, ${packedSplats}.rgbMinMaxLnScaleMinMax, ${index}, ${gsplat})) { bool zeroSize = all(equal(${gsplat}.scales, vec3(0.0, 0.0, 0.0))); ${gsplat}.flags = zeroSize ? 0u : GSPLAT_FLAG_ACTIVE; } diff --git a/src/index.ts b/src/index.ts index 01c7d15..0856dc5 100644 --- a/src/index.ts +++ b/src/index.ts @@ -90,3 +90,6 @@ export { unpackSplat, } from "./utils"; export * as utils from "./utils"; + +export { LN_SCALE_MIN, LN_SCALE_MAX } from "./defines"; +export * as defines from "./defines"; diff --git a/src/ksplat.ts b/src/ksplat.ts index 33a892c..17b1be3 100644 --- a/src/ksplat.ts +++ b/src/ksplat.ts @@ -1,3 +1,4 @@ +import type { SplatEncoding } from "./PackedSplats"; import { computeMaxSplats, encodeSh1Rgb, @@ -352,7 +353,10 @@ export function decodeKsplat( } } -export function unpackKsplat(fileBytes: Uint8Array): { +export function unpackKsplat( + fileBytes: Uint8Array, + splatEncoding: SplatEncoding, +): { packedArray: Uint32Array; numSplats: number; extra: Record; @@ -593,6 +597,7 @@ export function unpackKsplat(fileBytes: Uint8Array): { r, g, b, + splatEncoding, ); if (sphericalHarmonicsDegree >= 1) { @@ -603,7 +608,7 @@ export function unpackKsplat(fileBytes: Uint8Array): { for (const [i, key] of sh1Index.entries()) { sh1[i] = getSh(splatOffset, key); } - encodeSh1Rgb(extra.sh1 as Uint32Array, i, sh1); + encodeSh1Rgb(extra.sh1 as Uint32Array, i, sh1, splatEncoding); } if (sh2) { if (!extra.sh2) { @@ -612,7 +617,7 @@ export function unpackKsplat(fileBytes: Uint8Array): { for (const [i, key] of sh2Index.entries()) { sh2[i] = getSh(splatOffset, key); } - encodeSh2Rgb(extra.sh2 as Uint32Array, i, sh2); + encodeSh2Rgb(extra.sh2 as Uint32Array, i, sh2, splatEncoding); } if (sh3) { if (!extra.sh3) { @@ -621,7 +626,7 @@ export function unpackKsplat(fileBytes: Uint8Array): { for (const [i, key] of sh3Index.entries()) { sh3[i] = getSh(splatOffset, key); } - encodeSh3Rgb(extra.sh3 as Uint32Array, i, sh3); + encodeSh3Rgb(extra.sh3 as Uint32Array, i, sh3, splatEncoding); } } } diff --git a/src/pcsogs.ts b/src/pcsogs.ts index 7eb090d..ac1b3eb 100644 --- a/src/pcsogs.ts +++ b/src/pcsogs.ts @@ -1,4 +1,5 @@ import { unzip } from "fflate"; +import type { SplatEncoding } from "./PackedSplats"; import { type PcSogsJson, tryPcSogsZip } from "./SplatLoader"; import { computeMaxSplats, @@ -14,6 +15,7 @@ import { export async function unpackPcSogs( json: PcSogsJson, extraFiles: Record, + splatEncoding: SplatEncoding, ): Promise<{ packedArray: Uint32Array; numSplats: number; @@ -72,6 +74,7 @@ export async function unpackPcSogs( Math.exp(x), Math.exp(y), Math.exp(z), + splatEncoding, ); } }, @@ -116,7 +119,7 @@ export async function unpackPcSogs( const g = SH_C0 * dc1 + 0.5; const b = SH_C0 * dc2 + 0.5; const a = 1.0 / (1.0 + Math.exp(-opa)); - setPackedSplatRgba(packedArray, i, r, g, b, a); + setPackedSplatRgba(packedArray, i, r, g, b, a, splatEncoding); } }, ); @@ -178,9 +181,12 @@ export async function unpackPcSogs( } } - if (useSH1) encodeSh1Rgb(extra.sh1 as Uint32Array, i, sh1); - if (useSH2) encodeSh2Rgb(extra.sh2 as Uint32Array, i, sh2); - if (useSH3) encodeSh3Rgb(extra.sh3 as Uint32Array, i, sh3); + if (useSH1) + encodeSh1Rgb(extra.sh1 as Uint32Array, i, sh1, splatEncoding); + if (useSH2) + encodeSh2Rgb(extra.sh2 as Uint32Array, i, sh2, splatEncoding); + if (useSH3) + encodeSh3Rgb(extra.sh3 as Uint32Array, i, sh3, splatEncoding); } }); promises.push(shNPromise); @@ -248,7 +254,10 @@ async function decodeImageRgba(fileBytes: ArrayBuffer) { return rgba; } -export async function unpackPcSogsZip(fileBytes: Uint8Array): Promise<{ +export async function unpackPcSogsZip( + fileBytes: Uint8Array, + splatEncoding: SplatEncoding, +): Promise<{ packedArray: Uint32Array; numSplats: number; extra: Record; @@ -300,5 +309,5 @@ export async function unpackPcSogsZip(fileBytes: Uint8Array): Promise<{ extraFiles[name] = unzipped[full]; } - return await unpackPcSogs(json, extraFiles); + return await unpackPcSogs(json, extraFiles, splatEncoding); } diff --git a/src/shaders/splatDefines.glsl b/src/shaders/splatDefines.glsl index d7b8bae..79e10c1 100644 --- a/src/shaders/splatDefines.glsl +++ b/src/shaders/splatDefines.glsl @@ -1,6 +1,5 @@ const float LN_SCALE_MIN = -12.0; const float LN_SCALE_MAX = 9.0; -const float LN_RESCALE = (LN_SCALE_MAX - LN_SCALE_MIN) / 254.0; // 1..=255 const uint SPLAT_TEX_WIDTH_BITS = 11u; const uint SPLAT_TEX_HEIGHT_BITS = 11u; @@ -199,8 +198,13 @@ vec4 decodeQuatOctXy88R8(uint encoded) { // } // Pack a Gsplat into a uvec4 -uvec4 packSplat(vec3 center, vec3 scales, vec4 quaternion, vec4 rgba) { - uvec4 uRgba = uvec4(round(clamp(rgba * 255.0, 0.0, 255.0))); +uvec4 packSplatEncoding( + vec3 center, vec3 scales, vec4 quaternion, vec4 rgba, vec4 rgbMinMaxLnScaleMinMax +) { + float rgbMin = rgbMinMaxLnScaleMinMax.x; + float rgbMax = rgbMinMaxLnScaleMinMax.y; + vec3 encRgb = (rgba.rgb - vec3(rgbMin)) / (rgbMax - rgbMin); + uvec4 uRgba = uvec4(round(clamp(vec4(encRgb, rgba.a) * 255.0, 0.0, 255.0))); uint uQuat = encodeQuatOctXy88R8(quaternion); // uint uQuat = encodeQuatXyz888(quaternion); @@ -208,10 +212,13 @@ uvec4 packSplat(vec3 center, vec3 scales, vec4 quaternion, vec4 rgba) { uvec3 uQuat3 = uvec3(uQuat & 0xffu, (uQuat >> 8u) & 0xffu, (uQuat >> 16u) & 0xffu); // Encode scales in three uint8s, where 0=>0.0 and 1..=255 stores log scale + float lnScaleMin = rgbMinMaxLnScaleMinMax.z; + float lnScaleMax = rgbMinMaxLnScaleMinMax.w; + float lnScaleScale = 254.0 / (lnScaleMax - lnScaleMin); uvec3 uScales = uvec3( - (scales.x == 0.0) ? 0u : uint(round(clamp((log(scales.x) - LN_SCALE_MIN) / LN_RESCALE, 0.0, 254.0))) + 1u, - (scales.y == 0.0) ? 0u : uint(round(clamp((log(scales.y) - LN_SCALE_MIN) / LN_RESCALE, 0.0, 254.0))) + 1u, - (scales.z == 0.0) ? 0u : uint(round(clamp((log(scales.z) - LN_SCALE_MIN) / LN_RESCALE, 0.0, 254.0))) + 1u + (scales.x == 0.0) ? 0u : uint(round(clamp((log(scales.x) - lnScaleMin) * lnScaleScale, 0.0, 254.0))) + 1u, + (scales.y == 0.0) ? 0u : uint(round(clamp((log(scales.y) - lnScaleMin) * lnScaleScale, 0.0, 254.0))) + 1u, + (scales.z == 0.0) ? 0u : uint(round(clamp((log(scales.z) - lnScaleMin) * lnScaleScale, 0.0, 254.0))) + 1u ); // Pack it all into 4 x uint32 @@ -222,12 +229,19 @@ uvec4 packSplat(vec3 center, vec3 scales, vec4 quaternion, vec4 rgba) { return uvec4(word0, word1, word2, word3); } -// Unpack a Gsplat from a uvec4 -void unpackSplat(uvec4 packed, out vec3 center, out vec3 scales, out vec4 quaternion, out vec4 rgba) { +// Pack a Gsplat into a uvec4 +uvec4 packSplat(vec3 center, vec3 scales, vec4 quaternion, vec4 rgba) { + return packSplatEncoding(center, scales, quaternion, rgba, vec4(0.0, 1.0, LN_SCALE_MIN, LN_SCALE_MAX)); +} + +void unpackSplatEncoding(uvec4 packed, out vec3 center, out vec3 scales, out vec4 quaternion, out vec4 rgba, vec4 rgbMinMaxLnScaleMinMax) { uint word0 = packed.x, word1 = packed.y, word2 = packed.z, word3 = packed.w; uvec4 uRgba = uvec4(word0 & 0xffu, (word0 >> 8u) & 0xffu, (word0 >> 16u) & 0xffu, (word0 >> 24u) & 0xffu); - rgba = vec4(uRgba) / 255.0; + float rgbMin = rgbMinMaxLnScaleMinMax.x; + float rgbMax = rgbMinMaxLnScaleMinMax.y; + rgba = (vec4(uRgba) / 255.0); + rgba.rgb = rgba.rgb * (rgbMax - rgbMin) + rgbMin; center = vec4( unpackHalf2x16(word1), @@ -235,10 +249,13 @@ void unpackSplat(uvec4 packed, out vec3 center, out vec3 scales, out vec4 quater ).xyz; uvec3 uScales = uvec3(word3 & 0xffu, (word3 >> 8u) & 0xffu, (word3 >> 16u) & 0xffu); + float lnScaleMin = rgbMinMaxLnScaleMinMax.z; + float lnScaleMax = rgbMinMaxLnScaleMinMax.w; + float lnScaleScale = (lnScaleMax - lnScaleMin) / 254.0; scales = vec3( - (uScales.x == 0u) ? 0.0 : exp(LN_SCALE_MIN + float(uScales.x - 1u) * LN_RESCALE), - (uScales.y == 0u) ? 0.0 : exp(LN_SCALE_MIN + float(uScales.y - 1u) * LN_RESCALE), - (uScales.z == 0u) ? 0.0 : exp(LN_SCALE_MIN + float(uScales.z - 1u) * LN_RESCALE) + (uScales.x == 0u) ? 0.0 : exp(lnScaleMin + float(uScales.x - 1u) * lnScaleScale), + (uScales.y == 0u) ? 0.0 : exp(lnScaleMin + float(uScales.y - 1u) * lnScaleScale), + (uScales.z == 0u) ? 0.0 : exp(lnScaleMin + float(uScales.z - 1u) * lnScaleScale) ); @@ -248,6 +265,11 @@ void unpackSplat(uvec4 packed, out vec3 center, out vec3 scales, out vec4 quater // quaternion = decodeQuatEulerXyz888(uQuat); } +// Unpack a Gsplat from a uvec4 +void unpackSplat(uvec4 packed, out vec3 center, out vec3 scales, out vec4 quaternion, out vec4 rgba) { + unpackSplatEncoding(packed, center, scales, quaternion, rgba, vec4(0.0, 1.0, LN_SCALE_MIN, LN_SCALE_MAX)); +} + // Rotate vector v by quaternion q vec3 quatVec(vec4 q, vec3 v) { // Rotate vector v by quaternion q diff --git a/src/shaders/splatFragment.glsl b/src/shaders/splatFragment.glsl index 4c4ef77..e51be13 100644 --- a/src/shaders/splatFragment.glsl +++ b/src/shaders/splatFragment.glsl @@ -6,6 +6,7 @@ precision highp int; uniform float near; uniform float far; +uniform bool premultipliedAlpha; uniform bool encodeLinear; uniform float maxStdDev; uniform float minAlpha; @@ -67,5 +68,10 @@ void main() { if (encodeLinear) { rgba.rgb = srgbToLinear(rgba.rgb); } - fragColor = rgba; + + if (premultipliedAlpha) { + fragColor = vec4(rgba.rgb * rgba.a, rgba.a); + } else { + fragColor = rgba; + } } diff --git a/src/shaders/splatVertex.glsl b/src/shaders/splatVertex.glsl index c7086db..e148cde 100644 --- a/src/shaders/splatVertex.glsl +++ b/src/shaders/splatVertex.glsl @@ -30,6 +30,7 @@ uniform float clipXY; uniform float focalAdjustment; uniform usampler2DArray packedSplats; +uniform vec4 rgbMinMaxLnScaleMinMax; void main() { // Default to outside the frustum so it's discarded if we return early @@ -52,7 +53,7 @@ void main() { vec3 center, scales; vec4 quaternion, rgba; - unpackSplat(packed, center, scales, quaternion, rgba); + unpackSplatEncoding(packed, center, scales, quaternion, rgba, rgbMinMaxLnScaleMinMax); if (rgba.a < minAlpha) { return; diff --git a/src/utils.ts b/src/utils.ts index efeeb8b..d359d25 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -4,7 +4,7 @@ import * as THREE from "three"; // Miscellaneous utility functions for Spark import { - LN_RESCALE, + LN_SCALE_MAX, LN_SCALE_MIN, SCALE_ZERO, SPLAT_TEX_HEIGHT, @@ -353,10 +353,19 @@ export function setPackedSplat( r: number, g: number, b: number, + encoding?: { + rgbMin?: number; + rgbMax?: number; + lnScaleMin?: number; + lnScaleMax?: number; + }, ) { - const uR = floatToUint8(r); - const uG = floatToUint8(g); - const uB = floatToUint8(b); + const rgbMin = encoding?.rgbMin ?? 0.0; + const rgbMax = encoding?.rgbMax ?? 1.0; + const rgbRange = rgbMax - rgbMin; + const uR = floatToUint8((r - rgbMin) / rgbRange); + const uG = floatToUint8((g - rgbMin) / rgbRange); + const uB = floatToUint8((b - rgbMin) / rgbRange); const uA = floatToUint8(opacity); // Alternate internal encodings commented out below. @@ -370,6 +379,9 @@ export function setPackedSplat( const uQuatZ = (uQuat >>> 16) & 0xff; // Allow scales below LN_SCALE_MIN to be encoded as 0, which signifies a 2DGS + const lnScaleMin = encoding?.lnScaleMin ?? LN_SCALE_MIN; + const lnScaleMax = encoding?.lnScaleMax ?? LN_SCALE_MAX; + const lnScaleScale = 254.0 / (lnScaleMax - lnScaleMin); const uScaleX = scaleX < SCALE_ZERO ? 0 @@ -377,7 +389,7 @@ export function setPackedSplat( 255, Math.max( 1, - Math.round((Math.log(scaleX) - LN_SCALE_MIN) / LN_RESCALE) + 1, + Math.round((Math.log(scaleX) - lnScaleMin) * lnScaleScale) + 1, ), ); const uScaleY = @@ -387,7 +399,7 @@ export function setPackedSplat( 255, Math.max( 1, - Math.round((Math.log(scaleY) - LN_SCALE_MIN) / LN_RESCALE) + 1, + Math.round((Math.log(scaleY) - lnScaleMin) * lnScaleScale) + 1, ), ); const uScaleZ = @@ -397,7 +409,7 @@ export function setPackedSplat( 255, Math.max( 1, - Math.round((Math.log(scaleZ) - LN_SCALE_MIN) / LN_RESCALE) + 1, + Math.round((Math.log(scaleZ) - lnScaleMin) * lnScaleScale) + 1, ), ); @@ -439,8 +451,15 @@ export function setPackedSplatScales( scaleX: number, scaleY: number, scaleZ: number, + encoding?: { + lnScaleMin?: number; + lnScaleMax?: number; + }, ) { // Allow scales below LN_SCALE_MIN to be encoded as 0, which signifies a 2DGS + const lnScaleMin = encoding?.lnScaleMin ?? LN_SCALE_MIN; + const lnScaleMax = encoding?.lnScaleMax ?? LN_SCALE_MAX; + const lnScaleScale = 254.0 / (lnScaleMax - lnScaleMin); const uScaleX = scaleX < SCALE_ZERO ? 0 @@ -448,7 +467,7 @@ export function setPackedSplatScales( 255, Math.max( 1, - Math.round((Math.log(scaleX) - LN_SCALE_MIN) / LN_RESCALE) + 1, + Math.round((Math.log(scaleX) - lnScaleMin) * lnScaleScale) + 1, ), ); const uScaleY = @@ -458,7 +477,7 @@ export function setPackedSplatScales( 255, Math.max( 1, - Math.round((Math.log(scaleY) - LN_SCALE_MIN) / LN_RESCALE) + 1, + Math.round((Math.log(scaleY) - lnScaleMin) * lnScaleScale) + 1, ), ); const uScaleZ = @@ -468,7 +487,7 @@ export function setPackedSplatScales( 255, Math.max( 1, - Math.round((Math.log(scaleZ) - LN_SCALE_MIN) / LN_RESCALE) + 1, + Math.round((Math.log(scaleZ) - lnScaleMin) * lnScaleScale) + 1, ), ); @@ -513,10 +532,17 @@ export function setPackedSplatRgba( g: number, b: number, a: number, + encoding?: { + rgbMin?: number; + rgbMax?: number; + }, ) { - const uR = floatToUint8(r); - const uG = floatToUint8(g); - const uB = floatToUint8(b); + const rgbMin = encoding?.rgbMin ?? 0.0; + const rgbMax = encoding?.rgbMax ?? 1.0; + const rgbRange = rgbMax - rgbMin; + const uR = floatToUint8((r - rgbMin) / rgbRange); + const uG = floatToUint8((g - rgbMin) / rgbRange); + const uB = floatToUint8((b - rgbMin) / rgbRange); const uA = floatToUint8(a); const i4 = index * 4; packedSplats[i4] = uR | (uG << 8) | (uB << 16) | (uA << 24); @@ -529,10 +555,17 @@ export function setPackedSplatRgb( r: number, g: number, b: number, + encoding?: { + rgbMin?: number; + rgbMax?: number; + }, ) { - const uR = floatToUint8(r); - const uG = floatToUint8(g); - const uB = floatToUint8(b); + const rgbMin = encoding?.rgbMin ?? 0.0; + const rgbMax = encoding?.rgbMax ?? 1.0; + const rgbRange = rgbMax - rgbMin; + const uR = floatToUint8((r - rgbMin) / rgbRange); + const uG = floatToUint8((g - rgbMin) / rgbRange); + const uB = floatToUint8((b - rgbMin) / rgbRange); const i4 = index * 4; packedSplats[i4] = @@ -568,6 +601,12 @@ const packedFields = { export function unpackSplat( packedSplats: Uint32Array, index: number, + encoding?: { + rgbMin?: number; + rgbMax?: number; + lnScaleMin?: number; + lnScaleMax?: number; + }, ): { center: THREE.Vector3; scales: THREE.Vector3; @@ -584,10 +623,13 @@ export function unpackSplat( const word2 = packedSplats[i4 + 2]; const word3 = packedSplats[i4 + 3]; + const rgbMin = encoding?.rgbMin ?? 0.0; + const rgbMax = encoding?.rgbMax ?? 1.0; + const rgbRange = rgbMax - rgbMin; result.color.set( - (word0 & 0xff) / 255, - ((word0 >>> 8) & 0xff) / 255, - ((word0 >>> 16) & 0xff) / 255, + rgbMin + ((word0 & 0xff) / 255) * rgbRange, + rgbMin + (((word0 >>> 8) & 0xff) / 255) * rgbRange, + rgbMin + (((word0 >>> 16) & 0xff) / 255) * rgbRange, ); result.opacity = ((word0 >>> 24) & 0xff) / 255; result.center.set( @@ -596,15 +638,18 @@ export function unpackSplat( fromHalf(word2 & 0xffff), ); + const lnScaleMin = encoding?.lnScaleMin ?? LN_SCALE_MIN; + const lnScaleMax = encoding?.lnScaleMax ?? LN_SCALE_MAX; + const lnScaleScale = (lnScaleMax - lnScaleMin) / 254.0; const uScalesX = word3 & 0xff; result.scales.x = - uScalesX === 0 ? 0.0 : Math.exp(LN_SCALE_MIN + (uScalesX - 1) * LN_RESCALE); + uScalesX === 0 ? 0.0 : Math.exp(lnScaleMin + (uScalesX - 1) * lnScaleScale); const uScalesY = (word3 >>> 8) & 0xff; result.scales.y = - uScalesY === 0 ? 0.0 : Math.exp(LN_SCALE_MIN + (uScalesY - 1) * LN_RESCALE); + uScalesY === 0 ? 0.0 : Math.exp(lnScaleMin + (uScalesY - 1) * lnScaleScale); const uScalesZ = (word3 >>> 16) & 0xff; result.scales.z = - uScalesZ === 0 ? 0.0 : Math.exp(LN_SCALE_MIN + (uScalesZ - 1) * LN_RESCALE); + uScalesZ === 0 ? 0.0 : Math.exp(lnScaleMin + (uScalesZ - 1) * lnScaleScale); const uQuat = ((word2 >>> 16) & 0xffff) | ((word3 >>> 8) & 0xff0000); decodeQuatOctXy88R8(uQuat, result.quaternion); @@ -1110,11 +1155,21 @@ export function encodeSh1Rgb( sh1Array: Uint32Array, index: number, sh1Rgb: Float32Array, + encoding?: { + sh1Min?: number; + sh1Max?: number; + }, ) { + const sh1Min = encoding?.sh1Min ?? -1; + const sh1Max = encoding?.sh1Max ?? 1; + const sh1Mid = 0.5 * (sh1Min + sh1Max); + const sh1Scale = 63 / (sh1Max - sh1Min); + // Pack sint7 values into 2 x uint32 const base = index * 2; for (let i = 0; i < 9; ++i) { - const value = Math.max(-63, Math.min(63, sh1Rgb[i] * 63)) & 0x7f; + const s = (sh1Rgb[i] - sh1Mid) * sh1Scale; + const value = Math.round(Math.max(-63, Math.min(63, s))) & 0x7f; const bitStart = i * 7; const bitEnd = bitStart + 7; @@ -1136,30 +1191,39 @@ export function encodeSh2Rgb( sh2Array: Uint32Array, index: number, sh2Rgb: Float32Array, + encoding?: { + sh2Min?: number; + sh2Max?: number; + }, ) { + const sh2Min = encoding?.sh2Min ?? -1; + const sh2Max = encoding?.sh2Max ?? 1; + const sh2Mid = 0.5 * (sh2Min + sh2Max); + const sh2Scale = 0.5 / (sh2Max - sh2Min); + // Pack sint8 values into 4 x uint32 sh2Array[index * 4 + 0] = packSint8Bytes( - sh2Rgb[0], - sh2Rgb[1], - sh2Rgb[2], - sh2Rgb[3], + (sh2Rgb[0] - sh2Mid) * sh2Scale, + (sh2Rgb[1] - sh2Mid) * sh2Scale, + (sh2Rgb[2] - sh2Mid) * sh2Scale, + (sh2Rgb[3] - sh2Mid) * sh2Scale, ); sh2Array[index * 4 + 1] = packSint8Bytes( - sh2Rgb[4], - sh2Rgb[5], - sh2Rgb[6], - sh2Rgb[7], + (sh2Rgb[4] - sh2Mid) * sh2Scale, + (sh2Rgb[5] - sh2Mid) * sh2Scale, + (sh2Rgb[6] - sh2Mid) * sh2Scale, + (sh2Rgb[7] - sh2Mid) * sh2Scale, ); sh2Array[index * 4 + 2] = packSint8Bytes( - sh2Rgb[8], - sh2Rgb[9], - sh2Rgb[10], - sh2Rgb[11], + (sh2Rgb[8] - sh2Mid) * sh2Scale, + (sh2Rgb[9] - sh2Mid) * sh2Scale, + (sh2Rgb[10] - sh2Mid) * sh2Scale, + (sh2Rgb[11] - sh2Mid) * sh2Scale, ); sh2Array[index * 4 + 3] = packSint8Bytes( - sh2Rgb[12], - sh2Rgb[13], - sh2Rgb[14], + (sh2Rgb[12] - sh2Mid) * sh2Scale, + (sh2Rgb[13] - sh2Mid) * sh2Scale, + (sh2Rgb[14] - sh2Mid) * sh2Scale, 0, ); } @@ -1170,11 +1234,21 @@ export function encodeSh3Rgb( sh3Array: Uint32Array, index: number, sh3Rgb: Float32Array, + encoding?: { + sh3Min?: number; + sh3Max?: number; + }, ) { + const sh3Min = encoding?.sh3Min ?? -1; + const sh3Max = encoding?.sh3Max ?? 1; + const sh3Mid = 0.5 * (sh3Min + sh3Max); + const sh3Scale = 31 / (sh3Max - sh3Min); + // Pack sint6 values into 4 x uint32 const base = index * 4; for (let i = 0; i < 21; ++i) { - const value = Math.max(-31, Math.min(31, sh3Rgb[i] * 31)) & 0x3f; + const s = (sh3Rgb[i] - sh3Mid) * sh3Scale; + const value = Math.round(Math.max(-31, Math.min(31, s))) & 0x3f; const bitStart = i * 6; const bitEnd = bitStart + 6; diff --git a/src/worker.ts b/src/worker.ts index 6ebfe9e..943e1ca 100644 --- a/src/worker.ts +++ b/src/worker.ts @@ -1,4 +1,5 @@ import init_wasm, { sort_splats, sort32_splats } from "spark-internal-rs"; +import type { SplatEncoding } from "./PackedSplats"; import type { PcSogsJson, TranscodeSpzInput } from "./SplatLoader"; import { unpackAntiSplat } from "./antisplat"; import { WASM_SPLAT_SORT } from "./defines"; @@ -37,11 +38,16 @@ async function onMessage(event: MessageEvent) { try { switch (name) { case "unpackPly": { - const { packedArray, fileBytes } = args as { + const { packedArray, fileBytes, splatEncoding } = args as { packedArray: Uint32Array; fileBytes: Uint8Array; + splatEncoding: SplatEncoding; }; - const decoded = await unpackPly({ packedArray, fileBytes }); + const decoded = await unpackPly({ + packedArray, + fileBytes, + splatEncoding, + }); result = { id, numSplats: decoded.numSplats, @@ -51,8 +57,11 @@ async function onMessage(event: MessageEvent) { break; } case "decodeSpz": { - const { fileBytes } = args as { fileBytes: Uint8Array }; - const decoded = unpackSpz(fileBytes); + const { fileBytes, splatEncoding } = args as { + fileBytes: Uint8Array; + splatEncoding: SplatEncoding; + }; + const decoded = unpackSpz(fileBytes, splatEncoding); result = { id, numSplats: decoded.numSplats, @@ -62,8 +71,11 @@ async function onMessage(event: MessageEvent) { break; } case "decodeAntiSplat": { - const { fileBytes } = args as { fileBytes: Uint8Array }; - const decoded = unpackAntiSplat(fileBytes); + const { fileBytes, splatEncoding } = args as { + fileBytes: Uint8Array; + splatEncoding: SplatEncoding; + }; + const decoded = unpackAntiSplat(fileBytes, splatEncoding); result = { id, numSplats: decoded.numSplats, @@ -72,8 +84,11 @@ async function onMessage(event: MessageEvent) { break; } case "decodeKsplat": { - const { fileBytes } = args as { fileBytes: Uint8Array }; - const decoded = unpackKsplat(fileBytes); + const { fileBytes, splatEncoding } = args as { + fileBytes: Uint8Array; + splatEncoding: SplatEncoding; + }; + const decoded = unpackKsplat(fileBytes, splatEncoding); result = { id, numSplats: decoded.numSplats, @@ -83,14 +98,15 @@ async function onMessage(event: MessageEvent) { break; } case "decodePcSogs": { - const { fileBytes, extraFiles } = args as { + const { fileBytes, extraFiles, splatEncoding } = args as { fileBytes: Uint8Array; extraFiles: Record; + splatEncoding: SplatEncoding; }; const json = JSON.parse( new TextDecoder().decode(fileBytes), ) as PcSogsJson; - const decoded = await unpackPcSogs(json, extraFiles); + const decoded = await unpackPcSogs(json, extraFiles, splatEncoding); result = { id, numSplats: decoded.numSplats, @@ -100,8 +116,11 @@ async function onMessage(event: MessageEvent) { break; } case "decodePcSogsZip": { - const { fileBytes } = args as { fileBytes: Uint8Array }; - const decoded = await unpackPcSogsZip(fileBytes); + const { fileBytes, splatEncoding } = args as { + fileBytes: Uint8Array; + splatEncoding: SplatEncoding; + }; + const decoded = await unpackPcSogsZip(fileBytes, splatEncoding); result = { id, numSplats: decoded.numSplats, @@ -281,7 +300,12 @@ function benchmarkSort( async function unpackPly({ packedArray, fileBytes, -}: { packedArray: Uint32Array; fileBytes: Uint8Array }): Promise<{ + splatEncoding, +}: { + packedArray: Uint32Array; + fileBytes: Uint8Array; + splatEncoding: SplatEncoding; +}): Promise<{ packedArray: Uint32Array; numSplats: number; extra: Record; @@ -327,6 +351,7 @@ async function unpackPly({ r, g, b, + splatEncoding, ); }, (index, sh1, sh2, sh3) => { @@ -334,19 +359,19 @@ async function unpackPly({ if (!extra.sh1) { extra.sh1 = new Uint32Array(numSplats * 2); } - encodeSh1Rgb(extra.sh1 as Uint32Array, index, sh1); + encodeSh1Rgb(extra.sh1 as Uint32Array, index, sh1, splatEncoding); } if (sh2) { if (!extra.sh2) { extra.sh2 = new Uint32Array(numSplats * 4); } - encodeSh2Rgb(extra.sh2 as Uint32Array, index, sh2); + encodeSh2Rgb(extra.sh2 as Uint32Array, index, sh2, splatEncoding); } if (sh3) { if (!extra.sh3) { extra.sh3 = new Uint32Array(numSplats * 4); } - encodeSh3Rgb(extra.sh3 as Uint32Array, index, sh3); + encodeSh3Rgb(extra.sh3 as Uint32Array, index, sh3, splatEncoding); } }, ); @@ -354,7 +379,10 @@ async function unpackPly({ return { packedArray, numSplats, extra }; } -function unpackSpz(fileBytes: Uint8Array): { +function unpackSpz( + fileBytes: Uint8Array, + splatEncoding: SplatEncoding, +): { packedArray: Uint32Array; numSplats: number; extra: Record; @@ -373,10 +401,17 @@ function unpackSpz(fileBytes: Uint8Array): { setPackedSplatOpacity(packedArray, index, alpha); }, (index, r, g, b) => { - setPackedSplatRgb(packedArray, index, r, g, b); + setPackedSplatRgb(packedArray, index, r, g, b, splatEncoding); }, (index, scaleX, scaleY, scaleZ) => { - setPackedSplatScales(packedArray, index, scaleX, scaleY, scaleZ); + setPackedSplatScales( + packedArray, + index, + scaleX, + scaleY, + scaleZ, + splatEncoding, + ); }, (index, quatX, quatY, quatZ, quatW) => { setPackedSplatQuat(packedArray, index, quatX, quatY, quatZ, quatW); @@ -386,19 +421,19 @@ function unpackSpz(fileBytes: Uint8Array): { if (!extra.sh1) { extra.sh1 = new Uint32Array(numSplats * 2); } - encodeSh1Rgb(extra.sh1 as Uint32Array, index, sh1); + encodeSh1Rgb(extra.sh1 as Uint32Array, index, sh1, splatEncoding); } if (sh2) { if (!extra.sh2) { extra.sh2 = new Uint32Array(numSplats * 4); } - encodeSh2Rgb(extra.sh2 as Uint32Array, index, sh2); + encodeSh2Rgb(extra.sh2 as Uint32Array, index, sh2, splatEncoding); } if (sh3) { if (!extra.sh3) { extra.sh3 = new Uint32Array(numSplats * 4); } - encodeSh3Rgb(extra.sh3 as Uint32Array, index, sh3); + encodeSh3Rgb(extra.sh3 as Uint32Array, index, sh3, splatEncoding); } }, ); From 6d0ece3a7ae20d9d8ad8a5d8210956513a517b87 Mon Sep 17 00:00:00 2001 From: Andreas Sundquist Date: Tue, 22 Jul 2025 09:37:43 -0700 Subject: [PATCH 3/3] Fix SplatLoader not passing along encoding. Complete renaming to splat encoding. --- src/PackedSplats.ts | 4 ++-- src/SparkRenderer.ts | 4 ++-- src/SplatLoader.ts | 9 ++++++++- src/SplatMesh.ts | 4 ++-- 4 files changed, 14 insertions(+), 7 deletions(-) diff --git a/src/PackedSplats.ts b/src/PackedSplats.ts index bb52d19..f9ba443 100644 --- a/src/PackedSplats.ts +++ b/src/PackedSplats.ts @@ -35,7 +35,7 @@ export type SplatEncoding = { sh3Max?: number; }; -export const DEFAULT_SPLAT_RANGES: SplatEncoding = { +export const DEFAULT_SPLAT_ENCODING: SplatEncoding = { rgbMin: 0, rgbMax: 1, lnScaleMin: LN_SCALE_MIN, @@ -222,7 +222,7 @@ export class PackedSplats { input: fileBytes, fileType: options.fileType, pathOrUrl: options.fileName ?? url, - splatEncoding: options.splatEncoding ?? DEFAULT_SPLAT_RANGES, + splatEncoding: options.splatEncoding ?? DEFAULT_SPLAT_ENCODING, }); this.initialize(unpacked); } diff --git a/src/SparkRenderer.ts b/src/SparkRenderer.ts index 6743207..c3e84d7 100644 --- a/src/SparkRenderer.ts +++ b/src/SparkRenderer.ts @@ -1,7 +1,7 @@ import * as THREE from "three"; import { - DEFAULT_SPLAT_RANGES, + DEFAULT_SPLAT_ENCODING, PackedSplats, type SplatEncoding, } from "./PackedSplats"; @@ -350,7 +350,7 @@ export class SparkRenderer extends THREE.Mesh { this.falloff = options.falloff ?? 1.0; this.clipXY = options.clipXY ?? 1.4; this.focalAdjustment = options.focalAdjustment ?? 1.0; - this.splatEncoding = options.splatEncoding ?? { ...DEFAULT_SPLAT_RANGES }; + this.splatEncoding = options.splatEncoding ?? { ...DEFAULT_SPLAT_ENCODING }; this.active = new SplatAccumulator(); this.accumulatorCount = 1; diff --git a/src/SplatLoader.ts b/src/SplatLoader.ts index 933dc91..8f01bbe 100644 --- a/src/SplatLoader.ts +++ b/src/SplatLoader.ts @@ -1,6 +1,10 @@ import { unzipSync } from "fflate"; import { FileLoader, Loader, type LoadingManager } from "three"; -import { PackedSplats, type SplatEncoding } from "./PackedSplats"; +import { + DEFAULT_SPLAT_ENCODING, + PackedSplats, + type SplatEncoding, +} from "./PackedSplats"; import { SplatMesh } from "./SplatMesh"; import { PlyReader } from "./ply"; import { withWorker } from "./splatWorker"; @@ -110,11 +114,14 @@ export class SplatLoader extends Loader { await Promise.all(promises); if (onLoad) { + const splatEncoding = + this.packedSplats?.splatEncoding ?? DEFAULT_SPLAT_ENCODING; const decoded = await unpackSplats({ input, extraFiles, fileType, pathOrUrl: resolvedURL, + splatEncoding, }); if (this.packedSplats) { diff --git a/src/SplatMesh.ts b/src/SplatMesh.ts index 9b735b1..4a696db 100644 --- a/src/SplatMesh.ts +++ b/src/SplatMesh.ts @@ -2,7 +2,7 @@ import * as THREE from "three"; import init_wasm, { raycast_splats } from "spark-internal-rs"; import { - DEFAULT_SPLAT_RANGES, + DEFAULT_SPLAT_ENCODING, PackedSplats, type SplatEncoding, } from "./PackedSplats"; @@ -193,7 +193,7 @@ export class SplatMesh extends SplatGenerator { this.packedSplats = options.packedSplats ?? new PackedSplats(); this.packedSplats.splatEncoding = options.splatEncoding ?? { - ...DEFAULT_SPLAT_RANGES, + ...DEFAULT_SPLAT_ENCODING, }; this.numSplats = this.packedSplats.numSplats; this.editable = options.editable ?? true;