Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 140 additions & 0 deletions rust/spark-lib/src/ply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ pub const PLY_MAGIC: u32 = 0x00796c70; // "ply"
const MAX_SPLAT_CHUNK: usize = 65536;
const SH_C0: f32 = 0.28209479177387814;
const SUPER_CHUNK_SIZE: usize = 256;
const POINT_CLOUD_PROPERTIES: [&str; 6] = ["x", "y", "z", "red", "green", "blue"];
const DEFAULT_POINT_SCALE: f32 = 0.001;

pub struct PlyDecoder<T: SplatReceiver> {
splats: T,
Expand Down Expand Up @@ -69,6 +71,14 @@ impl<T: SplatReceiver> PlyDecoder<T> {
lod_tree: false,
})?;
PlyState::SuperSplat(state)
} else if parsed.is_pointcloud {
let state = PointCloudDecoderState::new(parsed.num_splats, parsed.vertex.record_size, parsed.vertex.properties.clone())?;
self.splats.init_splats(&SplatInit {
num_splats: parsed.num_splats,
max_sh_degree: 0,
lod_tree: false,
})?;
PlyState::PointCloud(state)
} else {
let state = PlyDecoderState::new(parsed.num_splats, parsed.vertex.record_size, parsed.vertex.properties.clone())?;
self.splats.init_splats(&SplatInit {
Expand All @@ -86,12 +96,65 @@ impl<T: SplatReceiver> PlyDecoder<T> {

fn poll_data(&mut self) -> anyhow::Result<()> {
match self.state {
Some(PlyState::PointCloud(_)) => self.poll_data_pointcloud(),
Some(PlyState::Standard(_)) => self.poll_data_standard(),
Some(PlyState::SuperSplat(_)) => self.poll_data_supersplat(),
None => unreachable!(),
}
}

fn poll_data_pointcloud(&mut self) -> anyhow::Result<()> {
let Some(PlyState::PointCloud(state)) = self.state.as_mut() else { unreachable!() };
let mut offset = 0;
loop {
let available = (self.buffer.len() - offset) / state.record_size;
let remaining = state.num_splats.saturating_sub(state.next_splat);
let count = remaining.min(available).min(MAX_SPLAT_CHUNK);
if count == 0 {
break;
}

state.ensure_out(count);

for i in 0..count {
let [i3, i4] = [i * 3, i * 4];
let base = offset + i * state.record_size;

for d in 0..3 {
state.out_center[i3 + d] = state.xyz[d].get_f32(&self.buffer, base);
}
state.out_opacity[i] = match state.alpha {
Some(alpha) => alpha.get_f32(&self.buffer, base),
None => 1.0
};
for d in 0..3 {
state.out_rgb[i3 + d] = state.rgb[d].get_f32(&self.buffer, base);
}
state.out_scale.splice(i3..i3+3, [DEFAULT_POINT_SCALE, DEFAULT_POINT_SCALE, DEFAULT_POINT_SCALE]);
state.out_quat.splice(i4..i4+4, [0.0, 0.0, 0.0, 1.0]);
}

self.splats.set_batch(state.next_splat, count, &SplatProps {
center: &state.out_center[..count * 3],
opacity: &state.out_opacity[..count],
rgb: &state.out_rgb[..count * 3],
scale: &state.out_scale[..count * 3],
quat: &state.out_quat[..count * 4],
sh1: &Vec::new(),
sh2: &Vec::new(),
sh3: &Vec::new(),
..Default::default()
});

state.next_splat += count;
offset += count * state.record_size;
}

self.buffer.drain(..offset);
Ok(())
}


fn poll_data_standard(&mut self) -> anyhow::Result<()> {
let Some(PlyState::Standard(state)) = self.state.as_mut() else { unreachable!() };
let mut offset = 0;
Expand Down Expand Up @@ -258,6 +321,11 @@ impl<T: SplatReceiver> ChunkReceiver for PlyDecoder<T> {
// As long as we've read the correct number of splats, we're good.

match state {
PlyState::PointCloud(state) => {
if state.next_splat != state.num_splats {
return Err(anyhow!("Expected {} splats, got {}", state.num_splats, state.next_splat));
}
},
PlyState::Standard(state) => {
if state.next_splat != state.num_splats {
return Err(anyhow!("Expected {} splats, got {}", state.num_splats, state.next_splat));
Expand All @@ -284,6 +352,7 @@ impl<T: SplatReceiver> ChunkReceiver for PlyDecoder<T> {

#[derive(Debug)]
enum PlyState {
PointCloud(PointCloudDecoderState),
Standard(PlyDecoderState),
SuperSplat(SuperSplatState),
}
Expand Down Expand Up @@ -335,6 +404,7 @@ struct ParsedHeader {
chunk: Option<PlyElementDesc>,
sh: Option<PlyElementDesc>,
num_splats: usize,
is_pointcloud: bool,
is_supersplat: bool,
}

Expand Down Expand Up @@ -421,12 +491,14 @@ fn parse_header(header: &str) -> anyhow::Result<ParsedHeader> {
let vertex = elements.iter().find(|e| e.name == "vertex").cloned().ok_or(anyhow!("Missing vertex element"))?;
let chunk = elements.iter().find(|e| e.name == "chunk").cloned();
let sh = elements.iter().find(|e| e.name == "sh").cloned();
let is_pointcloud = POINT_CLOUD_PROPERTIES.iter().all(|&p| vertex.properties.contains_key(p));

Ok(ParsedHeader {
num_splats: vertex.count,
vertex,
chunk,
sh,
is_pointcloud,
is_supersplat: elements.iter().any(|e| e.name == "chunk"),
elements,
})
Expand Down Expand Up @@ -955,6 +1027,74 @@ impl PlyDecoderState {
}
}

#[derive(Debug)]
struct PointCloudDecoderState {
num_splats: usize,
record_size: usize,
next_splat: usize,

#[allow(unused)]
properties: HashMap<String, PlyProperty>,
xyz: [PlyProperty; 3],
rgb: [PlyProperty; 3],
alpha: Option<PlyProperty>,

out_center: Vec<f32>,
out_opacity: Vec<f32>,
out_rgb: Vec<f32>,
out_scale: Vec<f32>,
out_quat: Vec<f32>,
}

impl PointCloudDecoderState {
fn new(num_splats: usize, record_size: usize, properties: HashMap<String, PlyProperty>) -> anyhow::Result<Self> {
let xyz = [
*properties.get("x").ok_or(anyhow!("Missing x property"))?,
*properties.get("y").ok_or(anyhow!("Missing y property"))?,
*properties.get("z").ok_or(anyhow!("Missing z property"))?,
];
let rgb = [
*properties.get("red").ok_or(anyhow!("Missing red property"))?,
*properties.get("green").ok_or(anyhow!("Missing green property"))?,
*properties.get("blue").ok_or(anyhow!("Missing blue property"))?,
];
let alpha = properties.get("alpha").map(|p| *p);

Ok(Self {
num_splats,
record_size,
next_splat: 0,
properties,
xyz,
rgb,
alpha,
out_center: Vec::new(),
out_opacity: Vec::new(),
out_rgb: Vec::new(),
out_scale: Vec::new(),
out_quat: Vec::new(),
})
}

fn ensure_out(&mut self, count: usize) {
if self.out_center.len() < (count * 3) {
self.out_center.resize(count * 3, 0.0);
}
if self.out_opacity.len() < count {
self.out_opacity.resize(count, 0.0);
}
if self.out_rgb.len() < (count * 3) {
self.out_rgb.resize(count * 3, 0.0);
}
if self.out_scale.len() < (count * 3) {
self.out_scale.resize(count * 3, 0.0);
}
if self.out_quat.len() < (count * 4) {
self.out_quat.resize(count * 4, 0.0);
}
}
}

#[derive(Debug, Clone, Copy)]
pub enum PlyPropertyType {
Char,
Expand Down
Loading