In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
import math
import matplotlib.pyplot as plt
import pyquaternion as pyq
import diff_gaussian_rasterization

In [2]:
device = torch.device("cuda")

In [3]:
# predicts a set of gaussians that accurately represent an input image
class GaussianPredictor(nn.Module):
    def __init__(self, n_gaussians=7):
        super().__init__()
        self.n_gaussians = n_gaussians
        # 64x64 -> 32x32
        self.conv1 = nn.Conv2d(3, 32, 3, stride=2, padding=1)
        self.conv1_bn = nn.BatchNorm2d(32)
        # 32x32 -> 16x16
        self.conv2 = nn.Conv2d(32, 64, 3, stride=2, padding=1)
        self.conv2_bn = nn.BatchNorm2d(64)
        # 16x16 -> 8x8
        self.conv3 = nn.Conv2d(64, 128, 3, stride=2, padding=1)
        self.conv3_bn = nn.BatchNorm2d(128)
        self.fc1 = nn.Linear(128 * 8 * 8, 1024)
        self.fc2 = nn.Linear(1024, n_gaussians*7) # [x, y, sx, sy, rx, ry, o] * n_gaussians

    def forward(self, x):
        # x: B x 3 x 64 x 64
        x = F.relu(self.conv1_bn(self.conv1(x)))
        x = F.relu(self.conv2_bn(self.conv2(x)))
        x = F.relu(self.conv3_bn(self.conv3(x)))
        x = x.flatten(start_dim=1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        x = x.view(-1, self.n_gaussians, 7)
        return x

In [4]:
FoVx = 90
FoVy = 90

def gen_orthographic_matrix(left: float, right: float, bottom: float, top: float, near: float, far: float):
    """Generate orthographic projection matrix.

    Args:
        left (float): left plane
        right (float): right plane
        bottom (float): bottom plane
        top (float): top plane
        near (float): near plane
        far (float): far plane

    Returns:
        torch.Tensor: orthographic projection matrix
    """
    return torch.tensor([
        [2 / (right - left), 0, 0, -(right + left) / (right - left)],
        [0, 2 / (top - bottom), 0, -(top + bottom) / (top - bottom)],
        [0, 0, -2 / (far - near), -(far + near) / (far - near)],
        [0, 0, 0, 1],
    ])

raster_settings = diff_gaussian_rasterization.GaussianRasterizationSettings(
    image_height=512,
    image_width=512,
    tanfovx=math.tan(FoVx * 0.5),
    tanfovy=math.tan(FoVy * 0.5),
    bg=torch.zeros(4, dtype=torch.float32, device=device),
    scale_modifier=1.0,
    viewmatrix=torch.eye(4, dtype=torch.float32, device=device),
    projmatrix=gen_orthographic_matrix(-1, 1, -1, 1, -1, 1).to(device),
    sh_degree=1,
    campos=torch.tensor([0, 0, 0], dtype=torch.float32, device=device),
    prefiltered=False,
    debug=True,
)

rasterizer = diff_gaussian_rasterization.GaussianRasterizer(raster_settings=raster_settings)

def render_images_batch(
        # [B, N, 2]
        means: torch.Tensor,
        # [B, N, 2]
        scales: torch.Tensor,
        # [B, N, 2]
        rotations: torch.Tensor,
        # [B, N]
        opacities: torch.Tensor,
) -> torch.Tensor:
    B, N = means.shape[:2]
    assert N >= 0
    assert means.shape == (B, N, 2)
    assert opacities.shape == (B, N)
    assert scales.shape == (B, N, 2)
    assert rotations.shape == (B, N, 2)
    assert opacities.shape == (B, N)

    means3D = torch.cat([means, torch.zeros(B, 1, 2, dtype=torch.float32, device=device)], dim=1)
    means2D = torch.zeros(B, N, 2, dtype=torch.float32, device=device)
    rotations = torch.from_numpy(pyq.Quaternion().elements).to(device).unsqueeze(0).unsqueeze(0).expand(B, N, 4)
    colors_precomp = torch.ones(B, N, 3, dtype=torch.float32, device=device)

    imlist = []
    for b in range(B):
        rendered_image, _ = rasterizer(
            means3D=means3D[b],
            means2D=means2D[b],
            colors_precomp=colors_precomp[b],
            opacities=opacities[b],
            scales=scales[b],
            rotations=rotations[b],
        )
        imlist.append(rendered_image)
    
    return torch.stack(imlist, dim=0)


RuntimeError: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero.

In [None]:

def train_gaussian_predictor(
    model: GaussianPredictor,
    optim: torch.optim.Optimizer,
    images: torch.Tensor,
) -> float:
    model.train()

    images = images.to(device)

    # predict gaussians
    gaussian_preds = model(images)

    # rasterize predictions
    means_batched = gaussian_preds[:, :, :2]
    scales_batched = gaussian_preds[:, :, 2:4]
    rotations_batched = gaussian_preds[:, :, 4:6]
    opacities_batched = gaussian_preds[:, :, 6]
    rasterized_pred = render_images_batch(means_batched, scales_batched, rotations_batched, opacities_batched)

    # compute loss
    loss = F.mse_loss(rasterized_pred, images)

    # step loss
    optim.zero_grad()
    loss.backward()
    optim.step()

    return loss.item()

In [None]:
# download mnist dataset
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

trainset = datasets.MNIST('data', download=True, train=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)

# create model
model = GaussianPredictor().to(device)
optim = torch.optim.Adam(model.parameters(), lr=1e-4)

# train model
for epoch in range(10):
    print(f"Epoch {epoch}")
    for images, labels in tqdm(trainloader):
        loss = train_gaussian_predictor(model, optim, images)
        print(f"Loss: {loss}")
        break

NameError: name 'GaussianPredictor' is not defined