In [60]:
import os
import cv2
import json
import torch
from argparse import ArgumentParser
from plyfile import PlyData, PlyElement
from utils.colmap_utils import (
    read_extrinsics_binary, 
    read_intrinsics_binary, 
    read_points3D_binary, 
    read_points3D_text,
    qvec2rotmat
)
from utils.graphics_utils import focal2fov

# 参数读取

In [14]:
# 读取json配置文件
with open('config.json','r') as file:
    json_data = json.load(file)

parser = ArgumentParser()
for key, value in json_data.items():
    parser.add_argument(f'--{key}', default=value)
parser.add_argument("--detect_anomaly", action="store_true", default=False)
parser.add_argument("--test_iterations", nargs="+", type=int, default=[100,1_000, 7_000, 30_000])
parser.add_argument("--save_iterations", nargs="+", type=int, default=[100,1_000, 7_000, 30_000])   
parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[100, 1_000, 7_000, 30_000])
parser.add_argument("--start_checkpoint", type=str, default = None)
args = parser.parse_args(args=['--source_path', './data',  '--model_path', './data/output'])
#config = vars(args)

# 数据读取

In [None]:
class GSInfo:
    

In [50]:
class GSDataLoader:
    def __init__(self, data_path, reading_dir, device):
        self.data_path = data_path
        self.device = device
        self.cameras = []
        self.points = {}
        self.load_colmap(reading_dir)
    
    def load_colmap(self, reading_dir):
        # 读取相机
        cameras_extrinsic_path = os.path.join(self.data_path, "sparse/0/images.bin")
        cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_path)        
        cameras_intrinsic_path = os.path.join(self.data_path, "sparse/0/cameras.bin")
        cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_path)
        
        # 读取图片
        images_folder=os.path.join(self.data_path, reading_dir)
        self.cameras = []
        for _, image_info in cam_extrinsics.items():
            intr = cam_intrinsics[image_info.camera_id]

            height = intr.height
            width = intr.width

            image_path = os.path.join(images_folder, image_info.name)
            image = cv2.imread(image_path)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            R = qvec2rotmat(image_info.qvec)
            T = image_info.tvec

            if intr.model == "SIMPLE_PINHOLE":
                # colmap params [f, cx, cy]
                fov_y = focal2fov(intr.params[0], height)
                fov_x = focal2fov(intr.params[0], width) 
            elif intr.model == "PINHOLE":
                # colmap params [fx, fy, cx, cy]
                fov_y = focal2fov(intr.params[1], height)
                fov_x = focal2fov(intr.params[0], width)
            else:
                assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!"

            self.cameras.append({
                'R': R,
                'T': T,
                'fov_x': fov_x,
                'fov_y': fov_y,
                'image': image,
                'image_path': image_path,
                'image_name': image_info.name,
                'width': width,
                'height': height
            })    
                
        # 读取点云
        ply_path = os.path.join(self.data_path, "sparse/0/points3D.ply")
        plydata = PlyData.read(ply_path)
        vertices = plydata['vertex']
        positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T
        colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0
        normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T
        self.points = {
            'positions': positions,
            'colors': colors,
            'normals': normals
        }