# üîÑ Glimpse3D - MVCRM Multi-View Consistent Refinement

**Enhance 3D representations using diffusion-enhanced multi-view images**

This notebook implements the **Multi-View Consistent Refinement Module (MVCRM)**, which:
1. Enhances rendered views with SDXL Lightning + ControlNet
2. Back-projects enhanced pixels into 3D space
3. Updates Gaussian splat parameters for consistent refinement

## Pipeline Role
```
TripoSR ‚Üí gsplat ‚Üí SyncDreamer ‚Üí SDXL Enhancement ‚Üí [This Notebook] ‚Üí Final Output
```

## Novel Contribution
This is a **key innovation** in the Glimpse3D pipeline - ensuring 2D diffusion enhancements are propagated back consistently into 3D.

---

## 1Ô∏è‚É£ Check GPU & Environment

In [None]:
import sys
IN_COLAB = 'google.colab' in sys.modules
print(f"Running in Colab: {IN_COLAB}")

!nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv

import torch
print(f"\nPyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

## 2Ô∏è‚É£ Install Dependencies

In [None]:
%%capture
# Core dependencies
!pip install torch torchvision --quiet
!pip install diffusers transformers accelerate --quiet
!pip install gsplat plyfile --quiet
!pip install numpy pillow matplotlib tqdm opencv-python-headless --quiet
!pip install scipy scikit-image --quiet

# For depth estimation
!pip install timm --quiet

print("‚úÖ Dependencies installed!")

## 3Ô∏è‚É£ Setup Working Directory

In [None]:
import os
from pathlib import Path

WORK_DIR = Path("/content/mvcrm_work")
WORK_DIR.mkdir(parents=True, exist_ok=True)

# Subdirectories
(WORK_DIR / "renders").mkdir(exist_ok=True)
(WORK_DIR / "enhanced").mkdir(exist_ok=True)
(WORK_DIR / "output").mkdir(exist_ok=True)

print(f"üìÇ Working directory: {WORK_DIR}")

## 4Ô∏è‚É£ Upload Inputs

Required inputs:
- Optimized Gaussian PLY from gsplat
- Multi-view images from SyncDreamer OR enhanced images from SDXL

In [None]:
from google.colab import files
import zipfile

# Upload Gaussian PLY
print("üì§ Upload optimized Gaussian PLY:")
uploaded_ply = files.upload()
PLY_PATH = WORK_DIR / list(uploaded_ply.keys())[0]
with open(PLY_PATH, 'wb') as f:
    f.write(list(uploaded_ply.values())[0])
print(f"‚úÖ PLY saved: {PLY_PATH}")

In [None]:
# Upload enhanced images (ZIP or individual)
print("\nüì§ Upload enhanced images (ZIP or individual PNGs):")
uploaded_images = files.upload()

IMAGE_DIR = WORK_DIR / "enhanced"

for fname, content in uploaded_images.items():
    if fname.endswith('.zip'):
        zip_path = WORK_DIR / fname
        with open(zip_path, 'wb') as f:
            f.write(content)
        with zipfile.ZipFile(zip_path, 'r') as z:
            z.extractall(IMAGE_DIR)
        print(f"‚úÖ Extracted: {fname}")
    else:
        img_path = IMAGE_DIR / fname
        with open(img_path, 'wb') as f:
            f.write(content)
        print(f"‚úÖ Saved: {fname}")

# List images
enhanced_images = sorted([f for f in IMAGE_DIR.iterdir() if f.suffix.lower() in ['.png', '.jpg', '.jpeg']])
print(f"\n‚úÖ Found {len(enhanced_images)} enhanced images")

## 5Ô∏è‚É£ Load Gaussian Model

In [None]:
import torch
import torch.nn as nn
import numpy as np
from plyfile import PlyData

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def load_gaussian_ply(path):
    """Load Gaussian Splat PLY file into tensors."""
    plydata = PlyData.read(path)
    vertex = plydata['vertex']
    
    xyz = np.stack([vertex['x'], vertex['y'], vertex['z']], axis=-1)
    f_dc = np.stack([vertex['f_dc_0'], vertex['f_dc_1'], vertex['f_dc_2']], axis=-1)
    
    f_rest_names = [f'f_rest_{i}' for i in range(45)]
    f_rest = np.stack([vertex[name] for name in f_rest_names if name in vertex.data.dtype.names], axis=-1)
    
    opacity = vertex['opacity']
    scales = np.stack([vertex['scale_0'], vertex['scale_1'], vertex['scale_2']], axis=-1)
    rotations = np.stack([vertex['rot_0'], vertex['rot_1'], vertex['rot_2'], vertex['rot_3']], axis=-1)
    
    return {
        'xyz': torch.tensor(xyz, dtype=torch.float32, device=device),
        'f_dc': torch.tensor(f_dc, dtype=torch.float32, device=device),
        'f_rest': torch.tensor(f_rest, dtype=torch.float32, device=device),
        'opacity': torch.tensor(opacity, dtype=torch.float32, device=device),
        'scales': torch.tensor(scales, dtype=torch.float32, device=device),
        'rotations': torch.tensor(rotations, dtype=torch.float32, device=device),
    }

class GaussianModel(nn.Module):
    def __init__(self, gaussians):
        super().__init__()
        self.xyz = nn.Parameter(gaussians['xyz'].clone())
        self.f_dc = nn.Parameter(gaussians['f_dc'].clone())
        self.f_rest = nn.Parameter(gaussians['f_rest'].clone())
        self.opacity_raw = nn.Parameter(gaussians['opacity'].clone())
        self.scales_raw = nn.Parameter(gaussians['scales'].clone())
        self.rotations = nn.Parameter(gaussians['rotations'].clone())
        
    @property
    def opacity(self):
        return torch.sigmoid(self.opacity_raw)
    
    @property
    def scales(self):
        return torch.exp(self.scales_raw)
    
    def get_colors(self):
        C0 = 0.28209479177387814
        return 0.5 + C0 * self.f_dc
    
    def forward(self):
        return {
            'xyz': self.xyz,
            'colors': self.get_colors(),
            'opacity': self.opacity,
            'scales': self.scales,
            'rotations': self.rotations / (self.rotations.norm(dim=-1, keepdim=True) + 1e-8),
        }

# Load model
gaussians = load_gaussian_ply(PLY_PATH)
model = GaussianModel(gaussians).to(device)
print(f"‚úÖ Loaded {len(gaussians['xyz']):,} Gaussians")

## 6Ô∏è‚É£ Camera System for SyncDreamer Views

In [None]:
import math

# SyncDreamer camera configuration (16 views)
# Views 0-7: Elevation 30¬∞, Azimuth 0¬∞, 45¬∞, 90¬∞, ..., 315¬∞
# Views 8-15: Elevation -20¬∞, Azimuth 0¬∞, 45¬∞, 90¬∞, ..., 315¬∞

SYNCDREAMER_ELEVATIONS = [30.0] * 8 + [-20.0] * 8
SYNCDREAMER_AZIMUTHS = [i * 45.0 for i in range(8)] * 2

def create_camera_pose(elevation_deg, azimuth_deg, radius=2.0):
    """Create world-to-camera matrix for given elevation and azimuth."""
    elev = math.radians(elevation_deg)
    azim = math.radians(azimuth_deg)
    
    # Camera position
    x = radius * math.cos(elev) * math.cos(azim)
    y = radius * math.cos(elev) * math.sin(azim)
    z = radius * math.sin(elev)
    
    cam_pos = np.array([x, y, z])
    look_at = np.array([0, 0, 0])
    up = np.array([0, 0, 1])
    
    forward = look_at - cam_pos
    forward = forward / np.linalg.norm(forward)
    
    right = np.cross(forward, up)
    right = right / np.linalg.norm(right)
    
    up_new = np.cross(right, forward)
    
    w2c = np.eye(4)
    w2c[:3, 0] = right
    w2c[:3, 1] = up_new
    w2c[:3, 2] = -forward
    w2c[:3, 3] = -w2c[:3, :3] @ cam_pos
    
    return w2c

def get_projection_matrix(fov_deg=60, aspect=1.0, near=0.1, far=100.0):
    fov_rad = math.radians(fov_deg)
    f = 1.0 / math.tan(fov_rad / 2)
    
    proj = np.zeros((4, 4))
    proj[0, 0] = f / aspect
    proj[1, 1] = f
    proj[2, 2] = (far + near) / (near - far)
    proj[2, 3] = 2 * far * near / (near - far)
    proj[3, 2] = -1
    
    return proj

# Generate all camera poses
camera_poses = [create_camera_pose(e, a) for e, a in zip(SYNCDREAMER_ELEVATIONS, SYNCDREAMER_AZIMUTHS)]
projection = get_projection_matrix(fov_deg=60)

print(f"‚úÖ Created {len(camera_poses)} camera poses matching SyncDreamer configuration")

## 7Ô∏è‚É£ Render and Compute Pixel-to-Gaussian Mapping

In [None]:
from gsplat import rasterization
from PIL import Image

IMAGE_SIZE = 512

def render_with_info(model, w2c, proj, image_size=512):
    """
    Render Gaussians and return:
    - RGB image
    - Alpha mask
    - Per-pixel Gaussian indices (for back-projection)
    """
    params = model()
    
    viewmat = torch.tensor(w2c, dtype=torch.float32, device=device)
    K = torch.tensor([
        [proj[0, 0] * image_size / 2, 0, image_size / 2],
        [0, proj[1, 1] * image_size / 2, image_size / 2],
        [0, 0, 1]
    ], dtype=torch.float32, device=device)
    
    render_colors, render_alphas, meta = rasterization(
        means=params['xyz'],
        quats=params['rotations'],
        scales=params['scales'],
        opacities=params['opacity'],
        colors=params['colors'],
        viewmats=viewmat.unsqueeze(0),
        Ks=K.unsqueeze(0),
        width=image_size,
        height=image_size,
        packed=False,
        render_mode="RGB",
    )
    
    return render_colors[0], render_alphas[0], meta

# Render all views and save
print("üé¨ Rendering all views...")
rendered_views = []

with torch.no_grad():
    for i, (w2c, elev, azim) in enumerate(zip(camera_poses, SYNCDREAMER_ELEVATIONS, SYNCDREAMER_AZIMUTHS)):
        rgb, alpha, _ = render_with_info(model, w2c, projection, IMAGE_SIZE)
        
        # Save render
        img_np = (rgb.cpu().numpy().clip(0, 1) * 255).astype(np.uint8)
        render_path = WORK_DIR / "renders" / f"render_{i:02d}.png"
        Image.fromarray(img_np).save(render_path)
        
        rendered_views.append(rgb)
        print(f"  View {i}: E={elev}¬∞ A={azim}¬∞")

print(f"\n‚úÖ Rendered {len(rendered_views)} views")

## 8Ô∏è‚É£ Load Enhanced Images

In [None]:
import matplotlib.pyplot as plt

# Load enhanced images
enhanced_tensors = []

for img_path in sorted(enhanced_images)[:16]:  # Use first 16 images
    img = Image.open(img_path).convert('RGB')
    img = img.resize((IMAGE_SIZE, IMAGE_SIZE))
    img_tensor = torch.tensor(np.array(img) / 255.0, dtype=torch.float32, device=device)
    enhanced_tensors.append(img_tensor)

print(f"‚úÖ Loaded {len(enhanced_tensors)} enhanced images")

# Visualize comparison
if len(enhanced_tensors) > 0 and len(rendered_views) > 0:
    fig, axes = plt.subplots(2, 4, figsize=(16, 8))
    
    for i in range(min(4, len(enhanced_tensors))):
        axes[0, i].imshow(rendered_views[i].cpu().numpy().clip(0, 1))
        axes[0, i].set_title(f"Original Render {i}")
        axes[0, i].axis('off')
        
        axes[1, i].imshow(enhanced_tensors[i].cpu().numpy().clip(0, 1))
        axes[1, i].set_title(f"Enhanced {i}")
        axes[1, i].axis('off')
    
    plt.suptitle("Original vs Enhanced Views")
    plt.tight_layout()
    plt.show()

## 9Ô∏è‚É£ Multi-View Consistent Refinement (MVCRM)

The key innovation: back-project 2D enhancements into 3D while maintaining consistency across views.

In [None]:
from tqdm import tqdm
import torch.nn.functional as F

class MVCRMRefiner:
    """Multi-View Consistent Refinement Module."""
    
    def __init__(self, model, camera_poses, projection, image_size=512):
        self.model = model
        self.camera_poses = camera_poses
        self.projection = projection
        self.image_size = image_size
        self.device = next(model.parameters()).device
        
    def compute_color_loss(self, rendered, target):
        """L1 + Perceptual loss between rendered and target."""
        l1_loss = F.l1_loss(rendered, target)
        
        # Simple luminance-based perceptual loss
        render_lum = 0.299 * rendered[..., 0] + 0.587 * rendered[..., 1] + 0.114 * rendered[..., 2]
        target_lum = 0.299 * target[..., 0] + 0.587 * target[..., 1] + 0.114 * target[..., 2]
        
        # Sobel-like edge detection
        render_grad_x = render_lum[:, 1:] - render_lum[:, :-1]
        render_grad_y = render_lum[1:, :] - render_lum[:-1, :]
        target_grad_x = target_lum[:, 1:] - target_lum[:, :-1]
        target_grad_y = target_lum[1:, :] - target_lum[:-1, :]
        
        edge_loss = F.l1_loss(render_grad_x, target_grad_x) + F.l1_loss(render_grad_y, target_grad_y)
        
        return l1_loss + 0.1 * edge_loss
    
    def refine(self, enhanced_images, num_iterations=500, lr_color=1e-3, lr_position=1e-5):
        """Refine Gaussians to match enhanced images."""
        
        # Setup optimizer - only optimize colors and positions
        optimizer = torch.optim.Adam([
            {'params': self.model.f_dc, 'lr': lr_color},
            {'params': self.model.xyz, 'lr': lr_position},
        ])
        
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_iterations)
        
        losses = []
        n_views = len(enhanced_images)
        
        print(f"üîÑ Starting MVCRM refinement with {n_views} views...")
        pbar = tqdm(range(num_iterations))
        
        for iteration in pbar:
            optimizer.zero_grad()
            total_loss = 0
            
            # Sample a subset of views for this iteration
            view_indices = np.random.choice(n_views, size=min(4, n_views), replace=False)
            
            for view_idx in view_indices:
                w2c = self.camera_poses[view_idx]
                target = enhanced_images[view_idx]
                
                # Render
                rendered, alpha, _ = render_with_info(self.model, w2c, self.projection, self.image_size)
                
                # Compute loss
                loss = self.compute_color_loss(rendered, target)
                total_loss += loss
            
            avg_loss = total_loss / len(view_indices)
            avg_loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            
            optimizer.step()
            scheduler.step()
            
            losses.append(avg_loss.item())
            
            if iteration % 50 == 0:
                pbar.set_postfix({'loss': f'{avg_loss.item():.4f}'})
        
        print(f"\n‚úÖ Refinement complete! Final loss: {losses[-1]:.4f}")
        return losses

# Initialize refiner
refiner = MVCRMRefiner(model, camera_poses, projection, IMAGE_SIZE)

In [None]:
# Run refinement
if len(enhanced_tensors) > 0:
    losses = refiner.refine(
        enhanced_tensors,
        num_iterations=500,
        lr_color=1e-3,
        lr_position=1e-5
    )
    
    # Plot loss curve
    plt.figure(figsize=(10, 4))
    plt.plot(losses)
    plt.xlabel('Iteration')
    plt.ylabel('Loss')
    plt.title('MVCRM Refinement Loss')
    plt.grid(True, alpha=0.3)
    plt.show()
else:
    print("‚ö†Ô∏è No enhanced images loaded. Skipping refinement.")

## üîü Compare Before/After Refinement

In [None]:
# Render refined views
print("üé¨ Rendering refined views...")
refined_views = []

with torch.no_grad():
    for i, w2c in enumerate(camera_poses):
        rgb, alpha, _ = render_with_info(model, w2c, projection, IMAGE_SIZE)
        refined_views.append(rgb)

# Compare
fig, axes = plt.subplots(3, 4, figsize=(16, 12))

for i in range(4):
    # Original render
    axes[0, i].imshow(rendered_views[i].cpu().numpy().clip(0, 1))
    axes[0, i].set_title(f"Original {i}")
    axes[0, i].axis('off')
    
    # Enhanced target
    if i < len(enhanced_tensors):
        axes[1, i].imshow(enhanced_tensors[i].cpu().numpy().clip(0, 1))
    axes[1, i].set_title(f"Enhanced Target {i}")
    axes[1, i].axis('off')
    
    # Refined render
    axes[2, i].imshow(refined_views[i].cpu().numpy().clip(0, 1))
    axes[2, i].set_title(f"Refined {i}")
    axes[2, i].axis('off')

plt.suptitle("MVCRM: Original ‚Üí Enhanced Target ‚Üí Refined Result", fontsize=14)
plt.tight_layout()
plt.savefig(WORK_DIR / "output" / "comparison.png", dpi=150)
plt.show()

## 1Ô∏è‚É£1Ô∏è‚É£ Export Refined Model

In [None]:
from plyfile import PlyElement, PlyData

def save_gaussian_ply(model, output_path):
    """Save Gaussian model to PLY file."""
    with torch.no_grad():
        params = model()
        
        xyz = params['xyz'].cpu().numpy()
        colors = model.f_dc.cpu().numpy()
        f_rest = model.f_rest.cpu().numpy()
        opacity = model.opacity_raw.cpu().numpy()
        scales = model.scales_raw.cpu().numpy()
        rotations = params['rotations'].cpu().numpy()
        
    num_points = len(xyz)
    
    dtype_full = [
        ('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
        ('f_dc_0', 'f4'), ('f_dc_1', 'f4'), ('f_dc_2', 'f4'),
    ]
    for i in range(f_rest.shape[1]):
        dtype_full.append((f'f_rest_{i}', 'f4'))
    dtype_full.extend([
        ('opacity', 'f4'),
        ('scale_0', 'f4'), ('scale_1', 'f4'), ('scale_2', 'f4'),
        ('rot_0', 'f4'), ('rot_1', 'f4'), ('rot_2', 'f4'), ('rot_3', 'f4'),
    ])
    
    elements = np.zeros(num_points, dtype=dtype_full)
    elements['x'] = xyz[:, 0]
    elements['y'] = xyz[:, 1]
    elements['z'] = xyz[:, 2]
    elements['f_dc_0'] = colors[:, 0]
    elements['f_dc_1'] = colors[:, 1]
    elements['f_dc_2'] = colors[:, 2]
    for i in range(f_rest.shape[1]):
        elements[f'f_rest_{i}'] = f_rest[:, i]
    elements['opacity'] = opacity
    elements['scale_0'] = scales[:, 0]
    elements['scale_1'] = scales[:, 1]
    elements['scale_2'] = scales[:, 2]
    elements['rot_0'] = rotations[:, 0]
    elements['rot_1'] = rotations[:, 1]
    elements['rot_2'] = rotations[:, 2]
    elements['rot_3'] = rotations[:, 3]
    
    el = PlyElement.describe(elements, 'vertex')
    PlyData([el]).write(output_path)
    print(f"‚úÖ Saved: {output_path}")

# Save refined model
refined_ply_path = WORK_DIR / "output" / "refined_gaussian.ply"
save_gaussian_ply(model, str(refined_ply_path))

## 1Ô∏è‚É£2Ô∏è‚É£ Create Video

In [None]:
import imageio

# Generate smooth 360¬∞ rotation
print("üé¨ Rendering 360¬∞ video...")
video_frames = []

with torch.no_grad():
    for azim in tqdm(np.linspace(0, 360, 60)):
        w2c = create_camera_pose(30.0, azim, radius=2.0)
        rgb, _, _ = render_with_info(model, w2c, projection, IMAGE_SIZE)
        frame = (rgb.cpu().numpy().clip(0, 1) * 255).astype(np.uint8)
        video_frames.append(frame)

# Save video
video_path = WORK_DIR / "output" / "refined_360.mp4"
imageio.mimsave(str(video_path), video_frames, fps=30)
print(f"‚úÖ Saved video: {video_path}")

# Display in notebook
from IPython.display import HTML
from base64 import b64encode

mp4 = open(video_path, 'rb').read()
data_url = f"data:video/mp4;base64,{b64encode(mp4).decode()}"
HTML(f'<video width=400 controls autoplay loop><source src="{data_url}" type="video/mp4"></video>')

## üì• Download Results

In [None]:
from google.colab import files
import shutil

# Create output ZIP
output_zip = shutil.make_archive(
    str(WORK_DIR / "mvcrm_output"),
    'zip',
    WORK_DIR / "output"
)

print("üì• Downloading results...")
files.download(output_zip)

print("\n‚úÖ Download complete!")

---

## ‚úÖ MVCRM Complete!

The **Multi-View Consistent Refinement Module** has:
1. Loaded the optimized Gaussian splats
2. Aligned camera poses with SyncDreamer configuration
3. Back-projected enhanced 2D images into 3D
4. Refined Gaussian parameters for multi-view consistency

The output `refined_gaussian.ply` is the final Glimpse3D result!