# Generate Convolutional Layers by Splatting Gaussians

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
import glm
import math
import dataclasses
import matplotlib.pyplot as plt
import diff_gaussian_rasterization_batched as dgrb

In [None]:
device = torch.device("cuda")

In [None]:
@dataclasses.dataclass
class GaussianSplattingRendererConfig:
    fov: float
    img_size: int
    img_channels: int
    img_height: int
    img_width: int
    projection_matrix: torch.Tensor

class GaussianSplattingRenderer(nn.Module):
    def __init__(self, config: GaussianSplattingRendererConfig):
        super(GaussianSplattingRenderer, self).__init__()
        self.config = config
        self.renderer = dgrb.GaussianRasterizer(
            dgrb.GaussianRasterizationSettings(
                image_channels=config.img_channels,
                image_height=config.img_size,
                image_width=config.img_size,
                tanfovx=math.tan(config.fov * 0.5),
                tanfovy=math.tan(config.fov * 0.5),
                bg=torch.zeros(config.img_channels, dtype=torch.float32, device=device),
                scale_modifier=1.0,
                projmatrix=config.projection_matrix,
                sh_degree=1,
                campos=torch.tensor([0, 0, 0], dtype=torch.float32, device=device),
                prefiltered=False,
                debug=True,
            )
        )
    
    def forward(self, g_mean: torch.Tensor, g_scale: torch.Tensor, g_quat: torch.Tensor, g_color: torch.Tensor, view_matrix: torch.Tensor) -> torch.Tensor:
        B, N, _ = g_mean.shape
        assert g_mean.shape == (B, N, 3)
        assert g_scale.shape == (B, N, 3)
        assert g_quat.shape == (B, N, 4)
        assert g_color.shape == (B, N, self.config.img_channels)
        assert view_matrix.shape == (B, 4, 4)

        # normalize quaternions
        g_quat = F.normalize(g_quat, dim=2)

        # [B, config.img_channels, H, W]
        rendered_image, _ = self.rasterizer(
            means3D=g_mean,
            means2D=torch.zeros(B, N, 2, dtype=torch.float32, device=device),
            colors_precomp=g_color,
            opacities=torch.ones(B, N, dtype=torch.float32, device=device),
            scales=g_scale,
            rotations=g_quat,
            viewmatrixes=view_matrix,
        )

        return rendered_image


In [None]:
def random_point_on_sphere(r: float) -> list[float]:
    phi = torch.rand(1).item() * 2 * math.pi
    theta = torch.rand(1).item() * math.pi
    return [
        r * math.sin(theta) * math.cos(phi),
        r * math.sin(theta) * math.sin(phi),
        r * math.cos(theta),
    ]


def toTensor(m: glm.mat4) -> torch.Tensor:
    return torch.tensor(m.to_list(), dtype=torch.float32, device=device)


class GaussianSplattingLayer(nn.Module):
    # Field of view of each camera
    FOV = math.radians(60.0)

    # distance from the camera to the origin
    CAM_DISTANCE = 3.0

    def __init__(
        self, filter_size: int, in_channels: int, out_channels: int, n_gaussians: int
    ):
        super(GaussianSplattingLayer, self).__init__()
        self.filter_size = filter_size
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.n_gaussians = n_gaussians

        # constructs a renderer
        # assuming the points are in the box [-1 1]^3,
        # and using a FOV of 60 and a distance of 3,
        # then the points they should just about fill the viewport
        self.renderer = GaussianSplattingRenderer(
            GaussianSplattingRendererConfig(
                fov=self.FOV,
                img_size=self.filter_size,
                img_channels=self.in_channels,
                img_height=self.filter_size,
                img_width=self.filter_size,
                projection_matrix=toTensor(glm.perspective(self.FOV, 1.0, 0.1, 20.0)),
            )
        )

        # fixed camera positions (should be learnable later on)
        # camera_viewmatrices in (out_channels, 4, 4)
        self.camera_viewmatrices = torch.stack(
            [
                toTensor(
                    glm.scale([1, 1, -1])
                    * glm.lookAt(
                        random_point_on_sphere(self.CAM_DISTANCE),
                        [0, 0, 0],
                        [0, 1, 0],
                    )
                )
                for _ in range(out_channels)
            ]
        )

        # Learnable weights

        # g_mean in (n_gaussians, 3)
        self.g_mean = nn.Parameter(torch.randn(n_gaussians, 3, device=device))
        # g_scales in (n_gaussians, 3)
        self.g_scales = nn.Parameter(torch.randn(n_gaussians, 3, device=device))
        # g_quats in (n_gaussians, 4)
        self.g_quats = nn.Parameter(torch.randn(n_gaussians, 4, device=device))
        # g_colors in (n_gaussians, in_channels)
        self.g_colors = nn.Parameter(
            torch.randn(n_gaussians, in_channels, device=device)
        )

    def forward(self, image: torch.Tensor) -> torch.Tensor:
        # image has shape (B, in_channels, H, W)
        B, _, H, W = image.shape
        assert image.shape == (B, self.in_channels, H, W)

        # construct the kernel

        # kernel in (out_channels, in_channels, filter_size, filter_size)
        kernel = self.renderer(
            self.g_mean.expand(self.out_channels, self.n_gaussians, 3),
            self.g_scales.expand(self.out_channels, self.n_gaussians, 3),
            self.g_quats.expand(self.out_channels, self.n_gaussians, 4),
            self.g_colors.expand(self.out_channels, self.n_gaussians, self.in_channels),
            self.camera_viewmatrices,
        )

        # out in (B, out_channels, H, W)
        out = F.conv2d(image, kernel, padding='same')

        return out