# 3D Gaussian Splatting API on Google Colab
**Version: 3.1.1** (Updated: 2025-06-01)
**IMPORTANT: Check version number first! If not 3.1.1, reload from GitHub**

Based on working implementation from: https://dev.classmethod.jp/articles/3d-gaussian-splatting-on-colab/

Major changes in v3.1.1:
- Fixed DISPLAY environment variable type error
- Proper CUDA/PyTorch configuration for T4 GPU
- Correct COLMAP camera model settings
- Working Gaussian Splatting training pipeline
- Real PLY output (not fallback)

This notebook creates a FastAPI server that:
- Accepts image/video uploads via HTTP POST
- Extracts frames from videos (if needed)
- Runs COLMAP for camera pose estimation
- Trains a 3D Gaussian Splatting model
- Returns the trained model (.ply file) via Google Drive

**Note**: T4 GPU (16GB) with CUDA 12.2 and Compute Capability 7.5

## 1️⃣ Install System Dependencies

In [ ]:
# Install system dependencies and virtual display
print("Installing system dependencies...")
!apt-get update -qq
!apt-get install -y \
    libglew-dev \
    libassimp-dev \
    libboost-all-dev \
    libgtk-3-dev \
    libopencv-dev \
    libglfw3-dev \
    libavdevice-dev \
    libavcodec-dev \
    libeigen3-dev \
    libxxf86vm-dev \
    libembree-dev \
    cmake \
    imagemagick \
    ffmpeg \
    xvfb \
    x11-utils \
    python3-opengl \
    libegl1-mesa \
    libgl1-mesa-glx \
    libgles2-mesa \
    libosmesa6

# Install virtual display library
!pip install -q pyvirtualdisplay

# Fix ImageMagick policy for PDF/PS files (sometimes needed)
!sed -i '/disable ghostscript format types/,+6d' /etc/ImageMagick-6/policy.xml || true

print("✓ System dependencies installed")

In [ ]:
# Setup virtual display and install COLMAP
print("Setting up virtual display...")

# Start virtual display
from pyvirtualdisplay import Display
display = Display(visible=False, size=(1024, 768))
display.start()
print(f"✓ Virtual display started: {display}")

# Set environment for headless operation
import os
os.environ['DISPLAY'] = ':' + str(display.display)
os.environ['PYOPENGL_PLATFORM'] = 'egl'
os.environ['QT_QPA_PLATFORM'] = 'offscreen'

# Install COLMAP
print("\nInstalling COLMAP...")
!apt-get install -y colmap

# Verify installation
print("\nVerifying COLMAP installation...")
!which colmap
!colmap -h | head -5 || echo "COLMAP help"

# Test COLMAP can run
import subprocess
result = subprocess.run(['colmap', 'feature_extractor', '-h'], capture_output=True, text=True)
if result.returncode == 0:
    print("\n✅ COLMAP successfully installed and can run!")
    print("Feature extractor help available")
else:
    print("\n⚠️ COLMAP installation issue detected")
    print(f"Error: {result.stderr}")

## 2️⃣ Clone and Setup Gaussian Splatting

In [ ]:
# Clone the repository and setup properly
!git clone https://github.com/graphdeco-inria/gaussian-splatting --recursive
%cd gaussian-splatting

# Check CUDA availability
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda}")
print(f"PyTorch version: {torch.__version__}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

In [None]:
# Install Python dependencies
!pip install -q plyfile tqdm opencv-python joblib
!pip install -q fastapi uvicorn pyngrok nest-asyncio python-multipart aiofiles

# Install submodules
!pip install -q submodules/diff-gaussian-rasterization
!pip install -q submodules/simple-knn

# Install Python dependencies in the correct order
print("Installing Python dependencies...")

# First install plyfile and other basic deps
!pip install -q plyfile tqdm opencv-python joblib numpy

# Install FastAPI dependencies
!pip install -q fastapi uvicorn pyngrok nest-asyncio python-multipart aiofiles

# Install PyTorch with CUDA support (make sure it's compatible)
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

# Install Gaussian Splatting submodules - order matters!
print("Installing Gaussian Splatting submodules...")
!pip install -q submodules/diff-gaussian-rasterization
!pip install -q submodules/simple-knn

print("✓ All dependencies installed")

In [ ]:
import os
import sys
import shutil
import asyncio
import nest_asyncio
import subprocess
import cv2
import numpy as np
from pathlib import Path
from datetime import datetime
from contextlib import asynccontextmanager
from typing import List, Optional

from fastapi import FastAPI, UploadFile, File, HTTPException, Header, Form
from fastapi.responses import JSONResponse
import uvicorn
from pyngrok import ngrok
from google.colab import drive, userdata

# Add gaussian-splatting to path
sys.path.append('/content/gaussian-splatting')

# Set environment for headless operation
os.environ['PYOPENGL_PLATFORM'] = 'egl'
os.environ['QT_QPA_PLATFORM'] = 'offscreen'

# Add COLMAP to PATH
os.environ['PATH'] = '/usr/bin:/usr/local/bin:' + os.environ.get('PATH', '')

# Enable nested event loops
nest_asyncio.apply()

In [ ]:
# Configuration
# Get secrets from Colab userdata
try:
    NGROK_AUTHTOKEN = userdata.get('NGROK_AUTHTOKEN')
    API_KEY = userdata.get('API_KEY')
except Exception as e:
    print("⚠️  Warning: Could not load secrets from Colab userdata")
    print("Please set NGROK_AUTHTOKEN and API_KEY in Colab secrets")
    print("Settings → Secrets → Add new secret")
    NGROK_AUTHTOKEN = None
    API_KEY = None

UPLOAD_DIR = Path("/content/uploads")
DATASET_DIR = Path("/content/datasets")
OUTPUT_DIR = Path("/content/drive/MyDrive/gaussian_splatting_outputs")

# Create directories
UPLOAD_DIR.mkdir(exist_ok=True)
DATASET_DIR.mkdir(exist_ok=True)

# Set ngrok auth token
if NGROK_AUTHTOKEN:
    ngrok.set_auth_token(NGROK_AUTHTOKEN)
else:
    raise ValueError("NGROK_AUTHTOKEN not found in Colab secrets")

In [None]:
# Mount Google Drive
drive.mount('/content/drive')
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
print(f"Output directory: {OUTPUT_DIR}")

## 4️⃣ Processing Functions

In [ ]:
def extract_frames_from_video(video_path: Path, output_dir: Path, fps: int = 2) -> List[Path]:
    """Extract frames from video at specified FPS"""
    output_dir.mkdir(exist_ok=True)
    
    cap = cv2.VideoCapture(str(video_path))
    video_fps = cap.get(cv2.CAP_PROP_FPS)
    frame_interval = int(video_fps / fps)
    
    frame_paths = []
    frame_count = 0
    saved_count = 0
    
    while True:
        ret, frame = cap.read()
        if not ret:
            break
            
        if frame_count % frame_interval == 0:
            frame_path = output_dir / f"frame_{saved_count:04d}.jpg"
            cv2.imwrite(str(frame_path), frame)
            frame_paths.append(frame_path)
            saved_count += 1
            
        frame_count += 1
    
    cap.release()
    print(f"Extracted {saved_count} frames from video")
    return frame_paths


def prepare_dataset_from_images(image_dir: Path, dataset_name: str) -> Path:
    """Prepare dataset structure for Gaussian Splatting"""
    dataset_path = DATASET_DIR / dataset_name
    dataset_path.mkdir(exist_ok=True)
    
    # Create input directory for images
    input_dir = dataset_path / "input"
    input_dir.mkdir(exist_ok=True)
    
    # Copy images to input directory
    image_count = 0
    for img_path in list(image_dir.glob("*.jpg")) + list(image_dir.glob("*.png")):
        shutil.copy(img_path, input_dir)
        image_count += 1
    
    print(f"Copied {image_count} images to dataset")
    return dataset_path


def run_colmap_convert(dataset_path: Path) -> bool:
    """Run COLMAP using the convert.py script (with camera model fix)"""
    print("\nRunning COLMAP via convert.py...")
    
    # Run convert script
    cmd = [
        "python", "/content/gaussian-splatting/convert.py",
        "-s", str(dataset_path)
    ]
    
    print(f"Command: {' '.join(cmd)}")
    
    # Set environment for headless operation - fix the DISPLAY issue
    env = os.environ.copy()
    env['QT_QPA_PLATFORM'] = 'offscreen'
    
    # Fix DISPLAY environment variable - ensure it's a string
    if 'display' in globals():
        env['DISPLAY'] = f":{display.display}"
    else:
        env['DISPLAY'] = ':0'
    
    result = subprocess.run(cmd, capture_output=True, text=True, env=env)
    
    print(f"Return code: {result.returncode}")
    if result.stdout:
        print(f"STDOUT: {result.stdout}")
    if result.stderr:
        print(f"STDERR: {result.stderr}")
    
    # Check if sparse reconstruction was created
    sparse_dir = dataset_path / "sparse"
    if sparse_dir.exists():
        sparse_subdirs = [d for d in sparse_dir.iterdir() if d.is_dir() and d.name.isdigit()]
        if sparse_subdirs:
            print(f"✓ COLMAP created sparse reconstruction in: {sparse_subdirs}")
            
            # Check for camera model issues and fix if needed
            for sparse_subdir in sparse_subdirs:
                cameras_txt = sparse_subdir / "cameras.txt"
                if cameras_txt.exists():
                    fix_camera_model(cameras_txt)
            
            return True
    
    print("❌ COLMAP failed or no sparse reconstruction created")
    return False


def fix_camera_model(cameras_txt_path: Path):
    """Fix camera model to SIMPLE_PINHOLE if needed"""
    print(f"Checking camera model in {cameras_txt_path}")
    
    try:
        with open(cameras_txt_path, 'r') as f:
            content = f.read()
        
        # If using OPENCV model, change to SIMPLE_PINHOLE
        if "OPENCV" in content:
            print("⚠️ Found OPENCV camera model, changing to SIMPLE_PINHOLE")
            
            lines = content.split('\n')
            new_lines = []
            
            for line in lines:
                if line.startswith('#') or not line.strip():
                    new_lines.append(line)
                    continue
                
                parts = line.split()
                if len(parts) >= 5 and parts[1] == "OPENCV":
                    # Change OPENCV to SIMPLE_PINHOLE and adjust parameters
                    camera_id = parts[0]
                    width = parts[2]
                    height = parts[3]
                    # Use average of fx, fy for SIMPLE_PINHOLE
                    fx = float(parts[4])
                    fy = float(parts[5]) if len(parts) > 5 else fx
                    cx = float(parts[6]) if len(parts) > 6 else float(width) / 2
                    cy = float(parts[7]) if len(parts) > 7 else float(height) / 2
                    
                    focal = (fx + fy) / 2
                    new_line = f"{camera_id} SIMPLE_PINHOLE {width} {height} {focal} {cx} {cy}"
                    new_lines.append(new_line)
                    print(f"  Changed: {line}")
                    print(f"  To: {new_line}")
                else:
                    new_lines.append(line)
            
            # Write back the modified content
            with open(cameras_txt_path, 'w') as f:
                f.write('\n'.join(new_lines))
            
            print("✓ Camera model fixed")
    
    except Exception as e:
        print(f"Warning: Could not fix camera model: {e}")


async def run_colmap(dataset_path: Path):
    """Run COLMAP for structure from motion"""
    print("\n" + "="*60)
    print("Running COLMAP reconstruction...")
    print("="*60)
    
    # Try running COLMAP convert
    success = run_colmap_convert(dataset_path)
    
    if not success:
        raise Exception("COLMAP preprocessing failed")
    
    print("\n✅ COLMAP completed successfully!")


async def train_gaussian_splatting(dataset_path: Path, iterations: int = 7000) -> Path:
    """Train Gaussian Splatting model"""
    print(f"\n" + "="*60)
    print(f"Training Gaussian Splatting for {iterations} iterations...")
    print("="*60)
    
    output_path = dataset_path / "output"
    
    # Check CUDA before training
    import torch
    if not torch.cuda.is_available():
        raise Exception("CUDA not available for training")
    
    print(f"✓ CUDA available: {torch.cuda.get_device_name(0)}")
    
    # Build training command
    cmd = [
        "python", "/content/gaussian-splatting/train.py",
        "-s", str(dataset_path),
        "-m", str(output_path),
        "--iterations", str(iterations),
        "--save_iterations", str(iterations),
        "--test_iterations", str(iterations),
        "--quiet"
    ]
    
    print(f"Running: {' '.join(cmd[:6])}...")
    
    # Run training
    env = os.environ.copy()
    process = subprocess.run(cmd, capture_output=True, text=True, env=env)
    
    print(f"Training return code: {process.returncode}")
    
    if process.returncode != 0:
        print(f"Training STDERR: {process.stderr}")
        raise Exception(f"Training failed: {process.stderr}")
    
    # Check for output PLY file
    ply_path = output_path / "point_cloud" / f"iteration_{iterations}" / "point_cloud.ply"
    
    if not ply_path.exists():
        # List what was actually created
        point_cloud_dir = output_path / "point_cloud"
        if point_cloud_dir.exists():
            available = [d.name for d in point_cloud_dir.iterdir() if d.is_dir()]
            print(f"Available iterations: {available}")
            
            # Try to find any PLY file
            for d in point_cloud_dir.iterdir():
                if d.is_dir():
                    ply_candidate = d / "point_cloud.ply"
                    if ply_candidate.exists():
                        print(f"Found PLY at: {ply_candidate}")
                        return ply_candidate
        
        raise Exception(f"No PLY file found. Expected: {ply_path}")
    
    print(f"✅ Training successful! PLY created: {ply_path}")
    return ply_path

## 5️⃣ FastAPI Application

In [ ]:
@asynccontextmanager
async def lifespan(app: FastAPI):
    print("Starting 3D Gaussian Splatting API...")
    yield
    print("Shutting down...")

app = FastAPI(
    title="3D Gaussian Splatting API",
    description="Convert images/videos to 3D Gaussian Splatting models",
    version="1.0.0",
    lifespan=lifespan
)


def verify_api_key(api_key: str = Header(None)):
    if api_key != API_KEY:
        raise HTTPException(status_code=401, detail="Invalid API key")
    return api_key


@app.get("/")
async def root():
    return {
        "status": "online",
        "message": "3D Gaussian Splatting API is running",
        "endpoints": ["/", "/process"],
        "gpu": "T4 (16GB VRAM)"
    }


@app.post("/process")
async def process_gaussian_splatting(
    files: List[UploadFile] = File(...),
    iterations: Optional[int] = Form(1000),  # Reduced default for testing
    extract_fps: Optional[int] = Form(2),
    api_key: str = Header(None)
):
    """Process images/video to create 3D Gaussian Splatting model"""
    
    verify_api_key(api_key)
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    job_name = f"gs_{timestamp}"
    job_dir = UPLOAD_DIR / job_name
    job_dir.mkdir(exist_ok=True)
    
    try:
        # Process uploaded files
        image_dir = job_dir / "images"
        image_dir.mkdir(exist_ok=True)
        
        for file in files:
            file_path = job_dir / file.filename
            
            # Save uploaded file
            with open(file_path, "wb") as f:
                content = await file.read()
                f.write(content)
            
            # Check if video or image
            if file.filename.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')):
                # Extract frames from video
                extract_frames_from_video(file_path, image_dir, fps=extract_fps)
            else:
                # Copy image to image directory
                shutil.copy(file_path, image_dir)
        
        # Check if we have enough images
        image_count = len(list(image_dir.glob("*")))
        if image_count < 3:
            raise HTTPException(
                status_code=400,
                detail=f"Need at least 3 images, got {image_count}"
            )
        
        # For optimal results, warn if too few images
        if image_count < 10:
            print(f"Warning: Only {image_count} images. For better results, use 10+ images with good overlap.")
        
        print(f"Processing {image_count} images...")
        
        # Prepare dataset structure
        dataset_path = prepare_dataset_from_images(image_dir, job_name)
        
        # Run COLMAP reconstruction
        await run_colmap(dataset_path)
        
        # Train Gaussian Splatting model
        ply_path = await train_gaussian_splatting(dataset_path, iterations)
        
        # Copy results to Google Drive
        output_filename = f"{job_name}_gaussian_splatting.ply"
        drive_path = OUTPUT_DIR / output_filename
        shutil.copy(ply_path, drive_path)
        
        # Also save the entire output directory as zip (optional, might be large)
        output_zip = f"{job_name}_full_output"
        shutil.make_archive(
            str(OUTPUT_DIR / output_zip),
            'zip',
            dataset_path / "output"
        )
        
        return JSONResponse({
            "status": "success",
            "job_id": job_name,
            "model_file": output_filename,
            "download_path": f"/content/drive/MyDrive/gaussian_splatting_outputs/{output_filename}",
            "full_output_zip": f"{output_zip}.zip",
            "images_processed": image_count,
            "iterations": iterations,
            "completed_at": datetime.now().isoformat()
        })
        
    except Exception as e:
        # Log the full error for debugging
        import traceback
        print(f"Error processing job {job_name}:")
        print(traceback.format_exc())
        
        # Clean up on error
        if job_dir.exists():
            shutil.rmtree(job_dir, ignore_errors=True)
        
        # Return detailed error for debugging
        raise HTTPException(
            status_code=500, 
            detail=f"Processing failed: {str(e)}"
        )
    
    finally:
        # Clean up temporary files (keep dataset for debugging if needed)
        if job_dir.exists():
            shutil.rmtree(job_dir, ignore_errors=True)

## 6️⃣ Launch Server with ngrok

In [ ]:
# Start the server
import threading

def run_server():
    uvicorn.run(app, host="0.0.0.0", port=8000)

# Start server thread
server_thread = threading.Thread(target=run_server, daemon=True)
server_thread.start()

# Wait for server to start
import time
time.sleep(3)

# Create ngrok tunnel
tunnel = ngrok.connect(8000)
public_url = tunnel.public_url

print("\n" + "="*60)
print(f"🚀 3D Gaussian Splatting API is live at: {public_url}")
print("="*60)
print(f"\nTest with:")
print(f"export PUBLIC_URL='{public_url}'")
print(f"export API_KEY='{API_KEY}'")
print(f"\nFor multiple images:")
print(f'curl -X POST $PUBLIC_URL/process -H "Api-Key: $API_KEY" -F "files=@img1.jpg" -F "files=@img2.jpg" -F "files=@img3.jpg"')
print(f"\nFor video:")
print(f'curl -X POST $PUBLIC_URL/process -H "Api-Key: $API_KEY" -F "files=@video.mp4" -F "extract_fps=2"')
print("\n" + "="*60)

In [None]:
# Keep alive
print("\n⏰ Server is running. Keep this cell executing to maintain the connection.")
print("⚠️  Processing may take 10-30 minutes depending on image count and settings.")
print("Press 'Stop' to shutdown the server.\n")

try:
    while True:
        time.sleep(60)
        print(f"[{datetime.now().strftime('%H:%M:%S')}] Server alive at: {public_url}")
except KeyboardInterrupt:
    print("\nShutting down server...")
    ngrok.disconnect(public_url)
    ngrok.kill()