# 3D Gaussian Splatting API on Google Colab

This notebook creates a FastAPI server that:
- Accepts image/video uploads via HTTP POST
- Extracts frames from videos (if needed)
- Runs COLMAP for camera pose estimation
- Trains a 3D Gaussian Splatting model
- Returns the trained model (.ply file) via Google Drive

**Note**: T4 GPU (16GB) has less VRAM than recommended (24GB), so we'll use reduced settings.

## 1️⃣ Install System Dependencies

In [ ]:
# Install system dependencies
!apt-get update && apt-get install -y \
    libglew-dev \
    libassimp-dev \
    libboost-all-dev \
    libgtk-3-dev \
    libopencv-dev \
    libglfw3-dev \
    libavdevice-dev \
    libavcodec-dev \
    libeigen3-dev \
    libxxf86vm-dev \
    libembree-dev \
    cmake \
    imagemagick \
    ffmpeg

# Fix ImageMagick policy for PDF/PS files (sometimes needed)
!sed -i '/disable ghostscript format types/,+6d' /etc/ImageMagick-6/policy.xml || true

In [None]:
# Install COLMAP (prebuilt binary for faster setup)
!wget https://github.com/colmap/colmap/releases/download/3.8/colmap-3.8-linux-cuda.tar.gz
!tar -xzf colmap-3.8-linux-cuda.tar.gz
!cp -r colmap-3.8-linux-cuda/* /usr/local/
!rm -rf colmap-3.8-linux-cuda*

## 2️⃣ Clone and Setup Gaussian Splatting

In [None]:
# Clone the repository
!git clone https://github.com/graphdeco-inria/gaussian-splatting --recursive
%cd gaussian-splatting

In [None]:
# Install Python dependencies
!pip install -q plyfile tqdm opencv-python joblib
!pip install -q fastapi uvicorn pyngrok nest-asyncio python-multipart aiofiles

# Install submodules
!pip install -q submodules/diff-gaussian-rasterization
!pip install -q submodules/simple-knn

## 3️⃣ Setup API Server

In [ ]:
import os
import sys
import shutil
import asyncio
import nest_asyncio
import subprocess
import cv2
from pathlib import Path
from datetime import datetime
from contextlib import asynccontextmanager
from typing import List, Optional

from fastapi import FastAPI, UploadFile, File, HTTPException, Header, Form
from fastapi.responses import JSONResponse
import uvicorn
from pyngrok import ngrok
from google.colab import drive, userdata

# Add gaussian-splatting to path
sys.path.append('/content/gaussian-splatting')

# Enable nested event loops
nest_asyncio.apply()

In [ ]:
# Configuration
# Get secrets from Colab userdata
try:
    NGROK_AUTHTOKEN = userdata.get('NGROK_AUTHTOKEN')
    API_KEY = userdata.get('API_KEY')
except Exception as e:
    print("⚠️  Warning: Could not load secrets from Colab userdata")
    print("Please set NGROK_AUTHTOKEN and API_KEY in Colab secrets")
    print("Settings → Secrets → Add new secret")
    NGROK_AUTHTOKEN = None
    API_KEY = None

UPLOAD_DIR = Path("/content/uploads")
DATASET_DIR = Path("/content/datasets")
OUTPUT_DIR = Path("/content/drive/MyDrive/gaussian_splatting_outputs")

# Create directories
UPLOAD_DIR.mkdir(exist_ok=True)
DATASET_DIR.mkdir(exist_ok=True)

# Set ngrok auth token
if NGROK_AUTHTOKEN:
    ngrok.set_auth_token(NGROK_AUTHTOKEN)
else:
    raise ValueError("NGROK_AUTHTOKEN not found in Colab secrets")

In [None]:
# Mount Google Drive
drive.mount('/content/drive')
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
print(f"Output directory: {OUTPUT_DIR}")

## 4️⃣ Processing Functions

In [ ]:
def extract_frames_from_video(video_path: Path, output_dir: Path, fps: int = 2) -> List[Path]:
    """Extract frames from video at specified FPS"""
    output_dir.mkdir(exist_ok=True)
    
    cap = cv2.VideoCapture(str(video_path))
    video_fps = cap.get(cv2.CAP_PROP_FPS)
    frame_interval = int(video_fps / fps)
    
    frame_paths = []
    frame_count = 0
    saved_count = 0
    
    while True:
        ret, frame = cap.read()
        if not ret:
            break
            
        if frame_count % frame_interval == 0:
            frame_path = output_dir / f"frame_{saved_count:04d}.jpg"
            cv2.imwrite(str(frame_path), frame)
            frame_paths.append(frame_path)
            saved_count += 1
            
        frame_count += 1
    
    cap.release()
    print(f"Extracted {saved_count} frames from video")
    return frame_paths


def prepare_dataset_from_images(image_dir: Path, dataset_name: str) -> Path:
    """Prepare dataset structure for Gaussian Splatting"""
    dataset_path = DATASET_DIR / dataset_name
    dataset_path.mkdir(exist_ok=True)
    
    # Create input directory for COLMAP
    input_dir = dataset_path / "input"
    input_dir.mkdir(exist_ok=True)
    
    # Copy images to input directory
    image_count = 0
    for img_path in list(image_dir.glob("*.jpg")) + list(image_dir.glob("*.png")):
        shutil.copy(img_path, input_dir)
        image_count += 1
    
    print(f"Copied {image_count} images to {input_dir}")
    return dataset_path


def validate_dataset_structure(dataset_path: Path) -> bool:
    """Validate that dataset has correct structure for training"""
    print(f"\nValidating dataset structure at: {dataset_path}")
    
    # List all directories in dataset
    print("\nDataset directories:")
    for p in sorted(dataset_path.iterdir()):
        if p.is_dir():
            print(f"  - {p.name}/")
            # List subdirectories
            for sp in sorted(p.iterdir())[:5]:
                if sp.is_dir():
                    print(f"    - {sp.name}/")
    
    # Check for images directory (created by convert.py)
    images_dir = dataset_path / "images"
    if not images_dir.exists():
        print(f"❌ Missing required directory: images")
        # Check if images are in different location
        possible_dirs = ["images_2", "images_4", "images_8"]
        for dir_name in possible_dirs:
            if (dataset_path / dir_name).exists():
                print(f"  Found: {dir_name}")
        return False
    else:
        img_count = len(list(images_dir.glob("*")))
        print(f"✓ Found images directory with {img_count} files")
    
    # Check for sparse reconstruction
    sparse_dir = dataset_path / "sparse"
    if not sparse_dir.exists():
        print(f"❌ Missing required directory: sparse")
        return False
    
    # Find numbered sparse subdirectories
    sparse_subdirs = [d for d in sparse_dir.iterdir() if d.is_dir() and d.name.isdigit()]
    if not sparse_subdirs:
        print("❌ No numbered sparse directories found")
        # List what's actually in sparse
        print(f"  Sparse directory contents: {list(sparse_dir.iterdir())}")
        return False
    
    # Get the latest sparse directory (highest number)
    sparse_subdir = max(sparse_subdirs, key=lambda x: int(x.name))
    print(f"✓ Using sparse directory: {sparse_subdir}")
    
    # Check for required COLMAP output files
    required_files = ["cameras.bin", "images.bin", "points3D.bin"]
    missing_files = []
    for file_name in required_files:
        bin_path = sparse_subdir / file_name
        txt_path = sparse_subdir / file_name.replace(".bin", ".txt")
        if not bin_path.exists() and not txt_path.exists():
            missing_files.append(file_name)
        else:
            print(f"  ✓ Found: {file_name}")
    
    if missing_files:
        print(f"❌ Missing required files: {missing_files}")
        return False
    
    print("✓ Dataset structure validated successfully")
    return True


async def run_colmap(dataset_path: Path):
    """Run COLMAP for structure from motion using convert.py"""
    print("\n" + "="*60)
    print("Running COLMAP reconstruction...")
    print("="*60)
    
    # Verify input images exist
    input_dir = dataset_path / "input"
    input_images = list(input_dir.glob("*.jpg")) + list(input_dir.glob("*.png"))
    
    if not input_images:
        raise Exception(f"No images found in {input_dir}")
    
    print(f"Found {len(input_images)} input images")
    
    # Check image quality (basic check)
    print("\nChecking images:")
    for i, img_path in enumerate(input_images[:5]):  # Check first 5
        img = cv2.imread(str(img_path))
        if img is not None:
            h, w = img.shape[:2]
            print(f"  - {img_path.name}: {w}x{h}")
        else:
            print(f"  - {img_path.name}: Failed to load!")
    
    # Run convert.py which handles all COLMAP processing
    cmd = [
        "python", "/content/gaussian-splatting/convert.py",
        "-s", str(dataset_path)
        # Removed --resize to use full resolution by default
    ]
    
    print(f"\nRunning command: {' '.join(cmd)}")
    process = subprocess.run(cmd, capture_output=True, text=True)
    
    # Always print output for debugging
    print("\n--- COLMAP Output ---")
    if process.stdout:
        print("STDOUT:")
        print(process.stdout)
    if process.stderr:
        print("\nSTDERR:")
        print(process.stderr)
    print("--- End COLMAP Output ---\n")
    
    if process.returncode != 0:
        # Try to identify specific error
        if "Could not register" in process.stderr:
            print("⚠️  COLMAP failed to find enough feature matches between images.")
            print("Tips: Ensure images have good overlap (>70%) and are not too blurry.")
        raise Exception(f"COLMAP failed with return code {process.returncode}")
    
    # Validate the output structure
    if not validate_dataset_structure(dataset_path):
        # Try alternative approach: run COLMAP manually
        print("\n⚠️  Standard convert.py failed. Trying manual COLMAP approach...")
        
        # Create necessary directories
        (dataset_path / "distorted").mkdir(exist_ok=True)
        (dataset_path / "images").mkdir(exist_ok=True)
        (dataset_path / "sparse").mkdir(exist_ok=True)
        (dataset_path / "stereo").mkdir(exist_ok=True)
        
        # Copy images to images directory as fallback
        for img in input_images:
            shutil.copy(img, dataset_path / "images" / img.name)
        
        # Create minimal sparse reconstruction
        sparse_0 = dataset_path / "sparse" / "0"
        sparse_0.mkdir(exist_ok=True)
        
        # Check again
        if not validate_dataset_structure(dataset_path):
            raise Exception("COLMAP output structure is invalid")
    
    print("\n✓ COLMAP completed successfully")


async def train_gaussian_splatting(dataset_path: Path, iterations: int = 7000) -> Path:
    """Train Gaussian Splatting model with T4-optimized settings"""
    print(f"\nTraining Gaussian Splatting for {iterations} iterations...")
    
    # Validate dataset before training
    if not validate_dataset_structure(dataset_path):
        raise Exception("Dataset structure is invalid for training")
    
    output_path = dataset_path / "output"
    
    # Build command with correct parameters for T4 GPU
    cmd = [
        "python", "/content/gaussian-splatting/train.py",
        "-s", str(dataset_path),
        "-m", str(output_path),
        "--iterations", str(iterations),
        "--densify_until_iter", str(min(iterations, 5000)),  # Don't exceed iterations
        "--densification_interval", "100",
        "--position_lr_max_steps", str(iterations),
        "--save_iterations", str(iterations),  # Save at final iteration
        "--quiet"  # Reduce output verbosity
    ]
    
    # Add T4 GPU optimizations for limited memory
    if iterations <= 1000:  # For quick testing
        cmd.extend([
            "--test_iterations", str(iterations),
            "--checkpoint_iterations", str(iterations),
            "--sh_degree", "2"  # Reduce spherical harmonics for T4 GPU
        ])
    else:
        cmd.extend([
            "--test_iterations", "1000", "3000", "5000", "7000",
            "--checkpoint_iterations", "7000"
        ])
    
    print(f"Running training command: {' '.join(cmd)}")
    process = subprocess.run(cmd, capture_output=True, text=True)
    
    if process.returncode != 0:
        print(f"Training stderr: {process.stderr}")
        raise Exception(f"Training failed: {process.stderr}")
    
    # Find the output PLY file
    ply_path = output_path / "point_cloud" / f"iteration_{iterations}" / "point_cloud.ply"
    
    if not ply_path.exists():
        # List actual output structure
        point_cloud_dir = output_path / "point_cloud"
        if point_cloud_dir.exists():
            print(f"Available iterations: {[d.name for d in point_cloud_dir.iterdir() if d.is_dir()]}")
        raise Exception(f"Expected output not found at {ply_path}")
    
    print(f"✓ Training completed successfully. Output: {ply_path}")
    return ply_path

## 5️⃣ FastAPI Application

In [ ]:
@asynccontextmanager
async def lifespan(app: FastAPI):
    print("Starting 3D Gaussian Splatting API...")
    yield
    print("Shutting down...")

app = FastAPI(
    title="3D Gaussian Splatting API",
    description="Convert images/videos to 3D Gaussian Splatting models",
    version="1.0.0",
    lifespan=lifespan
)


def verify_api_key(api_key: str = Header(None)):
    if api_key != API_KEY:
        raise HTTPException(status_code=401, detail="Invalid API key")
    return api_key


@app.get("/")
async def root():
    return {
        "status": "online",
        "message": "3D Gaussian Splatting API is running",
        "endpoints": ["/", "/process"],
        "gpu": "T4 (16GB VRAM)"
    }


@app.post("/process")
async def process_gaussian_splatting(
    files: List[UploadFile] = File(...),
    iterations: Optional[int] = Form(1000),  # Reduced default for testing
    extract_fps: Optional[int] = Form(2),
    api_key: str = Header(None)
):
    """Process images/video to create 3D Gaussian Splatting model"""
    
    verify_api_key(api_key)
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    job_name = f"gs_{timestamp}"
    job_dir = UPLOAD_DIR / job_name
    job_dir.mkdir(exist_ok=True)
    
    try:
        # Process uploaded files
        image_dir = job_dir / "images"
        image_dir.mkdir(exist_ok=True)
        
        for file in files:
            file_path = job_dir / file.filename
            
            # Save uploaded file
            with open(file_path, "wb") as f:
                content = await file.read()
                f.write(content)
            
            # Check if video or image
            if file.filename.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')):
                # Extract frames from video
                extract_frames_from_video(file_path, image_dir, fps=extract_fps)
            else:
                # Copy image to image directory
                shutil.copy(file_path, image_dir)
        
        # Check if we have enough images
        image_count = len(list(image_dir.glob("*")))
        if image_count < 3:
            raise HTTPException(
                status_code=400,
                detail=f"Need at least 3 images, got {image_count}"
            )
        
        # For optimal results, warn if too few images
        if image_count < 10:
            print(f"Warning: Only {image_count} images. For better results, use 10+ images with good overlap.")
        
        print(f"Processing {image_count} images...")
        
        # Prepare dataset structure
        dataset_path = prepare_dataset_from_images(image_dir, job_name)
        
        # Run COLMAP reconstruction
        await run_colmap(dataset_path)
        
        # Train Gaussian Splatting model
        ply_path = await train_gaussian_splatting(dataset_path, iterations)
        
        # Copy results to Google Drive
        output_filename = f"{job_name}_gaussian_splatting.ply"
        drive_path = OUTPUT_DIR / output_filename
        shutil.copy(ply_path, drive_path)
        
        # Also save the entire output directory as zip (optional, might be large)
        output_zip = f"{job_name}_full_output"
        shutil.make_archive(
            str(OUTPUT_DIR / output_zip),
            'zip',
            dataset_path / "output"
        )
        
        return JSONResponse({
            "status": "success",
            "job_id": job_name,
            "model_file": output_filename,
            "download_path": f"/content/drive/MyDrive/gaussian_splatting_outputs/{output_filename}",
            "full_output_zip": f"{output_zip}.zip",
            "images_processed": image_count,
            "iterations": iterations,
            "completed_at": datetime.now().isoformat()
        })
        
    except Exception as e:
        # Log the full error for debugging
        import traceback
        print(f"Error processing job {job_name}:")
        print(traceback.format_exc())
        
        # Clean up on error
        if job_dir.exists():
            shutil.rmtree(job_dir, ignore_errors=True)
        
        # Return detailed error for debugging
        raise HTTPException(
            status_code=500, 
            detail=f"Processing failed: {str(e)}"
        )
    
    finally:
        # Clean up temporary files (keep dataset for debugging if needed)
        if job_dir.exists():
            shutil.rmtree(job_dir, ignore_errors=True)

## 6️⃣ Launch Server with ngrok

In [ ]:
# Start the server
import threading

def run_server():
    uvicorn.run(app, host="0.0.0.0", port=8000)

# Start server thread
server_thread = threading.Thread(target=run_server, daemon=True)
server_thread.start()

# Wait for server to start
import time
time.sleep(3)

# Create ngrok tunnel
tunnel = ngrok.connect(8000)
public_url = tunnel.public_url

print("\n" + "="*60)
print(f"🚀 3D Gaussian Splatting API is live at: {public_url}")
print("="*60)
print(f"\nTest with:")
print(f"export PUBLIC_URL='{public_url}'")
print(f"export API_KEY='{API_KEY}'")
print(f"\nFor multiple images:")
print(f'curl -X POST $PUBLIC_URL/process -H "Api-Key: $API_KEY" -F "files=@img1.jpg" -F "files=@img2.jpg" -F "files=@img3.jpg"')
print(f"\nFor video:")
print(f'curl -X POST $PUBLIC_URL/process -H "Api-Key: $API_KEY" -F "files=@video.mp4" -F "extract_fps=2"')
print("\n" + "="*60)

In [None]:
# Keep alive
print("\n⏰ Server is running. Keep this cell executing to maintain the connection.")
print("⚠️  Processing may take 10-30 minutes depending on image count and settings.")
print("Press 'Stop' to shutdown the server.\n")

try:
    while True:
        time.sleep(60)
        print(f"[{datetime.now().strftime('%H:%M:%S')}] Server alive at: {public_url}")
except KeyboardInterrupt:
    print("\nShutting down server...")
    ngrok.disconnect(public_url)
    ngrok.kill()