# DiffusionGS - Baking Gaussian Splatting into Diffusion Denoiser for Fast and Scalable Single-stage Image-to-3D Generation

Official GitHub - https://caiyuanhao1998.github.io/project/DiffusionGS/

<div align="center">
  <img src="https://raw.githubusercontent.com/pleasure97/3D-AI-ML-Code-Implementation/main/2025/DiffusionGS/assets/pipeline.JPG" alt="Pipeline of DiffusionGS">
</div>

# 1. Dataset

In page 6 of the paper,

> We use **Objaverse and MVImgNet** as the training sets for objects.

> We center and scale each 3D object of Objaverse into $[-1, 1]^3$, and render 32 images at random viewpoints with 50 FOV.

> For MVImgNet, we crop the object, remove the background, normalize the cameras, and center and scale the object to $[-1, 1]^3$.

## 1.1 Objaverse Dataset

### 1.1.1 Prepare Objaverse Environmnet

For more details of loading Objaverse dataset, you can look here - https://colab.research.google.com/drive/1ZLA4QufsiI_RuNlamKqV7D7mn40FbWoY

In [1]:
!pip install objaverse --upgrade --quiet

import objaverse
objaverse.__version__

  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.6/61.6 kB[0m [31m1.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for gputil (setup.py) ... [?25l[?25hdone


'0.1.7'

Each object has a unique corresponding ID (universal identifier).

In [2]:
uids = objaverse.load_uids()
print(f"length of uids : {len(uids)}")
print(f"type of uids : {type(uids)}")

length of uids : 798759
type of uids : <class 'list'>


In [3]:
uids[:10]

['8476c4170df24cf5bbe6967222d1a42d',
 '8ff7f1f2465347cd8b80c9b206c2781e',
 'c786b97d08b94d02a1fa3b87d2e86cf1',
 '139331da744542009f146018fd0e05f4',
 'be2c02614d774f9da672dfdc44015219',
 'efd35e7d21ac482688c294e3b6c9f74e',
 '21d5f90dbc9f4f229b0faa7b56b67f3e',
 'dcd33159a0864de388de3a08f55e604a',
 'a7ad32b5d4d84ee5a40ebbd86da4dbe4',
 '7d6a14874eed48c2b720f0d1adfe6dd9']

We can get the object annotations for each of object using `objaverse.load_annotations()`.

In [4]:
annotations = objaverse.load_annotations(uids[:1])

We're going to use multiprocessing to download the objects.

In [5]:
import multiprocessing
processes = multiprocessing.cpu_count()
processes

2

`objaverse.load_objects()` takes in a list of object UIDs and optionally the number of download processes, and returns a map from each object UIDs to its `.glb` file location on disk.

In [6]:
objaverse_objects = objaverse.load_objects(uids=uids[:10], download_processes=processes)
objaverse_objects

starting download of 10 objects with 2 processes
Downloaded 1 / Downloaded 10 objects
2 / 10 objects
Downloaded 3 / 10 objects
Downloaded 4 / 10 objects
Downloaded 5Downloaded 6  //  1010  objectsobjects

Downloaded 7 / 10 objects
Downloaded 8 / 10 Downloadedobjects 
9 / 10 objects
Downloaded 10 / 10 objects


{'8476c4170df24cf5bbe6967222d1a42d': '/root/.objaverse/hf-objaverse-v1/glbs/000-023/8476c4170df24cf5bbe6967222d1a42d.glb',
 '8ff7f1f2465347cd8b80c9b206c2781e': '/root/.objaverse/hf-objaverse-v1/glbs/000-023/8ff7f1f2465347cd8b80c9b206c2781e.glb',
 'c786b97d08b94d02a1fa3b87d2e86cf1': '/root/.objaverse/hf-objaverse-v1/glbs/000-023/c786b97d08b94d02a1fa3b87d2e86cf1.glb',
 '139331da744542009f146018fd0e05f4': '/root/.objaverse/hf-objaverse-v1/glbs/000-023/139331da744542009f146018fd0e05f4.glb',
 'be2c02614d774f9da672dfdc44015219': '/root/.objaverse/hf-objaverse-v1/glbs/000-023/be2c02614d774f9da672dfdc44015219.glb',
 'efd35e7d21ac482688c294e3b6c9f74e': '/root/.objaverse/hf-objaverse-v1/glbs/000-023/efd35e7d21ac482688c294e3b6c9f74e.glb',
 '21d5f90dbc9f4f229b0faa7b56b67f3e': '/root/.objaverse/hf-objaverse-v1/glbs/000-023/21d5f90dbc9f4f229b0faa7b56b67f3e.glb',
 'dcd33159a0864de388de3a08f55e604a': '/root/.objaverse/hf-objaverse-v1/glbs/000-023/dcd33159a0864de388de3a08f55e604a.glb',
 'a7ad32b5d4d84e

Let's load up one of the `.glb` files to visualize it.

In [7]:
!pip install trimesh --quiet

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/708.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m706.6/708.6 kB[0m [31m21.5 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m708.6/708.6 kB[0m [31m14.5 MB/s[0m eta [36m0:00:00[0m
[?25h

In [8]:
import trimesh
trimesh.load(list(objaverse_objects.values())[1]).show()

### 1.1.2 Center and Scale Each Objaverse Object

In [14]:
!pip install --upgrade pyglet==v1.5.28

Collecting pyglet==v1.5.28
  Downloading pyglet-1.5.28-py3-none-any.whl.metadata (7.6 kB)
Downloading pyglet-1.5.28-py3-none-any.whl (1.1 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.1 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━[0m [32m0.7/1.1 MB[0m [31m19.5 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m21.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pyglet
Successfully installed pyglet-1.5.28


In [15]:
import os
import numpy as np
import trimesh

def normalize_mesh(mesh):
    vertices = mesh.vertices
    min_bound = vertices.min(axis=0)
    max_bound = vertices.max(axis=0)
    center = (min_bound + max_bound) / 2
    scale = 2.0 / np.max(max_bound - min_bound)
    mesh.vertices = (vertices - center) * scale
    return mesh


def sample_random_views(num_images=32, radius=1.):
    camera_positions = []
    view_directions = []

    for _ in range(num_images):
        theta = np.random.uniform(-2, 2)
        phi = np.random.uniform(-1, 1.5)

        cam_x = radius * np.sin(phi) * np.cos(theta)
        cam_y = radius * np.sin(phi) * np.sin(theta)
        cam_z = radius * np.cos(phi)

        position = np.array([cam_x, cam_y, cam_z])
        direction = -position / np.linalg.norm(position)

        camera_positions.append(position)
        view_directions.append(direction)

    return camera_positions, view_directions


def render_images_trimesh(mesh_file, output_dir, num_images=32, fov=50, image_size=(256, 256)):
    os.makedirs(output_dir, exist_ok=True)

    mesh = trimesh.load(mesh_file, force='mesh')
    mesh = normalize_mesh(mesh)

    scene = trimesh.Scene()
    scene.add_geometry(mesh)

    camera_positions, view_directions = sample_random_views(num_images)

    for i, (position, direction) in enumerate(zip(camera_positions, view_directions)):
        try:
            camera = trimesh.scene.Camera(fov=(fov, fov), resolution=image_size)
            scene.camera = camera

            look_at_matrix = trimesh.scene.cameras.look_at(points=[position], fov=fov)

            scene.camera_transform = look_at_matrix

            # 렌더링
            image_data = scene.save_image(resolution=image_size)
            if image_data is not None:
                image_path = os.path.join(output_dir, f"{i:03d}.png")
                with open(image_path, 'wb') as f:
                    f.write(image_data)
        except ZeroDivisionError as e:
            print(f"[ERROR] ZeroDivisionError for view {i} : {e}")
            continue

    print(f"Saved {num_images} images to {output_dir}")
    print(f"Saved {num_images} images to {output_dir}")


In [16]:
for idx, objaverse_object in enumerate(objaverse_objects.values()):
  render_images_trimesh(objaverse_object, f"/content/objaverse/images/{idx}")

ImportError: Library "GLU" not found.

### 1.1.3 Render Objaverse Images at Random Viewpoints

**Note:** It is difficult to render images in Colab, so we will sample images right away.

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go

def sample_images_from_random_views(ply_file: str, output_dir: str, image_size: tuple=(256, 256), fov: int=50, num_images: int=32):
  os.makedirs(output_dir, exist_ok=True)

  # Load the ply file
  pointcloud = o3d.io.read_point_cloud(ply_file)

  # Check if the point cloud is empty
  if (len(pointcloud.points) == 0):
    print(f"WARNING : Empty Point Cloud in {ply_file}")
    return

  for i in range(num_images):
    # Generate Random Camera Position
    theta = np.random.uniform(0, 2 * np.pi) # 0 to 360 degrees
    phi = np.random.uniform(0, np.pi) # 0 to 180 degrees
    radius = np.random.uniform(1.5, 3.)

    cam_x = radius * np.sin(phi) * np.cos(theta)
    cam_y = radius * np.sin(phi) * np.sin(theta)
    cam_z = radius * np.cos(phi)


    fig = plt.figure(figsize=(6, 6))
    ax = fig.add_subplot(111, projection='3d')

    points = np.asarray(pointcloud.points)
    ax.scatter(points[:, 0], points[:, 1], points[:, 2], s=1)

    ax.view_init(elev=np.degrees(phi), azim=np.degrees(theta))

    # Save the image
    output_path = os.path.join(output_dir, f"{i:03d}.png")
    plt.savefig(output_path, dpi=300)
    plt.close()

  print(f"Saved {num_images} images to {output_dir}")

In [None]:
PLY_FOLDER = "/content/objaverse/normalized"
IMAGES_FOLDER = "/content/objaverse/images"

ply_files = [file for file in os.listdir(PLY_FOLDER) if file.endswith(".ply")]

for ply_file in ply_files:
  ply_path = os.path.join(PLY_FOLDER, ply_file)
  image_dir = os.path.join(IMAGES_FOLDER, ply_file.split('.')[0])

  sample_images_from_random_views(ply_path, image_dir)

## 1.2 MVImgNet Dataset

You should enter the required information at the following link to download MVImgNet.- https://docs.google.com/forms/d/e/1FAIpQLSfU9BkV1hY3r75n5rc37IvlzaK2VFYbdsvohqPGAjb2YWIbUg/viewform?usp=sf_link

### 1.2.1 Crop the MVImgNet Object

In [None]:
%%writefile /content/preprocess.py
import trimesh
import numpy as np

def center_scale_mesh(mesh: trimesh.Trimesh):

  # Calculate the minimum, maximum bound and center
  min_bound = mesh.vertices.min(axis=0)
  max_bound = mesh.vertices.max(axis=0)
  center = (min_bound + max_bound) / 2
  print("Mininum bound of Original: ", min_bound)
  print("Maximum bound of Original: ", max_bound)
  print("Center: ", center)

  # Move all the vertices to center
  mesh.vertices -= center

  bound_size = max_bound - min_bound
  max_extent = np.max(bound_size)

  mesh.vertices /= (max_extent / 2)

  return mesh

In [None]:
from preprocess import center_scale_mesh

mesh = trimesh.load(list(objects.values())[0])
if isinstance(mesh, trimesh.Scene):
  mesh = mesh.dump()[0]
normalized_mesh = center_scale_mesh(mesh)

print("Minimum Bound of Normalized: ", normalized_mesh.vertices.min(axis=0))
print("Maximum Bound of Normalized: ", normalized_mesh.vertices.max(axis=0))

### 1.2.2 Remove the Background

### 1.2.3 Normalize the Camera

### 1.2.4 Center and Scale MVImgNet Object

# 2. Training

In [None]:
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

## 2.1 Scene-Object Mixed Training Strategy

In page 3 of the paper,
> For each scene or object, we pick up a view as the condition, $N$ views as the noisy views to be denoised, and $M$ novel views for supervision.

---

### 2.1.1 Viewpoint Selecting

$$\theta_{c d}^{(i)} \leq \theta_1, \quad \theta_{d n}^{(i, j)} \leq \theta_2,$$

* The first constraint of the angle between viewpoint and positions
* $\theta_{c d}^{(i)}$ - the angle between the $i$-th noisy view position and the condition view position
* $\theta_{d n}^{(i, j)}$ - the angle between the $i$-th noisy view position and the $j$-th novel view position
* $\theta_{1}, \theta_{2}$ - hyperparameters
* $1 \leq i \leq N$
* $1 \leq$ $j \leq M$

In page 5 of the paper,
> The position vector can be read from the translation of the camera-to-world (c2w) matrix of the viewpoint.


---
$$
\frac{\vec{z}_{c o n} \cdot \vec{z}_{n o i s e}^{(i)}}{\left|\vec{z}_{\text {con }}\right| \cdot\left|\vec{z}_{\text {noise }}^{(i)}\right|} \geq \cos \left(\varphi_1\right), \frac{\vec{z}_{\text {con }} \cdot \vec{z}_{n v}^{(j)}}{\left|\vec{z}_{\text {con }}\right| \cdot\left|\vec{z}_{n v}^{(j)}\right|} \geq \cos \left(\varphi_2\right)
$$

* The second constraint of the angle between viewpoint orientations
* $\vec{z}_{c o n}$ - the forward direction vectors of the condition view
* $\vec{z}_{n o i s e}^{(i)}$ - the forward direction vectors of the $i$-th noisy view
* $\vec{z}_{n v}^{(j)}$ - the forward direction vectors of the $j$-th novel view
* $\varphi_1$, $\varphi_2$ - hyperparameters

---

### 2.1.2 Reference-Point Plücker Coordinate (RPPC)

$$r=(o-(o \cdot d) d, d)$$

* $r$ - the pixel-aligned ray embeddings
* $o$ - the position of the ray landing on the pixel
* $d$ - the direction of the ray landing on the pixel

---
### Difference Between Original Plücker Coordinate and Reference-Point Plücker Coordinate

<div align="center">
  <img src="https://raw.githubusercontent.com/pleasure97/3D-AI-ML-Code-Implementation/main/2025/DiffusionGS/assets/Plucker Coordinates.JPG" alt="Plücker Coordinates">
</div>

In page 5 of the paper,

> Specifically, $o \times d$ represents the rotational effect of $o$ relative to $d$, showing limitations in perceiving the relative depth and geometry.

In page 6 of the paper,

> Our RPPC (Reference-Point Plücker Coordinate) satisifies the translation invariance assumption of the 4D light field.

> Plus, ..., our reference point can provide more information about the ray position and the relative depth, which are beneficial for the diffusion model to capture the 3D geometry of scens and objects.  

In [None]:
%%writefile /content/RPPC.py
# Source Code : https://github.com/echen01/ray-conditioning/blob/8e1d5ae76d4747c771d770d1f042af77af4b9b5d/training/plucker.py
import torch
from torch.nn import functional as F

def get_rays(H, W, intrinsics, c2w, jitter=False):
  """
  H : image height
  W : image width
  intrinsics : 4 by 4 intrinsic matrix
  c2w : 4 by 4 camera to world extrinsic matrix
  """
  u, v = torch.meshgrid(torch.arange(W, device=c2w.device), torch.arange(H, device=c2w.device), indexing="ij")
  B = c2w.shape[0]
  u, v = u.reshape(-1), v.reshape(-1)
  u_noise = v_noise = 0.5
  if jitter:
    u_noise = torch.rand(u.shape, device=c2w.device)
    v_noise = torch.rand(v.shape, device=c2w.device)
  u, v = u + u_noise, v + v_noise # add half pixel
  pixels = torch.stack((u, v, torch.ones_like(u)), dim=0) # (3, H * W)
  pixels = pixels.unsqueeze(0).repeat(B, 1, 1) # (B, 3 , H * W)
  if intrinsics.sum() == 0:
    inv_intrinsics = torch.eye(3, device=c2w.device).tile(B, 1, 1)
  else:
    inv_intrinsics = torch.linalg.inv(intrinsics)
  rays_d = inv_intrinsics @ pixels # (B, 3, H * W)
  rays_d = c2w[:, :3, :3] @ rays_d
  rays_d = rays_d.transpose(-1, -2) # (B, H * W, 3)
  rays_d = F.normalize(rays_d, dim=-1)

  rays_o = c2w[:, :3, 3].reshape((-1, 3)) # (B, 3)
  rays_o = rays_o.unsqueeze(1).repeat(1, H * W, 1) # (B, H * W, 3)

  return rays_o, rays_d

def plucker_embedding(H, W, intrinsics, c2w, jitter=False):
  """Computes the plucker coordinates from batched cam2world & intrinsics matrices, as well as pixel coordinates
  C2W : (Batch Size, 4, 4)
  intrinsics : (Batch Size, 3, 3)
  """
  rays_o, rays_d = get_rays(H, W, intrinsics, c2w, jitter=jitter) # (B, H * W, 3), (B, H * W, 3)
  cross = torch.cross(rays_o, rays_d, dim=-1)
  plucker = torch.cat((rays_d, cross), dim=1)

  plucker = plucker.view(-1, H, W, 6).permute(0, 3, 1, 2)
  return plucker # (B, 6, H, W, )

def reference_point_plucker_embedding(H, W, intrinsics, c2w, jitter=False):
  """Computes the reference point plucker coordinates from batched cam2world & intrinsics matrices, as well as pixel coordinates
  H : image height
  W : image width
  C2W : (Batch Size, 4, 4)
  intrinsics : (Batch Size, 3, 3)
  """
  rays_o, rays_d = get_rays(H, W, intrinsics, c2w, jitter=jitter) # (B, H * W, 3), (B, H * W, 3)
  o_dot_d = (rays_o * rays_d).sum(dim=-1, keepdim=True) # (B, H * W , 1)
  reference_point = rays_o - o_dot_d * rays_d # (B, H * W, 3)
  reference_point_plucker = torch.cat((rays_d, reference_point), dim=1)

  reference_point_plucker = reference_point_plucker.view(-1, H, W, 6).permute(0, 3, 1, 2)
  return reference_point_plucker # (B, 6, H, W)

---

### 2.1.3 Loss for Training

$$\mathcal{L}_{\text{pd}} = \mathbb{E}_k \left[ l_t^{(k)} - \frac{\mathbb{E}_k[l_t^{(k)}] - \sigma_0 + \mathbb{E}[o^{(k)}]}{\sqrt{\text{Var}(l_t^{(k)})}} \right]$$

* $\mathcal{L}_{\text{pd}}$ - the point distribution loss for training warm-up
* $\mathbb{E}$ - the mean value
* $l_t^{(k)} = |u_t^{(k)} d^{(k)}|$
  * $u_t^{(k)}$ - interpolated distance value between $u_{near}$ and $u_{far}$
  * $d^{(k)}$ - the direction of the $k$-th pixel-aligned ray
* $Var(l_t^{(k)})$ - the variance of $l_t^{(k)}$
* $\sigma_{0}$ - the target standard deviation (set to 0.5)
* $o^{(k)}$ - the origin of the $k$-th pixel-aligned ray



---
$$\mathcal{L}_{d e}=\mathcal{L}_2\left(\hat{\mathcal{X}}_{(0, t)}, \mathcal{X}_0\right)+\lambda \cdot \mathcal{L}_{\mathrm{VGG}}\left(\hat{\mathcal{X}}_{(0, t)}, \mathcal{X}_0\right)$$

* $\hat{\mathcal{X}}_{(0, t)}$ - the denoised multi-view images

* $\mathcal{X}_0=\left\{\mathbf{x}_0^{(1)}, \mathrm{x}_0^{(2)}, \cdots, \mathrm{x}_0^{(\mathrm{N})}\right\}$ - N noisy views when timestep is 0

* $\mathcal{L}_2\$ - L2 loss (=MSE loss)

* $\mathcal{L}_{\mathrm{VGG}}\$ - VGG Loss

* $\lambda$ - hyperparameter


In [None]:
%%writefile /content/loss.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

def get_point_distribution_loss(rays_o, rays_d, k:int, sigma_0: float=0.5):
  pass

# Source code - https://gist.github.com/alper111/
class VGGLoss(nn.Module):
  def __init__(self, resize=True):
    super().__init__()
    blocks = []
    VGG16 = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT)
    blocks.append(VGG16.features[:4].eval())
    blocks.append(VGG16.features[4:9].eval())
    blocks.append(VGG16.features[9:16].eval())
    blocks.append(VGG16.features[16:23].eval())

    for block in blocks:
      for param in block.parameters():
        param.requires_grad=False

    self.blocks = nn.ModuleList(blocks)
    self.transform = F.interpolate
    self.resize = resize
    self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
    self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))

  def forward(self, input, target, feature_layers=[0, 1, 2, 3], style_layers=[]):
    if input.shape[1] != 3:
      input = input.repeat(1, 3, 1, 1)
      target = target.repeat(1, 3, 1, 1)
    input = (input - self.mean) / self.std
    target = (target - self.mean) / self.std
    if self.resize:
      input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
      target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
    loss = 0.
    x = input
    y = target
    for i, block in enumerate(self.blocks):
      x = block(x)
      y = block(y)
      if i in feature_layers:
        loss += F.l1_loss(x, y)
      if i in style_layers:
        act_x = x.reshape(x.shape[0], x.shape[1], -1)
        act_y = y.reshape(y.shape[0], y.shape[1], -1)
        gram_x = act_x @ act_x.permute(0, 2, 1)
        gram_y = act_y @ act_y.permute(0, 2, 1)
        loss += F.l1_loss(gram_x, gram_y)
    return loss

def get_denoising_loss(source: torch.Tensor, target: torch.Tensor, hyperparameter: float=0.8):
  L2_Loss = nn.MSELoss()
  VGG_Loss = VGGLoss()
  return L2_Loss(source, target) + hyperparameter * VGG_Loss(source, target)

---
$$\mathcal{L} = (\mathcal{L}_{\text{de}} + \mathcal{L}_{\text{nv}}) \cdot \mathbb{1}_{\text{iter} > \text{iter}_0} + \mathcal{L}_{\text{pd}} \cdot \mathbb{1}_{\text{iter} \leq \text{iter}_0} \cdot \mathbb{1}_{\text{object}}$$

* $\mathcal{L}$ - the overall training objective
* $\mathcal{L}_{\text{nv}}$ - the novel view loss
* $\mathcal{L}_{d e}$ - the denoising loss
* $\mathbb{1}_{\text{iter} > \text{iter}_0}$ - the conditional indicator function which equals 1 if the current training iteration is greater than the threshold $iter_{0}$
* $\mathbb{1}_{\text{iter} \leq \text{iter}_0}$  - similar indicator function as above


In [None]:
from loss import get_denoising_loss
total_iters = 100_000
warmup_iters = 2_000

current_iter = 0

is_object =

point_distribution_loss =
denoising_distribution_loss = get_denoising_loss()
novel_view_loss =
loss = torch.where(current_iter > warmup_iters, denoising_distribution_loss + novel_view_loss, point_distribution_loss * torch.where(is_object, 1, 0))
loss.backward()

---
### 2.1.4 Training Details

In page 6 of the paper,

> We implement DiffusionGS by PyTorch and train it with Adam Optimizer.

> To save GPU memory, we adopt mixed-precision training with BF16, sublinear memory training, and deferred GS rendering.

In page 7 of the paper,

> The learning rate is linearly warmed up to $4e^{-4}$ with $2K$ iterations and decades to $0$ using cosine annealing scheme.

> Finally, we scale up the training resolution from $256 \times 256$ to $512 \times 512$ and finetune the model for $20K$ iterations.

<div align="center">
  <img src="https://raw.githubusercontent.com/pleasure97/3D-AI-ML-Code-Implementation/main/2025/DiffusionGS/assets/Learning Rate Scheduler.png" alt="Learning Rate Scheduler">
</div>

In [None]:
import matplotlib.pyplot as plt
from typing import List

# lrs = []
# for i in range(total_iters):
#     optimizer.step()
#     scheduler.step()
#     lrs.append(optimizer.param_groups[0]['lr'])

def visualize_lr_scheduler(lrs: List[float]):
  plt.plot(lrs)
  plt.xlabel("Iteration")
  plt.ylabel("Learning Rate")
  plt.title("Linear Warm-up + Cosine Annealing")
  plt.show()

def warmup_lambda(epoch: int, warmup_iters: int):
  return epoch / warmup_iters if epoch < warmup_iters else 1.

In [None]:
import torch
import torch.optim as optim

# Adam Optimizer
optimizer = optim.Adam(DiffusionGS.parameters(), lr=4e-4)

# Linearly Warm-Up and Cosine Annealing Scheduler
warmup_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_lambda)
cosine_annealing_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_iters - warmup_iters, eta_min=0)
scheduler = optim.lr_scheduler.SequentialLR(optimizer, schedulers=[warmup_scheduler, cosine_annealing_scheduler], milestones=[warmup_iters])

## 2.2 3D Diffusion
---
*  $\mathbf{x}_{\text {con }} \in \mathbb{R}^{H \times W \times 3}$ - 1 clean condition view
* $\mathcal{X}_t=\left\{\mathbf{x}_t^{(1)}, \mathrm{x}_t^{(2)}, \cdots, \mathbf{x}_t^{(N)}\right\}$ -  $N$ noisy views
  * $\mathcal{X}_0=\left\{\mathbf{x}_0^{(1)}, \mathrm{x}_0^{(2)}, \cdots, \mathrm{x}_0^{(\mathrm{N})}\right\}$ - *Concatenated with $\mathcal{X}_t$*
* $\mathbf{v}_{\text {con }} \in \mathbb{R}^{H \times W \times 6}$ - viewpoint conditions
  * $\mathcal{V}=\left\{\mathbf{v}^{(1)}, \mathbf{v}^{(2)}, \cdots, \mathbf{v}^{(\mathbb{N})}\right\}$

$$\mathbf{x}_t^{(i)}=\overline{\alpha_t} \mathbf{x}_0^{(i)}+\sqrt{1-\overline{\alpha_t}} \epsilon_t^{(i)}$$
* $\overline{\alpha_t}$ - pre-scheduled hyper-parameter
* $\epsilon_t^{(i)} \sim \mathcal{N}(0, \mathbf{I})$ and $i=1,2, \cdots, N$
* $t$ - timestep


In [None]:
import torch
import tqdm

HEIGHT = 256
WIDTH = 256
x_con = torch.randn((HEIGHT, WIDTH, 3))
v_con = torch.randn((HEIGHT, WIDTH, 3))

class DiffusionNoiser:
  def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=2e-2, image_size=256, device=device):
    self.noise_steps = noise_steps
    self.beta_start = beta_start
    self.beta_end = beta_end
    self.image_size = image_size
    self.device = device

    self.beta = self.schedule_noise()
    self.alpha = 1. - self.beta
    self.alpha_hat = torch.cumprod(self.alpha, dim=0)

  def schedule_noise(self):
    return torch.linspace(self.beta_start, self.beta_end, self.noise_steps, device=self.device)

  def noise_images(self, image, timestep):
    alpha_hat_sqrt = torch.sqrt(self.alpha_hat[timestep])[:, None, None, None]
    one_minus_alpha_hat_sqrt = torch.sqrt(1 - self.alpha_hat[timestep])[:, None, None, None]
    epsilon = torch.rand_like(image)
    return alpha_hat_sqrt * image + one_minus_alpha_hat_sqrt * epsilon, epsilon

  def sample_timesteps(self, num_timesteps):
    return torch.randint(low=1, high=self.noise_steps, size=(n,), device=self.device)

  def sample(self, model, num_timesteps):
    model.eval()
    with torch.no_grad():
      x = torch.randn((num_timesteps, 3, self.image_size, self.image_size), device=device)
      for i in tqdm(reversed(range(1, self.noise_steps)), position=0):
        time_steps = (torch.ones(num_timesteps, device=self.device) * i).long()
        predicted_noise = model(x, t)
        alpha = self.alpha[time_steps][:, None, None, None]
        alpha_hat = self.alpha_hat[time_steps][:, None, None, None]
        beta = self.beta[time_steps][:, None, None, None]
        if i > 1:
          noise = torch.randn_like(x)
        else:
          noise = torch.zeros_like(x)
        x = 1. / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
    model.train()
    x = (x.clamp(-1, 1) + 1) / 2
    x = (x * 255).type(torch.uint8)
    return x

In [None]:
# Source Code link - https://huggingface.co/dylanebert/LGM-full/blob/main/pipeline.py
# !pip install kiui
from kiui.cam import orbit_camera
import numpy as np
import torch

def get_camera(num_frames, elevation=15, azimuth_start=0, azimuth_span=360, blender_coord=True, extra_view=False):
    angle_gap = azimuth_span / num_frames
    cameras = []
    for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start, angle_gap):

        pose = orbit_camera(-elevation, azimuth, radius=1)  # kiui's elevation is negated, [4, 4]

        # opengl to blender
        if blender_coord:
            pose[2] *= -1
            pose[[1, 2]] = pose[[2, 1]]

        cameras.append(pose.flatten())

    if extra_view:
        cameras.append(np.zeros_like(cameras[0]))

    return torch.from_numpy(np.stack(cameras, axis=0)).float()

In [None]:
# Source Code link - https://huggingface.co/dylanebert/LGM-full/blob/main/pipeline.py
# !pip install einops
import torch
import math
from einops import repeat

def timestep_embedding(timesteps, dim, max_period=10_000, repeat_only=False):
    """
    Create sinusoidal timestep embeddings.
    :param timesteps: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an [N x dim] Tensor of positional embeddings.
    """
    if not repeat_only:
        half = dim // 2
        freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half)
        args = timesteps[:, None] * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    else:
        embedding = repeat(timesteps, "b -> b d", d=dim)
    return embedding

differentiable gaussian rasterization - https://huggingface.co/learn/ml-for-3d-course/unit3/hands-on

---
$$\mathcal{G}_\theta\left(\mathcal{X}_t \mid \mathbf{x}_{c o n}, \mathbf{v}_{c o n}, t, \mathcal{V}\right)=\left\{G_t^{(k)}\left(\mu_t^{(k)}, \boldsymbol{\Sigma}_t^{(k)}, \alpha_t^{(k)}, c_t^{(k)}\right)\right\}$$
* $\theta$ - denoiser
* $\mathcal{G}_\theta$ - predicted 3D Gaussians by $\theta$
* $1 \leq k \leq N_g$
* $N_g=(N+1) H W$ - the number of per-pixel Gaussian $G_t^{(k)}$
* $H , W$ - Height and Width of the image
* $\mu_t^{(k)} \in$ $\mathbb{R}^3$ - the center position of each $G_t^{(k)}$ (clipped into $[-1, 1]^3$)
* $\Sigma_t^{(k)} \in \mathbb{R}^{3 \times 3}$ - the covariance of each $G_t^{(k)}$ controlling its shape
  * parameterized by a rotation matrix $\mathbf{R}_t^{(k)}$ and a scaling matrix $\mathbf{S}_t^{(k)}$
* $\alpha_t^{(k)} \in \mathbb{R}$ - the opacity of each $G_t^{(k)}$ characterizing the transmittance
* $c_t^{(k)} \in \mathbb{R}^3$ - the RGB color of each $G_t^{(k)}$


In [None]:
class GaussianModel:
  pass

GaussianModel(center_positions, covariances, opacities, rgbs, time_step)

---
$$\mu_t^{(k)}=o^{(k)}+u_t^{(k)} d^{(k)}$$
* $o^{(k)}$ - the origin of the $k$-th pixel-aligned ray
* $d^{(k)}$ - the direction of the $k$-th pixel-aligned ray


---
$$u_t^{(k)}=w_t^{(k)} u_{\text {near }}+\left(1-w_t^{(k)}\right) u_{f a r}$$
* $w_t^{(k)} \in \mathbb{R}$ - the weight to control $u_t^{(k)}$
* $u_{\text {near }}$ - the nearest distances
* $u_{f a r}$ - the farthest distances
* For the object-level Gaussian decoder, $[u_{\text {near }}, u_{f a r}] = [0.1, 4.2]$
* For the scene-level Gaussian decoder, $[u_{\text {near }}, u_{f a r}] = [0, 500]$

## 2.3 Denoiser

---

* $L$ - the number of transformer blocks
* Each transformer block contains 1 MSA, 1 MLP, and 2 LN.
* $\hat{\mathcal{H}}=\left\{\hat{\mathbf{H}}_{\text {con }}, \hat{\mathbf{H}}^{(1)}, \cdots, \hat{\mathbf{H}}^{(N)}\right\}$ - per-pixel Gaussian Maps
  * $\hat{\mathbf{H}}_{\text {con }}$, $\hat{\mathbf{H}}^{(i)} \in$ $\mathbb{R}^{H \times W \times 14}$



In [None]:
import torch
import torch.nn as nn

In [None]:
class PatchEmbedding(nn.Module):
  """ Turns a 2D input image into a 1D sequence learnable embeding vector.
  Args:
    in_channels (int) - Number of color channels for the input images. Defaults to 3.
    patch_size (int) - Size of patches to convert input images into. Defaults to 16.
    embedding_dim (int) - Size of embedding to turn image into. Defaults to 768.
  """
  def __init__(self, in_channels: int=3, patch_size: int=16, embedding_dim: int=768):
    super().__init__()

    self.in_channels = in_channels
    self.patch_size = patch_size
    self.embedding_dim = embedding_dim

    self.patchify = nn.Conv2d(in_channels=self.in_channels,
                              out_channels=self.embedding_dim,
                              kernel_size=self.patch_size,
                              stride=self.patch_size,
                              padding=0)

    self.flatten = nn.Flatten(start_dim=2, end_dim=3)

  def forward(self, x: torch.Tensor):
    image_resolution = x.shape[-1]
    assert image_resoultion % self.patch_size == 0, \
      f"Input size must be divisible by patch size, image size : {image_resolution}, patch size : {self.patch_size}"

    x_patched = self.patchify(x)
    x_flattened = self.flatten(x_patched)

    return x_flattened.permute(0, 2, 1) # [batch_size, patch_size ** 2 * channel, embedding_dim] -> [batch_size, embedding_dim, patch_size ** 2 ]



In [None]:
!pip install torchinfo --quiet

In [None]:
from torchinfo import summary

PatchEmbedding = PatchEmbedding()
height, width = image.shape[1], image.shape[2]
num_patches = int((height * width) / PatchEmbedding.patch_size ** 2)

# summary(PatchEmbedding(),
#         input_size=,
#         col_names=["input_size", "output_size", "num_params", "trainable"],
#         col_width=20,
#         row_settings=["var_names"] )

In [None]:
PositionalEmbedding = nn.Parameter(torch.ones(1, num_patches + 1, PatchEmbedding.embedding_dimension), requires_grad=True)
PatchAndPositionalEmbedding = PatchEmbedding + PositionalEmbedding

In [None]:
class MultiHeadSelfAttentionBlock(nn.Module):
  def __init__(self,
               embedding_dim: int=768,
               num_heads: int=12,
               attention_dropout: float=0.):
    super().__init__()

    self.embedding_dim = embedding_dim
    self.num_heads = num_heads
    self.attention_dropout = attention_dropout

    self.layer_norm = nn.LayerNorm(normalized_shape=self.embedding_dim)

    self.multihead_attention = nn.MultiheadAttention(embed_dim=self.embedding_dim,
                                                     num_heads=self.num_heads,
                                                     dropout=self.attention_dropout,
                                                     batch_first=True)

  def forward(self, x: torch.Tensor):
    x = self.layer_norm(x)
    attention_output, _ = self.multihead_attention(query=x, key=x, value=x, need_weights=False)
    return attention_output

In [None]:
class MLPBlock(nn.Module):
  def __init__(self,
               embedding_dim: int=768,
               mlp_size: int=3072,
               dropout: float=0.1):
    super().__init__()

    self.embedding_dim = embedding_dim
    self.mlp_size = mlp_size
    self.dropout = dropout

    self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)

    self.mlp = nn.Sequential(
        nn.Linear(in_features=self.embedding_dim, out_features=self.mlp_size),
        nn.GELU(),
        nn.Dropout(p=dropout),
        nn.Linear(in_features=self.mlp_size, out_features=self.embedding_dim),
        nn.Dropout(p=dropout)
    )

  def forward(self, x: torch.Tensor):
    x = self.layer_norm(x)
    x = self.mlp(x)
    return x

In [None]:
class TransformerBlock(nn.Module):
  def __init__(self,
               embedding_dim: int=768,
               num_heads: int=12,
               mlp_size: int=3072,
               mlp_dropout: float=0.1,
               attention_dropout: float=0.):
    super().__init__()

    self.embedding_dim = embedding_dim
    self.num_heads = num_heads
    self.mlp_size = mlp_size
    self.mlp_dropout = mlp_dropout
    self.attention_dropout = attention_dropout

    self.MSABlock = MultiheadSelfAttentionBlock(embedding_dim=self.embedding_dim,
                                                num_heads=self.num_heads,
                                                attention_dropout=self.attention_dropout)

    self.MLPBlock = MLPBlock(embedding_dim=self.embedding_dim,
                             mlp_size=self.mlp_size,
                             mlp_dropout=self.mlp_dropout)

  def forward(self, x: torch.Tensor):

    x = self.MSABlock(x) + x
    x = self.MLPBlock(x) + x

    return x

In [None]:
class TransformerLayer(nn.Module):
  def __init__(self,
               img_size: int=224,
               in_channels: int=3,
               patch_size: int=12,
               num_transformer_layers: int=12,
               embedding_dim: int=768,
               mlp_size: int=3072,
               num_heads: int=12,
               attention_dropout: float=0.,
               mlp_dropout: float=0.1,
               embedding_dropout: float=0.1,
               num_outputs: int=3072):
    super().__init__()

    assert img_size % patch_size == 0, f"Image size must be divisible by patch size, image size: {img_size}, patch size: {patch_size}."

    self.img_size = img_size
    self.in_channels = in_channels
    self.patch_size = patch_size
    self.num_transformer_layers = num_transformer_layers
    self.embedding_dim = embedding_dim
    self.mlp_size = mlp_size
    self.num_heads = num_heads
    self.attention_dropout = attention_dropout
    self.mlp_dropout = mlp_dropout
    self.embedding_dropout = embedding_dropout
    self.num_outputs = num_outputs

    self.num_patches = (self.img_size * self.img_size) // self.patch_size ** 2
    self.PositionEmbedding = nn.Parameter(data=torch.randn(1, self.num_patches + 1, self.embedding_dim), requires_grad=True)
    self.EmbeddingDropout = nn.Dropout(p=self.embedding_dropout)
    self.PatchEmbedding = PatchEmbedding(in_channels=self.in_channels,
                                         patch_size=self.patch_size,
                                         embedding_dim=self.embedding_dim)
    self.TransformerBlocks = nn.Sequential(*[TransformerBlock(embedding_dim=self.embedding_dim,
                                                             num_heads=self.num_heads,
                                                             mlp_size=self.mlp_size,
                                                             mlp_droput=self.mlp_dropout) for _ in range(self.num_transformer_layers)])
    self.classifier = nn.Sequential(nn.LayerNorm(normalized_shape=self.embedding_dim),
                                    nn.Linear(in_features=self.embedding_dim, out_features=self.num_outputs))

  def forward(self, x: torch.Tensor):
    batch_size = x.shape[0]
    x = self.PatchEmbedding(x)
    x = self.PositionEmbedding(x) + x
    x = self.EmbeddingDropout(x)
    x = self.TransformerBlocks(x)
    x = self.classifier(x)
    return

---
$$\hat{\mathcal{X}}_{(0, t)}=\left\{\hat{\mathbf{x}}_{(0, t)}^{(1)}, \hat{\mathbf{x}}_{(0, t)}^{(2)}, \cdots, \hat{\mathbf{x}}_{(0, t)}^{(N)}\right\}$$
* $\hat{\mathcal{X}}_{(0, t)}$ - the denoised multi-view images

---
$$\hat{\mathbf{x}}_{(0, t)}^{(i)}=F_r\left(\mathbf{M}_{e x t}^{(i)}, \mathbf{M}_{i n t}^{(i)}, \mathcal{G}_\theta\left(\mathcal{X}_t \mid \mathbf{x}_{c o n}, \mathbf{v}_{c o n}, t, \mathcal{V}\right)\right)$$
* $F_r$ - the differentiable rasterization function
* $1 \leq i \leq N$
* $\mathbf{M}_{e x t}^{(i)}$ - the extrinsic matrix of the viewpoint $\mathbf{c}^{(i)}$.
* $\mathbf{M}_{i n t}^{(i)}$ - the intrinsic matrix of the viewpoint $\mathbf{c}^{(i)}$.

In [None]:
!git clone https://github.com/graphdeco-inria/diff-gaussian-rasterization.git diff-gaussian-rasterization

---
$$\boldsymbol{\Sigma}_t^{\prime(k, i)}=\mathbf{J}_t^{(i)} \mathbf{W}_t^{(i)} \boldsymbol{\Sigma}_t^{(k)} \mathbf{W}_t^{(i)^{\top}} \mathbf{J}_t^{(i)^{\top}}$$
* $\boldsymbol{\Sigma}_t^{(k)}$ - the 3D covariance matrix of each $G_t^{(k)}$ at viewpoint $\mathbf{c}^{(i)}$ in the world coordinate system
* $\boldsymbol{\Sigma}_t^{\prime(k, i)} \in \mathbb{R}^{3 \times 3}$  - the 3D covariance matrix of each $G_t^{(k)}$ at viewpoint $\mathbf{c}^{(i)}$ in the camera coordinate system
*  $\mathbf{J}_t^{(i)} \in \mathbb{R}^{3 \times 3}$ - the Jacobian matrix of the affine approximation of the projective transformation
* $\mathbf{W}_t^{(i)} \in \mathbb{R}^{3 \times 3}$ - the viewing transformation

# References

* Loading Objaverse Dataset - https://colab.research.google.com/drive/1ZLA4QufsiI_RuNlamKqV7D7mn40FbWoY
* Paper Replicating - https://github.com/mrdbourke/pytorch-deep-learning/blob/main/08_pytorch_paper_replicating.ipynb
* 3D Gaussian Rasterization - https://github.com/graphdeco-inria/diff-gaussian-rasterization
* Original Plücker Coordinate - https://github.com/echen01/ray-conditioning