In [1]:
import os
import cv2
import json
import math
import torch
import torch.nn as nn
import torch.autograd.profiler as profiler
from argparse import ArgumentParser
from typing import NamedTuple
from plyfile import PlyData
from pykdtree.kdtree import KDTree
from utils.colmap_utils import (
    read_extrinsics_binary, 
    read_intrinsics_binary, 
    read_points3D_binary, 
    read_points3D_text,
    qvec2rotmat
)
from utils.camera_utils import focal2fov
from utils.model_utils import InverseSigmoid

# 参数读取

In [2]:
# 读取json配置文件
with open('config.json','r') as file:
    json_data = json.load(file)

parser = ArgumentParser()
for key, value in json_data.items():
    parser.add_argument(f'--{key}', default=value)
parser.add_argument("--detect_anomaly", action="store_true", default=False)
parser.add_argument("--test_iterations", nargs="+", type=int, default=[100,1_000, 7_000, 30_000])
parser.add_argument("--save_iterations", nargs="+", type=int, default=[100,1_000, 7_000, 30_000])   
parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[100, 1_000, 7_000, 30_000])
parser.add_argument("--start_checkpoint", type=str, default = None)
args = parser.parse_args(args=['--source_path', './data',  '--model_path', './data/output'])
#config = vars(args)

# 数据读取

相机矩阵(4 X 4)：
$
viewMatrix = \begin{bmatrix}
R & T \\
0 & 1
\end{bmatrix}
$

投影矩阵(4 X 4)：
$
projMatrix = \begin{bmatrix}
\frac{2n}{r-l} & 0 & 0 & 0 \\
0 & \frac{2n}{t-b} & 0 & 0 \\
0 & 0 & \frac{f +n}{f - n} & \frac{2nf}{n - f} \\
0 & 0 & 1 & 0
\end{bmatrix}
$

In [6]:
# camera_utils
class BasicPointCloud(NamedTuple):
    positions : np.array
    colors : np.array
    normals : np.array

In [7]:
# camera_utils
class ImageInfo:
    def __init__(self, image_name, image_path, image, image_width, image_height, R, T, fov_x, fov_y, focal_x, focal_y, device):
        self.image = torch.from_numpy(image) # 需要时读入显存
        self.image_name = image_name
        self.image_path = image_path        
        self.image_width = image_width
        self.image_height = image_height
        self.R = R
        self.T = T
        self.fov_x = fov_x
        self.fov_y = fov_y
        self.focal_x = focal_x
        self.focal_y = focal_y
        self.zfar = 100.0
        self.znear = 0.01
        
        self.view_matrix = self.getViewMatrix(self.R, self.T).to(device)
        self.proj_matrix = self.getProjMatrix(self.znear, self.zfar, self.fov_x, self.fov_y).to(device)
        self.view_proj_matrix = self.proj_matrix @ self.view_matrix
        self.camera_center = torch.inverse(self.view_matrix)[:3, 3]
        
    def getViewMatrix(self, R, T):
        view_matrix = torch.eye(4, dtype=torch.float32)
        view_matrix[:3, :3] = torch.tensor(R, dtype=torch.float32)
        view_matrix[:3, 3] = torch.tensor(T, dtype=torch.float32)
        
        return view_matrix
        
    def getProjMatrix(self, znear, zfar, fov_x, fov_y):
        tan_fov_y = np.tan((fov_y / 2))
        tan_fov_x = np.tan((fov_x / 2))
        
        top = tan_fov_y * znear
        bottom = -top
        right = tan_fov_x * znear
        left = -right
        
        proj_matrix = torch.zeros(4, 4)
        proj_matrix[0, 0] = 2.0 * znear / (right - left)
        proj_matrix[1, 1] = 2.0 * znear / (top - bottom)
        proj_matrix[2, 2] = zfar / (zfar - znear)
        proj_matrix[2, 3] = 2.0 * znear * zfar / (znear - zfar)
        proj_matrix[3, 2] = 1.0
        
        return proj_matrix

In [8]:
class GSDataLoader:
    def __init__(self, data_path, reading_dir, device):
        self.data_path = data_path
        self.device = device
        self.cameras = []
        self.points = {}
        self.loadColmap(reading_dir)
    
    def loadColmap(self, reading_dir):
        # 读取相机
        cameras_extrinsic_path = os.path.join(self.data_path, "sparse/0/images.bin")
        cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_path)        
        cameras_intrinsic_path = os.path.join(self.data_path, "sparse/0/cameras.bin")
        cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_path)
        
        # 读取图片
        images_folder=os.path.join(self.data_path, reading_dir)
        self.cameras = []
        for _, image_info in cam_extrinsics.items():
            intr = cam_intrinsics[image_info.camera_id]

            image_height = intr.height
            image_width = intr.width

            image_path = os.path.join(images_folder, image_info.name)
            image = cv2.imread(image_path)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            R = qvec2rotmat(image_info.qvec)
            T = image_info.tvec

            if intr.model == "SIMPLE_PINHOLE":
                # colmap params [f, cx, cy]
                focal_x = intr.params[0]
                focal_y = intr.params[0]
            elif intr.model == "PINHOLE":
                # colmap params [fx, fy, cx, cy]
                focal_x = intr.params[0]
                focal_y = intr.params[1]
                          
            else:
                assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!"

            fov_x = focal2fov(focal_x, image_width)
            fov_y = focal2fov(focal_y, image_height)
            
            self.cameras.append(ImageInfo(
                image_name=image_info.name, 
                image_path=image_path, 
                image=image, 
                image_width=image_width, 
                image_height=image_height, 
                R=R, T=T, 
                fov_x=fov_x,
                fov_y=fov_y,
                focal_x=focal_x,
                focal_y=focal_y, 
                device=self.device))
                
        # 读取点云
        ply_path = os.path.join(self.data_path, "sparse/0/points3D.ply")
        plydata = PlyData.read(ply_path)
        vertices = plydata['vertex']
        positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T
        colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0
        normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T
        self.points = BasicPointCloud(positions=positions, colors=colors, normals=normals)

In [9]:
device = torch.device('cuda')
data = GSDataLoader(args.source_path, args.images, device)

# GS模型

四元数$q=(w, x, y, z)$，需满足$ w^2 + x^2 + y^2 + z^2 =1$ 

四元数转旋转矩阵：
$
R = \begin{bmatrix}
1 - 2y^2 - 2z^2 & 2xy - 2wz & 2xz + 2wy \\
2xy + 2wz & 1 - 2x^2 - 2z^2 & 2yz - 2wx \\
2xz - 2wy & 2yz + 2wx & 1 - 2x^2 - 2y^2
\end{bmatrix}
$

In [10]:
# model_utils
def BuildRotation(r):
    # 归一化
    norm = torch.norm(r, dim=-1)
    q = r / norm[:, None]
    
    R = torch.zeros((q.shape[0], 3, 3), dtype=torch.float32, device='cuda')
    
    w = q[:, 0]
    x = q[:, 1]
    y = q[:, 2]
    z = q[:, 3]
    
    R[:, 0, 0] = 1 - 2*y*y - 2*z*z
    R[:, 0, 1] = 2*x*y-2*w*z
    R[:, 0, 2] = 2*x*z + 2*w*y    
    R[:, 1, 0] = 2*x*y + 2*w*z
    R[:, 1, 1] = 1 - 2*x*x - 2*z*z
    R[:, 1, 2] = 2*y*z - 2*w*x
    R[:, 2, 0] = 2*x*z - 2*w*y
    R[:, 2, 1] = 2*y*z + 2*w*x
    R[:, 2, 2] = 1 - 2*x*x - 2*y*y
    
    return R

In [11]:
# model_utils
def BuildRotationScaling(r, s):
    U = torch.zeros((s.shape[0], 3, 3), dtype=torch.float32, device='cuda')
    R = BuildRotation(r)
    
    U[:, 0, 0] = s[:, 0]
    U[:, 1, 1] = s[:, 1]   
    U[:, 2, 2] = s[:, 2]
    
    U = R @ U
    
    return U

In [12]:
# model_utils
def GetUTM(M):
    utm = torch.zeros((M.shape[0], 6), dtype=torch.float32, device='cuda')
    
    utm[:, 0] = M[:, 0, 0]
    utm[:, 1] = M[:, 0, 1]
    utm[:, 2] = M[:, 0, 2]
    utm[:, 3] = M[:, 1, 1]
    utm[:, 4] = M[:, 1, 2]    
    utm[:, 5] = M[:, 2, 2] 
    
    return utm

In [13]:
# sh_utils
C0 = 0.28209479177387814
C1 = 0.4886025119029199
C2 = [
    1.0925484305920792,
    -1.0925484305920792,
    0.31539156525252005,
    -1.0925484305920792,
    0.5462742152960396
]
C3 = [
    -0.5900435899266435,
    2.890611442640554,
    -0.4570457994644658,
    0.3731763325901154,
    -0.4570457994644658,
    1.445305721320277,
    -0.5900435899266435
]
C4 = [
    2.5033429417967046,
    -1.7701307697799304,
    0.9461746957575601,
    -0.6690465435572892,
    0.10578554691520431,
    -0.6690465435572892,
    0.47308734787878004,
    -1.7701307697799304,
    0.6258357354491761,
] 

def RGB2SH(rgb):
    return (rgb - 0.5) / C0

def SH2RGB(sh):
    return sh * C0 + 0.5

def EvalSh(deg, sh, dirs):
    assert deg <= 4 and deg >= 0
    coeff = (deg + 1) ** 2
    assert sh.shape[-1] >= coeff
    
    result = C0 * sh[..., 0]
    if deg > 0:
        x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
        result = (result -
                C1 * y * sh[..., 1] +
                C1 * z * sh[..., 2] -
                C1 * x * sh[..., 3])

        if deg > 1:
            xx, yy, zz = x * x, y * y, z * z
            xy, yz, xz = x * y, y * z, x * z
            result = (result +
                    C2[0] * xy * sh[..., 4] +
                    C2[1] * yz * sh[..., 5] +
                    C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
                    C2[3] * xz * sh[..., 7] +
                    C2[4] * (xx - yy) * sh[..., 8])

            if deg > 2:
                result = (result +
                C3[0] * y * (3 * xx - yy) * sh[..., 9] +
                C3[1] * xy * z * sh[..., 10] +
                C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
                C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
                C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
                C3[5] * z * (xx - yy) * sh[..., 14] +
                C3[6] * x * (xx - 3 * yy) * sh[..., 15])

                if deg > 3:
                    result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
                            C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
                            C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
                            C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
                            C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
                            C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
                            C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
                            C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
                            C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
    return result

In [14]:
class GSModel(nn.Module):
    def __init__(self, sh_degree):
        super().__init__()
        self.active_sh_degree = 0
        self.max_sh_degree = sh_degree
        self._xyz = torch.empty(0)
        self._features_dc = torch.empty(0)
        self._features_rest = torch.empty(0)
        self._scaling = torch.empty(0)
        self._rotation = torch.empty(0)
        self._opacity = torch.empty(0)
        self.setupFunctions()
    
    def setupFunctions(self):
        def buildCvarianceFomRotationScaling(rotation, scaling, scaling_modifier):
            U = BuildRotationScaling(rotation, scaling * scaling_modifier)
            covariance = U @ U.transpose(1, 2)
            utm = GetUTM(covariance)
            return utm
        
        self.scaling_activation = torch.exp
        self.scaling_inverse_activation = torch.log
        
        self.opacity_activation = torch.sigmoid
        self.inverse_opacity_activation = InverseSigmoid
        
        self.rotation_activation = torch.nn.functional.normalize
        
        self.covariance_activation = buildCvarianceFomRotationScaling
    
    # 重点计算sacling的初始化尺度
    def createFromPointCloud(self, point_cloud):
        N = point_cloud.positions.shape[0]
        # 位置
        xyz = torch.tensor(point_cloud.positions, dtype=torch.float32, device='cuda')
        # 不透明度
        opacity = InverseSigmoid(torch.ones((N, 1), dtype=torch.float32, device='cuda') * 0.1)
        # 旋转
        rotation = torch.zeros((N, 4), dtype=torch.float32, device='cuda')
        rotation[:, 0] = 1.0
        # 尺度
        kd_tree = KDTree(point_cloud.positions)
        dist, idx = kd_tree.query(point_cloud.positions, k=4)
        mean_dist = dist[:, 1:].mean(axis=1)
        mean_dist = torch.tensor(np.log(mean_dist), dtype=torch.float32, device='cuda').unsqueeze(dim=1)
        scaling = torch.ones(3, dtype=torch.float32, device='cuda') * mean_dist
        #颜色
        fused_color = torch.tensor(RGB2SH(data.points.colors), dtype=torch.float32, device='cuda')
        features = torch.zeros((N, 3, (self.max_sh_degree+1) ** 2), dtype=torch.float32, device='cuda')
        features[:, :, 0] = fused_color
        
        self._xyz = nn.Parameter(xyz.requires_grad_(True))
        self._features_dc = nn.Parameter(features[:,:,0:1].contiguous().requires_grad_(True))
        self._features_rest = nn.Parameter(features[:,:,1:].contiguous().requires_grad_(True))
        self._scaling = nn.Parameter(scaling.requires_grad_(True))
        self._rotation = nn.Parameter(rotation.requires_grad_(True))
        self._opacity = nn.Parameter(opacity.requires_grad_(True))
    
    def getCovariance(self, scaling_modifier = 1):
        return self.covariance_activation(self._rotation, self._scaling, scaling_modifier)
    
    @property
    def getScaling(self):
        return self.scaling_activation(self._scaling)
    
    @property
    def getRotation(self):
        return self.rotation_activation(self._rotation)
    
    @property
    def getXyz(self):
        return self._xyz
    
    @property
    def getFeatures(self):
        features_dc = self._features_dc
        features_rest = self._features_rest
        return torch.cat((features_dc, features_rest), dim=-1)

    @property
    def getOpacity(self):
        return self.opacity_activation(self._opacity)
    
    #读取GS点云
    def loadPly(self, path):
        plydata = PlyData.read(path)
        xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
                        np.asarray(plydata.elements[0]["y"]),
                        np.asarray(plydata.elements[0]["z"])),  axis=1)
        opacity = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
        features_dc = np.zeros((xyz.shape[0], 3, 1))
        features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
        features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])       
        features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
        
        rest_names = [s.name for s in plydata.elements[0].properties if s.name.startswith("f_rest")]
        feature_rest = np.zeros((xyz.shape[0], len(rest_names)))
        for idx, rest_name in enumerate(rest_names):
            feature_rest[:, idx] = np.asarray(plydata.elements[0][rest_name])
        feature_rest = feature_rest.reshape((feature_rest.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))
        
        rotation = np.zeros((xyz.shape[0], 4))
        rotation[:, 0] = np.asarray(plydata.elements[0]["rot_0"])
        rotation[:, 1] = np.asarray(plydata.elements[0]["rot_1"])       
        rotation[:, 2] = np.asarray(plydata.elements[0]["rot_2"])
        rotation[:, 3] = np.asarray(plydata.elements[0]["rot_3"])
        
        scaling = np.zeros((xyz.shape[0], 3))
        scaling[:, 0] = np.asarray(plydata.elements[0]["scale_0"])
        scaling[:, 1] = np.asarray(plydata.elements[0]["scale_1"])       
        scaling[:, 2] = np.asarray(plydata.elements[0]["scale_2"])
        
        self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float32, device="cuda").requires_grad_(True))
        self._opacity = nn.Parameter(torch.tensor(opacity, dtype=torch.float32, device="cuda").requires_grad_(True))
        self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float32, device="cuda").contiguous().requires_grad_(True))
        self._features_rest = nn.Parameter(torch.tensor(feature_rest, dtype=torch.float32, device="cuda").contiguous().requires_grad_(True))
        self._rotation = nn.Parameter(torch.tensor(rotation, dtype=torch.float32, device="cuda").requires_grad_(True))
        self._scaling = nn.Parameter(torch.tensor(scaling, dtype=torch.float32, device="cuda").requires_grad_(True))
        
        self.active_sh_degree = self.max_sh_degree

In [15]:
gaussian = GSModel(3)

In [138]:
gaussian.createFromPointCloud(data.points)

In [16]:
gaussian.loadPly("data/point_cloud.ply")

# 渲染

In [161]:
class GSRender(nn.Module):
    def __init__(self, active_sh_degree):
        super().__init__()
        self.active_sh_degree = active_sh_degree
        
    def homogenize(self, point):
        return torch.cat((point, torch.ones_like(point[..., :1])), dim=-1)
        
    def buildPoint(self, xyz, view_matrix, view_proj_matrix):
        h_points = self.homogenize(xyz)
        # proj
        p_points_t = h_points @ view_proj_matrix.T
        p_points_w = 1.0 / (p_points_t[..., -1:] + 0.000001)
        p_points = p_points_t * p_points_w
        # view
        v_points =  h_points @ view_matrix.T
        in_mask = v_points[..., 2] >= 0.2
        return p_points, v_points, in_mask
    
    def buildColor(self, xyz, shs, camera_center):
        ray_o = camera_center
        ray_d = xyz - camera_center
        color = EvalSh(self.active_sh_degree, shs, ray_d)
        color = (color + 0.5).clip(min=0)
        return color
        
    def buildCov3D(self, r, s):
        U = BuildRotationScaling(r, s)
        cov3D = U @ U.transpose(1, 2)
        return cov3D
    
    def buildCov2D(self, v_points, cov3D, view_matrix, fov_x, fov_y, focal_x, focal_y):
        tan_fovx = math.tan(fov_x * 0.5)
        tan_fovy = math.tan(fov_y * 0.5)
        tx = (v_points[..., 0] / v_points[..., 2]).clip(min=-tan_fovx*1.3, max=tan_fovx*1.3) * v_points[..., 2]
        ty = (v_points[..., 1] / v_points[..., 2]).clip(min=-tan_fovy*1.3, max=tan_fovy*1.3) * v_points[..., 2]
        tz = v_points[..., 2]
        
        J = torch.zeros(v_points.shape[0], 3, 3).to(v_points)
        J[..., 0, 0] = 1 / tz * focal_x
        J[..., 0, 2] = -tx / (tz * tz) * focal_x
        J[..., 1, 1] = 1 / tz * focal_y
        J[..., 1, 2] = -ty / (tz * tz) * focal_y
        
        W = view_matrix[:3,:3]
        cov2D = J @ W @ cov3D @ W.T @ J.permute(0,2,1)
    
        low_filter = torch.eye(2,2).to(cov2D) * 0.3
        return cov2D[:, :2, :2] + low_filter[None]
    
    def render(self, uv, cov2D, color, opacity, depth, height, width):
        radius = self.getRadius(cov2D)
        rect_min, rect_max = self.getRect(uv, radius, height, width)
        
        pix_coord = self.getPixCoord(height, width)
        render_color = torch.ones((height, width, 3)).to('cuda')
        render_depth = torch.zeros((height, width, 1)).to('cuda')
        
        TILE_SIZE = 16
        for h in range(0, height, TILE_SIZE):
            for w in range(0, width, TILE_SIZE):
                over_l = rect_min[..., 0].clip(min=w)
                over_t = rect_min[..., 1].clip(min=h)
                over_r = rect_max[..., 0].clip(max=w+TILE_SIZE-1)
                over_b = rect_max[..., 1].clip(max=h+TILE_SIZE-1)
                in_mask = (over_r > over_l) & (over_b > over_t)
                
                if not torch.any(in_mask):
                    continue
                
                s_depth, s_index = torch.sort(depth[in_mask])
                s_uv = uv[in_mask][s_index]
                s_cov2D = cov2D[in_mask][s_index]
                s_conic = s_cov2D.inverse()
                s_opacity = opacity[in_mask][s_index]
                s_color = color[in_mask][s_index]
                
                tile_coord = pix_coord[h:h+TILE_SIZE, w:w+TILE_SIZE].flatten(0,1)
                
                dx =  tile_coord[:, None, :] - s_uv[None] 
                h_s_conic = s_conic.unsqueeze(0).repeat(tile_coord.shape[0], 1, 1, 1)
                power = (dx[..., None, :] @ h_s_conic @ dx[..., None]).flatten(2)
                power = (-0.5 * power).clip(max=0)
           
                gauss2D = torch.exp(power)
            
                alpha = s_opacity[None] * gauss2D
                alpha = alpha.clip(max=0.99)
                
                T = torch.cat([torch.ones_like(alpha[:,:1]), 1-alpha[:,:-1]], dim=1).cumprod(dim=1)
                
                tile_color = (T * alpha * s_color[None]).sum(dim=1)
                
                tile_depth = (T * alpha * s_depth[None,:,None]).sum(dim=1)
                reshape_x = min(width-w, TILE_SIZE)
                reshape_y = min(height-h, TILE_SIZE)
                render_color[h:h+TILE_SIZE, w:w+TILE_SIZE] = tile_color.reshape(reshape_y, reshape_x, -1)
                render_depth[h:h+TILE_SIZE, w:w+TILE_SIZE] = tile_depth.reshape(reshape_y, reshape_x, -1)

        return render_color, render_depth
        
    @torch.no_grad()
    def getRadius(self, cov2D):
        a = cov2D[:, 0, 0]
        b = cov2D[:, 0, 1]
        c = cov2D[:, 1, 0]
        d = cov2D[:, 1, 1]
        lambda1 = 0.5 * (a + d) + 0.5 * torch.sqrt((a - d) * (a - d) + 4 * b * c)
        return 3.0 * torch.sqrt(lambda1).ceil()
    
    @torch.no_grad()
    def getRect(self, pix, radius, height, width):
        rect_min = pix - radius[:, None]
        rect_max = pix + radius[:, None]
        rect_min[..., 0] = rect_min[..., 0].clip(0, width-1.0)
        rect_min[..., 1] = rect_min[..., 1].clip(0, height-1.0)
        rect_max[..., 0] = rect_max[..., 0].clip(0, width-1.0)
        rect_max[..., 1] = rect_max[..., 1].clip(0, height-1.0)
        return rect_min, rect_max
    
    @torch.no_grad()
    def getPixCoord(self, height, width):
        x = torch.arange(width)
        y = torch.arange(height)

        x_grid, y_grid = torch.meshgrid(x, y, indexing='xy')
        
        pix_coord = torch.stack((x_grid, y_grid), dim=-1).to('cuda')
        return pix_coord
        
    def forward(self, model, camera):
        opacity = model.getOpacity
        xyz = model.getXyz
        rotation = model.getRotation
        scaling = model.getScaling
        features = model.getFeatures
        
        prof = profiler.record_function
        
        with prof("build point"):
            p_points, v_points, in_mask = self.buildPoint(xyz, camera.view_matrix, camera.view_proj_matrix)
            p_points = p_points[in_mask]
            v_points = v_points[in_mask]
            u = (p_points[..., 0] + 1) * 0.5 * camera.image_width
            v = (p_points[..., 1] + 1) * 0.5 * camera.image_height
            uv = torch.stack([u, v], dim=-1)
            depth = v_points[:,2]
            
        with prof("build color"):
            xyz = xyz[in_mask]
            features = features[in_mask]
            color = self.buildColor(xyz, features, camera.camera_center)
        
        with prof("build cov3D"):
            rotation = rotation[in_mask]
            scaling = scaling[in_mask]
            cov3D = self.buildCov3D(rotation, scaling)

        with prof("build cov2D"):
            cov2D = self.buildCov2D(v_points=v_points, cov3D=cov3D, view_matrix = camera.view_matrix,
                                    fov_x=camera.fov_x, fov_y=camera.fov_y, focal_x=camera.focal_x, focal_y=camera.focal_y)
        
        with prof("render"):
            opacity = opacity[in_mask]
            render_color, render_depth = self.render(uv, cov2D, color, opacity, depth, camera.image_height, camera.image_width)
        return render_color, render_depth

In [162]:
render = GSRender(3)

In [163]:
render_color, render_depth = render(gaussian, data.cameras[1])

In [166]:
import imageio
rgb_pd = render_color.detach().cpu().numpy()
to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8)
#cv2.imwrite("image.png",  to8b(rgb_pd))
imageio.imwrite("image.png", to8b(rgb_pd))

In [22]:
cov3D, cov2D = render(gaussian, data.cameras[0])

In [None]:
T = 1
alpha = [0.2, 0.5, 0.2, 0.3]
d = [1, 1.5, 5, 6]

In [None]:
C= alpha[0] * d[0]
for i in range(1, 4):
    test_T = T * (1-alpha[i-1])
    w = test_T * alpha[i]
    C += test_T * alpha[i] * d[i]
    T = test_T
    print(w)