# Simple Gaussian Splatting - Educational Implementation

This notebook implements a simplified version of 3D Gaussian Splatting suitable for learning and class projects.

**Key Features:**
- Axis-aligned Gaussians (no rotation - simplest case)
- Vectorized rendering
- Fast training (minutes, not hours)
- Easy to understand and modify
- Ready for diffusion extension

**Simplifications:**
- No quaternion rotations (axis-aligned only)
- Simple diagonal covariance matrices
- Reduced number of Gaussians (1000-2000)
- Lower resolution (64x64 to 128x128)

## 1. Setup & Imports

In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import os
import sys

# Add project root to path
sys.path.insert(0, os.path.abspath('.'))

# Check device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## 2. Simple Gaussian Model

We use **axis-aligned Gaussians** (no rotation) for simplicity:
- Position: 3D coordinates (x, y, z)
- Scale: 3D scales (sx, sy, sz) - diagonal covariance
- Opacity: α ∈ [0, 1]
- Color: RGB values

Covariance matrix: Σ = diag(sx², sy², sz²)

In [None]:
class SimpleGaussian3D(nn.Module):
    """
    Simple 3D Gaussian model with axis-aligned ellipsoids (no rotation).
    
    Parameters:
    - pos: (N, 3) 3D positions
    - log_scales: (N, 3) log scales (for positive constraint)
    - logit_opacity: (N,) opacity in logit space
    - rgb_logit: (N, 3) RGB colors in logit space
    """
    def __init__(self, n_gaussians=1000, device='cpu'):
        super().__init__()
        self.num = n_gaussians
        self.device = device
        
        # Initialize parameters
        # Positions: random near origin
        self.pos = nn.Parameter(torch.randn(n_gaussians, 3, device=device) * 0.5)
        
        # Scales: small initial values (exp(-2) ≈ 0.14)
        self.log_scales = nn.Parameter(torch.ones(n_gaussians, 3, device=device) * -2.0)
        
        # Opacity: high initial value (sigmoid(1) ≈ 0.73)
        self.logit_opacity = nn.Parameter(torch.ones(n_gaussians, device=device) * 1.0)
        
        # Colors: random bright colors
        self.rgb_logit = nn.Parameter(torch.randn(n_gaussians, 3, device=device) * 0.5 + 1.0)
    
    def get_scales(self):
        """Get positive scales"""
        return torch.exp(self.log_scales)
    
    def get_opacity(self):
        """Get opacity in [0, 1]"""
        return torch.sigmoid(self.logit_opacity)
    
    def get_colors(self):
        """Get RGB colors in [0, 1]"""
        return torch.sigmoid(self.rgb_logit)
    
    def get_covariance_3d(self):
        """
        Get 3D covariance matrices (diagonal, axis-aligned).
        Returns: (N, 3, 3) covariance matrices
        """
        scales = self.get_scales()  # (N, 3)
        # Diagonal covariance: Σ = diag(s²)
        cov = torch.zeros(self.num, 3, 3, device=self.device)
        cov[:, 0, 0] = scales[:, 0] ** 2
        cov[:, 1, 1] = scales[:, 1] ** 2
        cov[:, 2, 2] = scales[:, 2] ** 2
        return cov

# Test the model
n_gaussians = 1000
gaussians = SimpleGaussian3D(n_gaussians=n_gaussians, device=device)
print(f"✓ Created {n_gaussians} Gaussians")
print(f"  Position range: [{gaussians.pos.min().item():.2f}, {gaussians.pos.max().item():.2f}]")
print(f"  Scale range: [{gaussians.get_scales().min().item():.3f}, {gaussians.get_scales().max().item():.3f}]")
print(f"  Opacity range: [{gaussians.get_opacity().min().item():.3f}, {gaussians.get_opacity().max().item():.3f}]")

## 3. Simple Data Loader

Load LLFF dataset or create synthetic data for testing.

In [None]:
def load_simple_dataset(scene_root="data/nerf_llff_data/fern", downscale=8.0, max_views=None):
    """
    Simple dataset loader for LLFF format.
    Returns list of (image, pose, K) tuples.
    """
    try:
        from data.load_lff import LLFFDataset
        dataset = LLFFDataset(scene_root=scene_root, downscale=downscale, device=device)
        
        # Convert to simple format
        data = []
        n_views = len(dataset) if max_views is None else min(len(dataset), max_views)
        for i in range(n_views):
            sample = dataset[i]
            data.append({
                'image': sample['image'],  # (3, H, W)
                'pose': sample['pose'],    # (4, 4)
                'K': sample['K'],          # (3, 3)
                'H': sample['image'].shape[1],
                'W': sample['image'].shape[2]
            })
        
        print(f"✓ Loaded {len(data)} views from {scene_root}")
        print(f"  Resolution: {data[0]['H']}x{data[0]['W']}")
        return data
    except Exception as e:
        print(f"⚠ Could not load dataset: {e}")
        print("  Creating synthetic data for testing...")
        return create_synthetic_data(n_views=5, H=64, W=64)

def create_synthetic_data(n_views=5, H=64, W=64):
    """Create simple synthetic data for testing"""
    data = []
    focal = H * 0.7  # Simple focal length
    
    for i in range(n_views):
        # Simple circular camera path
        angle = 2 * np.pi * i / n_views
        cam_pos = np.array([np.cos(angle), 0, np.sin(angle)]) * 3.0
        
        # Simple pose (camera-to-world)
        pose = np.eye(4)
        pose[:3, 3] = cam_pos
        # Look at origin
        forward = -cam_pos / np.linalg.norm(cam_pos)
        right = np.cross([0, 1, 0], forward)
        right = right / np.linalg.norm(right)
        up = np.cross(forward, right)
        pose[:3, :3] = np.column_stack([right, up, -forward])
        
        # Simple intrinsics
        K = np.array([
            [focal, 0, W/2],
            [0, focal, H/2],
            [0, 0, 1]
        ])
        
        # Create simple test image (checkerboard pattern)
        img = np.zeros((3, H, W))
        checker = ((np.arange(H)[:, None] // 8) + (np.arange(W)[None, :] // 8)) % 2
        img[0] = checker * 0.8
        img[1] = (1 - checker) * 0.6
        img[2] = checker * 0.4
        
        data.append({
            'image': torch.from_numpy(img).float(),
            'pose': torch.from_numpy(pose).float(),
            'K': torch.from_numpy(K).float(),
            'H': H,
            'W': W
        })
    
    print(f"✓ Created {len(data)} synthetic views")
    return data

# Load dataset
dataset = load_simple_dataset(scene_root="data/nerf_llff_data/fern", downscale=8.0, max_views=10)
print(f"Dataset size: {len(dataset)} views")

## 4. Simple Renderer

Project 3D Gaussians to 2D image plane and render using alpha blending.

**Steps:**
1. Transform 3D positions to camera space
2. Project to 2D image coordinates
3. Project 3D covariance to 2D
4. Evaluate Gaussian at each pixel (vectorized)
5. Alpha blend front-to-back

In [None]:
def simple_render(gaussians, camera_dict):
    """
    Simple renderer for axis-aligned 3D Gaussians.
    
    Args:
        gaussians: SimpleGaussian3D instance
        camera_dict: dict with 'pose' (4x4), 'K' (3x3), 'width', 'height'
    
    Returns:
        rendered: (H, W, 3) RGB image
    """
    device = gaussians.device
    H, W = int(camera_dict['height']), int(camera_dict['width'])
    
    # Get parameters
    positions_3d = gaussians.pos  # (N, 3)
    covariances_3d = gaussians.get_covariance_3d()  # (N, 3, 3) - diagonal
    opacities = gaussians.get_opacity()  # (N,)
    colors = gaussians.get_colors()  # (N, 3)
    
    # Transform to camera space
    pose = camera_dict['pose'].to(device)  # camera-to-world
    w2c = torch.inverse(pose)  # world-to-camera
    
    # Transform positions: p_cam = R @ p_world + t
    positions_cam = (w2c[:3, :3] @ positions_3d.T + w2c[:3, 3:4]).T  # (N, 3)
    
    # Visibility culling (camera looks along -Z)
    visible = positions_cam[:, 2] < -0.01
    if visible.sum() == 0:
        # Fallback: try positive Z
        visible = positions_cam[:, 2] > 0.01
        if visible.sum() == 0:
            # Use all if still none
            visible = torch.ones(len(positions_cam), dtype=torch.bool, device=device)
    
    positions_cam = positions_cam[visible]
    covariances_3d = covariances_3d[visible]
    opacities = opacities[visible]
    colors = colors[visible]
    
    # Project to 2D
    K = camera_dict['K'].to(device)
    fx, fy = K[0, 0], K[1, 1]
    cx, cy = K[0, 2], K[1, 2]
    
    depths = torch.abs(positions_cam[:, 2]).clamp(min=0.01)
    x_2d = (positions_cam[:, 0] / depths) * fx + cx
    y_2d = (positions_cam[:, 1] / depths) * fy + cy
    positions_2d = torch.stack([x_2d, y_2d], dim=-1)  # (N, 2)
    
    # Project 3D covariance to 2D (simplified for axis-aligned)
    # For axis-aligned: project diagonal elements
    scales_3d = torch.sqrt(torch.diagonal(covariances_3d, dim1=-2, dim2=-1))  # (N, 3)
    
    # Simple 2D projection: scale by focal/depth
    scales_2d = torch.zeros(len(positions_2d), 2, device=device)
    scales_2d[:, 0] = (scales_3d[:, 0] / depths) * fx  # X scale
    scales_2d[:, 1] = (scales_3d[:, 1] / depths) * fy  # Y scale
    
    # Create 2D covariance (diagonal)
    covariances_2d = torch.zeros(len(positions_2d), 2, 2, device=device)
    covariances_2d[:, 0, 0] = scales_2d[:, 0] ** 2
    covariances_2d[:, 1, 1] = scales_2d[:, 1] ** 2
    covariances_2d = covariances_2d + 1e-4 * torch.eye(2, device=device).unsqueeze(0)  # Regularization
    
    # Sort by depth (front-to-back for alpha blending)
    depth_order = torch.argsort(depths)
    positions_2d = positions_2d[depth_order]
    covariances_2d = covariances_2d[depth_order]
    opacities = opacities[depth_order]
    colors = colors[depth_order]
    
    # Create pixel grid
    y_coords, x_coords = torch.meshgrid(
        torch.arange(H, device=device, dtype=torch.float32),
        torch.arange(W, device=device, dtype=torch.float32),
        indexing='ij'
    )
    pixels = torch.stack([x_coords.flatten(), y_coords.flatten()], dim=-1)  # (H*W, 2)
    
    # Render with alpha blending
    img = torch.zeros(H * W, 3, device=device)
    alpha = torch.ones(H * W, device=device)
    
    # Process in chunks to save memory
    chunk_size = 50
    for i in range(0, len(positions_2d), chunk_size):
        chunk_end = min(i + chunk_size, len(positions_2d))
        pos_chunk = positions_2d[i:chunk_end]
        cov_chunk = covariances_2d[i:chunk_end]
        opac_chunk = opacities[i:chunk_end]
        color_chunk = colors[i:chunk_end]
        
        # Vectorized evaluation for chunk
        for j in range(len(pos_chunk)):
            px, py = pos_chunk[j]
            
            # Bounds check
            if px < -50 or px > W + 50 or py < -50 or py > H + 50:
                continue
            
            # Compute Mahalanobis distance (vectorized)
            diff = pixels - pos_chunk[j]  # (H*W, 2)
            cov_inv = torch.inverse(cov_chunk[j] + 1e-6 * torch.eye(2, device=device))
            mahal = torch.sum(diff @ cov_inv * diff, dim=-1)  # (H*W,)
            
            # Gaussian kernel with 3-sigma cutoff
            mask = mahal < 9.0
            g = torch.zeros_like(mahal)
            g[mask] = torch.exp(-0.5 * mahal[mask])
            
            # Alpha blending
            contribution = alpha.unsqueeze(-1) * opac_chunk[j] * g.unsqueeze(-1) * color_chunk[j].unsqueeze(0)
            img = img + contribution
            
            # Update alpha
            alpha = alpha * (1.0 - opac_chunk[j] * g).clamp(0.0, 1.0)
    
    return torch.clamp(img.reshape(H, W, 3), 0, 1)

# Test render
test_camera = dataset[0]
rendered = simple_render(gaussians, {
    'pose': test_camera['pose'],
    'K': test_camera['K'],
    'width': test_camera['W'],
    'height': test_camera['H']
})
print(f"✓ Render test successful! Shape: {rendered.shape}")

## 5. Training Loop

Simple training with L2 loss and Adam optimizer.

In [None]:
def train_simple(gaussians, dataset, num_epochs=20, lr=0.01, print_every=5):
    """
    Simple training loop.
    
    Args:
        gaussians: SimpleGaussian3D instance
        dataset: List of camera dicts
        num_epochs: Number of training epochs
        lr: Learning rate
        print_every: Print progress every N epochs
    """
    optimizer = torch.optim.Adam(gaussians.parameters(), lr=lr)
    
    history = {'loss': [], 'psnr': []}
    
    print(f"Training for {num_epochs} epochs...")
    print("="*60)
    
    for epoch in range(num_epochs):
        epoch_loss = 0
        num_views = 0
        
        # Shuffle dataset
        indices = torch.randperm(len(dataset))
        
        for idx in indices:
            sample = dataset[int(idx)]
            
            # Render
            camera = {
                'pose': sample['pose'],
                'K': sample['K'],
                'width': sample['W'],
                'height': sample['H']
            }
            img_pred = simple_render(gaussians, camera)  # (H, W, 3)
            
            # Ground truth
            img_gt = sample['image'].permute(1, 2, 0).to(device)  # (H, W, 3)
            
            # Loss
            loss = ((img_pred - img_gt) ** 2).mean()
            
            # Optimize
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(gaussians.parameters(), max_norm=1.0)
            optimizer.step()
            
            epoch_loss += loss.item()
            num_views += 1
        
        avg_loss = epoch_loss / num_views
        
        # Compute PSNR on first view
        if epoch % print_every == 0 or epoch == num_epochs - 1:
            with torch.no_grad():
                test_sample = dataset[0]
                test_camera = {
                    'pose': test_sample['pose'],
                    'K': test_sample['K'],
                    'width': test_sample['W'],
                    'height': test_sample['H']
                }
                rendered = simple_render(gaussians, test_camera)
                gt = test_sample['image'].permute(1, 2, 0).to(device)
                mse = ((rendered - gt) ** 2).mean()
                psnr = -10 * torch.log10(mse + 1e-10)
                
                history['loss'].append(avg_loss)
                history['psnr'].append(psnr.item())
                
                print(f"Epoch {epoch:3d}/{num_epochs}: Loss={avg_loss:.5f}, PSNR={psnr:.2f} dB")
    
    return history

# Initialize Gaussians
n_gaussians = 1000
gaussians = SimpleGaussian3D(n_gaussians=n_gaussians, device=device)

# Train
history = train_simple(gaussians, dataset, num_epochs=20, lr=0.01, print_every=5)

## 6. Visualization

Visualize results: rendered images, training curves, and 3D Gaussian positions.

In [None]:
# Visualize rendered images
fig, axes = plt.subplots(2, min(3, len(dataset)), figsize=(15, 10))
if len(dataset) == 1:
    axes = axes.reshape(2, 1)

for i in range(min(3, len(dataset))):
    sample = dataset[i]
    camera = {
        'pose': sample['pose'],
        'K': sample['K'],
        'width': sample['W'],
        'height': sample['H']
    }
    
    with torch.no_grad():
        rendered = simple_render(gaussians, camera)
    
    # Ground truth
    gt = sample['image'].permute(1, 2, 0).cpu().numpy()
    axes[0, i].imshow(np.clip(gt, 0, 1))
    axes[0, i].set_title(f'Ground Truth (View {i})')
    axes[0, i].axis('off')
    
    # Rendered
    axes[1, i].imshow(np.clip(rendered.cpu().numpy(), 0, 1))
    axes[1, i].set_title(f'Rendered (View {i})')
    axes[1, i].axis('off')

plt.tight_layout()
plt.show()

In [None]:
# Plot training curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

epochs = range(0, len(history['loss']) * 5, 5)
ax1.plot(epochs, history['loss'])
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training Loss')
ax1.grid(True)

ax2.plot(epochs, history['psnr'])
ax2.set_xlabel('Epoch')
ax2.set_ylabel('PSNR (dB)')
ax2.set_title('PSNR')
ax2.grid(True)

plt.tight_layout()
plt.show()

In [None]:
# Visualize 3D Gaussian positions
from mpl_toolkits.mplot3d import Axes3D

fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')

with torch.no_grad():
    positions = gaussians.pos.cpu().numpy()
    colors = gaussians.get_colors().cpu().numpy()
    scales = gaussians.get_scales().cpu().numpy()
    opacity = gaussians.get_opacity().cpu().numpy()

# Plot Gaussians
ax.scatter(positions[:, 0], positions[:, 1], positions[:, 2], 
           c=colors, s=scales.mean(axis=1)*100, alpha=opacity*0.6)

# Plot camera positions if available
if len(dataset) > 0:
    cam_positions = []
    for sample in dataset[:5]:  # First 5 cameras
        pose = sample['pose']
        if isinstance(pose, torch.Tensor):
            pose = pose.cpu().numpy()
        cam_pos = pose[:3, 3]
        cam_positions.append(cam_pos)
    
    if cam_positions:
        cam_positions = np.array(cam_positions)
        ax.scatter(cam_positions[:, 0], cam_positions[:, 1], cam_positions[:, 2],
                  c='red', s=50, marker='^', label='Cameras')

ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.set_title('3D Gaussian Positions')
ax.legend()
plt.show()

## 7. Extension Ideas (for Diffusion)

This simplified implementation can be extended with diffusion:

1. **Parameter Space**: The Gaussian parameters (pos, scales, colors, opacity) form a parameter space
2. **Diffusion Process**: Add noise and denoise in this space
3. **Manifold Structure**: Consider the manifold structure of rotations (if adding rotations later)

**Simple diffusion extension:**
- Add noise to Gaussian parameters: `θ_t = θ_0 + ε * noise`
- Train a denoising network or use iterative denoising
- Sample new Gaussians via reverse diffusion process