In [2]:
import os
import cv2
import json
import torch
from argparse import ArgumentParser
from typing import NamedTuple
from plyfile import PlyData
from utils.colmap_utils import (
    read_extrinsics_binary, 
    read_intrinsics_binary, 
    read_points3D_binary, 
    read_points3D_text,
    qvec2rotmat
)
from utils.graphics_utils import focal2fov

# 参数读取

In [3]:
# 读取json配置文件
with open('config.json','r') as file:
    json_data = json.load(file)

parser = ArgumentParser()
for key, value in json_data.items():
    parser.add_argument(f'--{key}', default=value)
parser.add_argument("--detect_anomaly", action="store_true", default=False)
parser.add_argument("--test_iterations", nargs="+", type=int, default=[100,1_000, 7_000, 30_000])
parser.add_argument("--save_iterations", nargs="+", type=int, default=[100,1_000, 7_000, 30_000])   
parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[100, 1_000, 7_000, 30_000])
parser.add_argument("--start_checkpoint", type=str, default = None)
args = parser.parse_args(args=['--source_path', './data',  '--model_path', './data/output'])
#config = vars(args)

# 数据读取

相机矩阵(4 X 4)：
$
viewMatrix = \begin{bmatrix}
R & T \\
0 & 1
\end{bmatrix}
$

投影矩阵(4 X 4)：
$
projMatrix = \begin{bmatrix}
\frac{2n}{r-l} & 0 & 0 & 0 \\
0 & \frac{2n}{t-b} & 0 & 0 \\
0 & 0 & \frac{f +n}{f - n} & \frac{2nf}{n - f} \\
0 & 0 & 1 & 0
\end{bmatrix}
$

In [22]:
class BasicPointCloud(NamedTuple):
    positions : np.array
    colors : np.array
    normals : np.array

In [23]:
class ImageInfo:
    def __init__(self, image_name, image_path, image, image_width, image_height, R, T, fov_x, fov_y, device):
        self.image = torch.from_numpy(image).to(device)
        self.image_name = image_name
        self.image_path = image_path        
        self.image_width = image_width
        self.image_height = image_height
        self.R = R
        self.T = T
        self.fov_x = fov_x
        self.fov_y = fov_y
        self.zfar = 100.0
        self.znear = 0.01
        
        self.viewMatrix = self.getViewMatrix(self.R, self.T).to(device)
        self.projMatrix = self.getProjMatrix(self.znear, self.zfar, self.fov_x, self.fov_y).to(device)
        self.viewProjMatrix = self.viewMatrix @ self.projMatrix
        self.cameraCenter = torch.inverse(self.viewMatrix)[:3, 3]
        
    def getViewMatrix(self, R, T):
        viewMatrix = torch.eye(4, dtype=torch.float32)
        viewMatrix[:3, :3] = torch.tensor(R, dtype=torch.float32)
        viewMatrix[:3, 3] = torch.tensor(T, dtype=torch.float32)
        
        return viewMatrix
        
    def getProjMatrix(self, znear, zfar, fov_x, fov_y):
        tan_fov_y = np.tan((fov_y / 2))
        tan_fov_x = np.tan((fov_x / 2))
        
        top = tan_fov_y * znear
        bottom = -top
        right = tan_fov_x * znear
        left = -right
        
        projMatrix = torch.zeros(4, 4)
        projMatrix[0, 0] = 2.0 * znear / (right - left)
        projMatrix[1, 1] = 2.0 * znear / (top - bottom)
        projMatrix[2, 2] = zfar / (zfar - znear)
        projMatrix[2, 3] = 2.0 * znear * zfar / (znear - zfar)
        projMatrix[3, 3] = 1.0
        
        return projMatrix

In [24]:
class GSDataLoader:
    def __init__(self, data_path, reading_dir, device):
        self.data_path = data_path
        self.device = device
        self.cameras = []
        self.points = {}
        self.loadColmap(reading_dir)
    
    def loadColmap(self, reading_dir):
        # 读取相机
        cameras_extrinsic_path = os.path.join(self.data_path, "sparse/0/images.bin")
        cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_path)        
        cameras_intrinsic_path = os.path.join(self.data_path, "sparse/0/cameras.bin")
        cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_path)
        
        # 读取图片
        images_folder=os.path.join(self.data_path, reading_dir)
        self.cameras = []
        for _, image_info in cam_extrinsics.items():
            intr = cam_intrinsics[image_info.camera_id]

            image_height = intr.height
            image_width = intr.width

            image_path = os.path.join(images_folder, image_info.name)
            image = cv2.imread(image_path)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            R = qvec2rotmat(image_info.qvec)
            T = image_info.tvec

            if intr.model == "SIMPLE_PINHOLE":
                # colmap params [f, cx, cy]
                fov_y = focal2fov(intr.params[0], image_height)
                fov_x = focal2fov(intr.params[0], image_width) 
            elif intr.model == "PINHOLE":
                # colmap params [fx, fy, cx, cy]
                fov_y = focal2fov(intr.params[1], image_height)
                fov_x = focal2fov(intr.params[0], image_width)
            else:
                assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!"

            self.cameras.append(ImageInfo(
                image_name=image_info.name, 
                image_path=image_path, 
                image=image, 
                image_width=image_width, 
                image_height=image_height, 
                R=R, T=T, 
                fov_x=fov_x,
                fov_y=fov_y, 
                device=self.device))
                
        # 读取点云
        ply_path = os.path.join(self.data_path, "sparse/0/points3D.ply")
        plydata = PlyData.read(ply_path)
        vertices = plydata['vertex']
        positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T
        colors = np.vstack([vertices['red 

In [25]:
device = torch.device('cpu')
data = GSDataLoader(args.source_path, args.images, device)

# GS模型