In [8]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [9]:
from scene.gaussian_model import GaussianModel
import os, sys
from plyfile import PlyData, PlyElement
import numpy as np
import torch

In [None]:
def process_ply(plydata):
    xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
                    np.asarray(plydata.elements[0]["y"]),
                    np.asarray(plydata.elements[0]["z"])),  axis=1)
    opacities = np.asarray(plydata.elements[0]["opacity"])

    features_dc = np.zeros((xyz.shape[0], 3, 1))
    features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
    features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
    features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])

    extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
    extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1]))
    # print(extra_f_names)
    assert len(extra_f_names)==3*(1 + 1) ** 2 - 3 # 3*(4) for 4DGS, 3*(2) for 3DGStream
    features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
    for idx, attr_name in enumerate(extra_f_names):
        features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
    # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
    features_extra = features_extra.reshape((features_extra.shape[0], 3, (1 + 1) ** 2 - 1))
    
    # shN = features_extra.transpose((0, 2, 1))

    scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
    scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1]))
    scales = np.zeros((xyz.shape[0], len(scale_names)))
    for idx, attr_name in enumerate(scale_names):
        scales[:, idx] = np.asarray(plydata.elements[0][attr_name])

    rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
    rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1]))
    rots = np.zeros((xyz.shape[0], len(rot_names)))
    for idx, attr_name in enumerate(rot_names):
        rots[:, idx] = np.asarray(plydata.elements[0][attr_name])

    # TODO: turn into tensor(with cuda device), and dimension check
    return {'means': torch.tensor(xyz, device='cuda').float(),
            'opacities': torch.tensor(opacities, device='cuda').float(),
            'quats': torch.tensor(rots, device='cuda').float(),
            'scales': torch.tensor(scales, device='cuda').float(),
            'sh0': torch.tensor(features_dc, device='cuda').float().transpose(1, 2).contiguous(),
            'shN': torch.tensor(features_extra, device='cuda').float().transpose(1, 2).contiguous(),
            }

In [None]:
ply_before = "/data2/wlsgur4011/3DGStream_reproduction/models/Diva360/dog/point_cloud/iteration_15000/point_cloud.ply" # before

ply_before = PlyData.read(ply_before)
ply_before = process_ply(ply_before)

for key, value in ply_before.items():
    print(f'{key} shape:', value.shape)

means shape: torch.Size([29360, 3])
opacities shape: torch.Size([29360])
quats shape: torch.Size([29360, 4])
scales shape: torch.Size([29360, 3])
sh0 shape: torch.Size([29360, 1, 3])
shN shape: torch.Size([29360, 3, 3])


In [5]:
def post_process(input_dict):
    return {'step': None,    
            'splats': input_dict,           
            'clustered': None}

### Change "ply_path" for each directory

In [None]:
# TODO: selective path
output_path = '/data2/wlsgur4011/3DGStream_reproduction/output/Diva360'
for object in os.listdir(output_path):
    ply_path = os.path.join(output_path, object, 'point_cloud/iteration_250/point_cloud.ply')
    temp = PlyData.read(ply_path)
    ckpt = process_ply(temp)
    ckpt = post_process(ckpt)

    ckpt_path = os.path.join(os.path.split(ply_path)[0], 'checkpoint.pth')
    # print(ckpt_path)
    torch.save(ckpt, ckpt_path)