In [1]:
import os
import cv2
import json
import torch
import torch.nn as nn
from argparse import ArgumentParser
from typing import NamedTuple
from plyfile import PlyData
from pykdtree.kdtree import KDTree
from utils.colmap_utils import (
    read_extrinsics_binary, 
    read_intrinsics_binary, 
    read_points3D_binary, 
    read_points3D_text,
    qvec2rotmat
)
from utils.camera_utils import focal2fov
from utils.model_utils import InverseSigmoid

  from .autonotebook import tqdm as notebook_tqdm


# 参数读取

In [3]:
# 读取json配置文件
with open('config.json','r') as file:
    json_data = json.load(file)

parser = ArgumentParser()
for key, value in json_data.items():
    parser.add_argument(f'--{key}', default=value)
parser.add_argument("--detect_anomaly", action="store_true", default=False)
parser.add_argument("--test_iterations", nargs="+", type=int, default=[100,1_000, 7_000, 30_000])
parser.add_argument("--save_iterations", nargs="+", type=int, default=[100,1_000, 7_000, 30_000])   
parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[100, 1_000, 7_000, 30_000])
parser.add_argument("--start_checkpoint", type=str, default = None)
args = parser.parse_args(args=['--source_path', './data',  '--model_path', './data/output'])
#config = vars(args)

# 数据读取

相机矩阵(4 X 4)：
$
viewMatrix = \begin{bmatrix}
R & T \\
0 & 1
\end{bmatrix}
$

投影矩阵(4 X 4)：
$
projMatrix = \begin{bmatrix}
\frac{2n}{r-l} & 0 & 0 & 0 \\
0 & \frac{2n}{t-b} & 0 & 0 \\
0 & 0 & \frac{f +n}{f - n} & \frac{2nf}{n - f} \\
0 & 0 & 1 & 0
\end{bmatrix}
$

In [4]:
# camera_utils
class BasicPointCloud(NamedTuple):
    positions : np.array
    colors : np.array
    normals : np.array

In [5]:
# camera_utils
class ImageInfo:
    def __init__(self, image_name, image_path, image, image_width, image_height, R, T, fov_x, fov_y, device):
        self.image = torch.from_numpy(image) # 需要时读入显存
        self.image_name = image_name
        self.image_path = image_path        
        self.image_width = image_width
        self.image_height = image_height
        self.R = R
        self.T = T
        self.fov_x = fov_x
        self.fov_y = fov_y
        self.zfar = 100.0
        self.znear = 0.01
        
        self.viewMatrix = self.getViewMatrix(self.R, self.T).to(device)
        self.projMatrix = self.getProjMatrix(self.znear, self.zfar, self.fov_x, self.fov_y).to(device)
        self.viewProjMatrix = self.viewMatrix @ self.projMatrix
        self.cameraCenter = torch.inverse(self.viewMatrix)[:3, 3]
        
    def getViewMatrix(self, R, T):
        viewMatrix = torch.eye(4, dtype=torch.float32)
        viewMatrix[:3, :3] = torch.tensor(R, dtype=torch.float32)
        viewMatrix[:3, 3] = torch.tensor(T, dtype=torch.float32)
        
        return viewMatrix
        
    def getProjMatrix(self, znear, zfar, fov_x, fov_y):
        tan_fov_y = np.tan((fov_y / 2))
        tan_fov_x = np.tan((fov_x / 2))
        
        top = tan_fov_y * znear
        bottom = -top
        right = tan_fov_x * znear
        left = -right
        
        projMatrix = torch.zeros(4, 4)
        projMatrix[0, 0] = 2.0 * znear / (right - left)
        projMatrix[1, 1] = 2.0 * znear / (top - bottom)
        projMatrix[2, 2] = zfar / (zfar - znear)
        projMatrix[2, 3] = 2.0 * znear * zfar / (znear - zfar)
        projMatrix[3, 3] = 1.0
        
        return projMatrix

In [6]:
class GSDataLoader:
    def __init__(self, data_path, reading_dir, device):
        self.data_path = data_path
        self.device = device
        self.cameras = []
        self.points = {}
        self.loadColmap(reading_dir)
    
    def loadColmap(self, reading_dir):
        # 读取相机
        cameras_extrinsic_path = os.path.join(self.data_path, "sparse/0/images.bin")
        cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_path)        
        cameras_intrinsic_path = os.path.join(self.data_path, "sparse/0/cameras.bin")
        cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_path)
        
        # 读取图片
        images_folder=os.path.join(self.data_path, reading_dir)
        self.cameras = []
        for _, image_info in cam_extrinsics.items():
            intr = cam_intrinsics[image_info.camera_id]

            image_height = intr.height
            image_width = intr.width

            image_path = os.path.join(images_folder, image_info.name)
            image = cv2.imread(image_path)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            R = qvec2rotmat(image_info.qvec)
            T = image_info.tvec

            if intr.model == "SIMPLE_PINHOLE":
                # colmap params [f, cx, cy]
                fov_y = focal2fov(intr.params[0], image_height)
                fov_x = focal2fov(intr.params[0], image_width) 
            elif intr.model == "PINHOLE":
                # colmap params [fx, fy, cx, cy]
                fov_y = focal2fov(intr.params[1], image_height)
                fov_x = focal2fov(intr.params[0], image_width)
            else:
                assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!"

            self.cameras.append(ImageInfo(
                image_name=image_info.name, 
                image_path=image_path, 
                image=image, 
                image_width=image_width, 
                image_height=image_height, 
                R=R, T=T, 
                fov_x=fov_x,
                fov_y=fov_y, 
                device=self.device))
                
        # 读取点云
        ply_path = os.path.join(self.data_path, "sparse/0/points3D.ply")
        plydata = PlyData.read(ply_path)
        vertices = plydata['vertex']
        positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T
        colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0
        normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T
        self.points = BasicPointCloud(positions=positions, colors=colors, normals=normals)

In [7]:
device = torch.device('cpu')
data = GSDataLoader(args.source_path, args.images, device)

# GS模型

四元数$q=(w, x, y, z)$，需满足$ w^2 + x^2 + y^2 + z^2 =1$ 

四元数转旋转矩阵：
$
R = \begin{bmatrix}
1 - 2y^2 - 2z^2 & 2xy - 2wz & 2xz + 2wy \\
2xy + 2wz & 1 - 2x^2 - 2z^2 & 2yz - 2wx \\
2xz - 2wy & 2yz + 2wx & 1 - 2x^2 - 2y^2
\end{bmatrix}
$

In [None]:
# model_utils
def BuildRotation(r):
    # 归一化
    norm = torch.norm(r, dim=-1)
    q = r / norm[:, None]
    
    R = torch.zeros((q.shaep[0], 3, 3), dtype=torch.float32, device='cuda')
    
    w = q[:, 0]
    x = q[:, 1]
    y = q[:, 2]
    z = q[:, 3]
    
    R[:, 0, 0] = 1 - 2*y*y - 2*z*z
    R[:, 0, 1] = 2*x*y-2*w*z
    R[:, 0, 2] = 2*x*z + 2*w*y    
    R[:, 1, 0] = 2*x*y + 2*w*z
    R[:, 1, 1] = 1 - 2*x*x - 2*z*z
    R[:, 1, 2] = 2*y*z - 2*w*x
    R[:, 2, 0] = 2*x*z - 2*w*y
    R[:, 2, 1] = 2*y*z + 2*w*x
    R[:, 2, 2] = 1 - 2*x*x - 2*y*y
    
    return R

In [None]:
# model_utils
def BuildRotationScaling(r, s):
    U = torch.zeors((s.shape[0], 3, 3), dtype=torch.float32, device='cuda')
    R = BuildRotation(r)
    
    U[:, 0, 0] = s[:, 0]
    U[:, 1, 1] = s[:, 1]   
    U[:, 2, 2] = s[:, 2]
    
    U = R @ U
    
    return U

In [None]:
# model_utils
def GetUTM(M):
    utm = torch.zeros((M.shape[0], 6), dtype=torch.float32, device='cuda')
    
    utm[:, 0] = M[:, 0, 0]
    utm[:, 1] = M[:, 0, 1]
    utm[:, 2] = M[:, 0, 2]
    utm[:, 3] = M[:, 1, 1]
    utm[:, 4] = M[:, 1, 2]    
    utm[:, 5] = M[:, 2, 2] 
    
    return utm

In [12]:
# sh_utils
C0 = 0.28209479177387814
def RGB2SH(rgb):
    return (rgb - 0.5) / C0

In [92]:
class GSModel(nn.Module):
    def __init__(self, sh_degree):
        super().__init__()
        self.max_sh_degree = sh_degree
        self._xyz = torch.empty(0)
        self._features_dc = torch.empty(0)
        self._features_rest = torch.empty(0)
        self._scaling = torch.empty(0)
        self._rotation = torch.empty(0)
        self._opacity = torch.empty(0)
        self.setupFunctions()
    
    def setupFunctions(self):
        def buildCvarianceFomRotationScaling(rotation, sacling, scaling_modifier):
            U = BuildRotationScaling(rotation, scaling * scaling_modifier)
            covariance = U @ U.transpose(1, 2)
            utm = GetUTM(covariance)
            return utm
        
        self.scaling_activation = torch.exp
        self.scaling_inverse_activation = torch.log
        
        self.opacity_activation = torch.sigmoid
        self.inverse_opacity_activation = InverseSigmoid
        
        self.rotation_activation = torch.nn.functional.normalize
        
        self.covariance_activation = buildCvarianceFomRotationScaling
    
    # 重点计算sacling的初始化尺度
    def createFromPointCloud(self, point_cloud):
        N = point_cloud.positions.shape[0]
        # 位置
        xyz = torch.tensor(point_cloud.positions, dtype=torch.float32, device='cuda')
        # 不透明度
        opacity = torch.ones((N, 1), dtype=torch.float32, device='cuda') * InverseSigmoid(0.1)
        # 旋转
        rotation = torch.zeros((N, 4), dtype=torch.float32, device='cuda')
        rotation[:, 0] = 1.0
        # 尺度
        kd_tree = KDTree(point_cloud.positions)
        dist, idx = kd_tree.query(point_cloud.positions, k=4)
        mean_dist = dist[:, 1:].mean(axis=1)
        mean_dist = torch.tensor(np.log(mean_dist), dtype=torch.float32, device='cuda').unsqueeze(dim=1)
        scaling = torch.ones(3) * mean_dist
        #颜色
        fused_color = torch.tensor(RGB2SH(data.points.colors), dtype=torch.float32, device='cuda')
        features = torch.zeros((N, 3, (self.max_sh_degree+1) ** 2), dtype=torch.float32, device='cuda')
        features[:, :, 0] = fused_color
        
        
        self._xyz = nn.Parameter(xyz.requires_grad_(True))
        self._features_dc = nn.Parameter(features[:,:,0].requires_grad_(True))
        self._features_rest = nn.Parameter(features[:,:,1:].requires_grad_(True))
        self._scaling = nn.Parameter(scaling.requires_grad_(True))
        self._rotation = nn.Parameter(rotation.requires_grad_(True))
        self._opacity = nn.Parameter(opacity.requires_grad_(True))

NameError: name 'nn' is not defined

In [13]:
RGB2SH(data.points.colors)

array([[-0.77153874, -0.77153874, -0.98006272],
       [-0.78544033, -0.85494833, -0.99396432],
       [ 1.73074905,  1.75855225,  1.74465065],
       ...,
       [-1.35540589, -1.28589789, -1.38320908],
       [-1.1468819 , -1.03566911, -1.1885867 ],
       [-0.0208524 ,  0.0208524 ,  0.14596679]])

In [9]:
kd_tree = KDTree(data.points.positions)
dist, idx = kd_tree.query(data.points.positions, k=4)
mean_dist = dist[:, 1:].mean(axis=1)

In [10]:
mean_dist = torch.tensor(np.log(mean_dist), dtype=torch.float32).unsqueeze(dim=1)

In [11]:
mean_dist * torch.ones(3)

tensor([[-5.5668, -5.5668, -5.5668],
        [-5.5065, -5.5065, -5.5065],
        [-4.5285, -4.5285, -4.5285],
        ...,
        [-0.6764, -0.6764, -0.6764],
        [-2.6339, -2.6339, -2.6339],
        [-4.1649, -4.1649, -4.1649]])

In [89]:
torch.ones((mean_dist.shape[0], 3), dtype=torch.float32) * torch.tensor(np.log(mean_dist)).unsqueeze(dim=1)

tensor([[-5.5668, -5.5668, -5.5668],
        [-5.5065, -5.5065, -5.5065],
        [-4.5285, -4.5285, -4.5285],
        ...,
        [-0.6764, -0.6764, -0.6764],
        [-2.6339, -2.6339, -2.6339],
        [-4.1649, -4.1649, -4.1649]])

In [88]:
torch.ones((mean_dist.shape[0], 3), dtype=torch.float32) * torch.tensor(np.log(mean_dist)).unsqueeze(dim=1)

tensor([[-5.5668, -5.5668, -5.5668],
        [-5.5065, -5.5065, -5.5065],
        [-4.5285, -4.5285, -4.5285],
        ...,
        [-0.6764, -0.6764, -0.6764],
        [-2.6339, -2.6339, -2.6339],
        [-4.1649, -4.1649, -4.1649]])