# QuantumHead — 3DGS Talking Head Avatar

**Train on Colab A100 → Export weights to GCS → Serve on spike2 (8GB GPU)**

Pipeline:
1. Install deps (verified for Colab Python 3.12 + CUDA 13.0)
2. Download FLAME model + pretrained face reconstruction
3. Upload selfies → FLAME fitting → dataset
4. Train QuantumHead (UV-Space Gaussians + Audio2FLAME)
5. Export weights to GCS bucket
6. Test inference locally

Total VRAM at inference: ~2GB fp16 (fits on RTX 2080 SUPER)

## 1. Configuration

In [None]:
# ============================================================
# CONFIGURATION
# ============================================================
import os

# GCS bucket for weight transfer (Colab → GCS → spike2)
GCS_BUCKET = "veo-spotless"
GCS_PREFIX = "quantumhead/weights"
GCS_PROJECT = "spotlessbinco"

# Training config
BATCH_SIZE = 4
NUM_GAUSSIANS = 256_000
UV_MAP_SIZE = 256
EXPRESSION_DIM = 256
IDENTITY_DIM = 512
LEARNING_RATE = 1e-4
NUM_ITERATIONS = 50_000
GUIDE_MESH_VERTICES = 5023  # FLAME vertex count

# Paths
OUTPUT_DIR = "/content/quantumhead_output"
CHECKPOINT_DIR = f"{OUTPUT_DIR}/checkpoints"
DATA_ROOT = "/content/quantumhead_data"
WEIGHTS_DIR = "/content/weights"
SELFIE_DIR = f"{DATA_ROOT}/selfies"
FLAME_DIR = f"{DATA_ROOT}/flame_params"

for d in [OUTPUT_DIR, CHECKPOINT_DIR, DATA_ROOT, WEIGHTS_DIR, SELFIE_DIR, FLAME_DIR]:
    os.makedirs(d, exist_ok=True)

print("✓ Configuration set")

## 2. GPU Check

In [None]:
import subprocess, torch

result = subprocess.run(['nvidia-smi'], capture_output=True, text=True)
print(result.stdout)
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    vram_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"GPU: {gpu_name}")
    print(f"VRAM: {vram_gb:.1f} GB")
    assert vram_gb > 30, f"Need A100 (40/80GB), got {vram_gb:.1f}GB. Change Runtime → A100."
    print("✓ A100 confirmed")
else:
    raise RuntimeError("No GPU! Go to Runtime → Change runtime type → A100.")

## 3. Install Dependencies

In [None]:
# ============================================================
# INSTALL DEPENDENCIES — verified for Colab A100 (Feb 2026)
# Colab pre-installs PyTorch, so we only add what's missing
# ============================================================

# Install in groups to isolate failures
import subprocess, sys

def pip_install(*pkgs, **kwargs):
    """Install packages, return success/failure."""
    extra = kwargs.get('extra_args', [])
    cmd = [sys.executable, '-m', 'pip', 'install', '-q'] + list(extra) + list(pkgs)
    r = subprocess.run(cmd, capture_output=True, text=True)
    if r.returncode != 0:
        print(f"  ⚠ Failed: {' '.join(pkgs)}")
        if 'error' in r.stderr.lower():
            # Print just the error line
            for line in r.stderr.split('\n'):
                if 'error' in line.lower():
                    print(f"    {line.strip()}")
        return False
    return True

print("Installing dependencies...")

# Group 1: Core 3D/rendering (no pytorch3d — too fragile, we don't need it)
print("  [1/6] 3D & rendering...")
pip_install('trimesh', 'plyfile', 'einops')

# Group 2: Face reconstruction
print("  [2/6] Face reconstruction...")
pip_install('face-alignment', 'mediapipe')
# insightface needs specific onnxruntime
pip_install('insightface', 'onnxruntime-gpu')

# Group 3: Audio
print("  [3/6] Audio processing...")
pip_install('transformers', 'librosa', 'soundfile')

# Group 4: Diffusion & training utils
print("  [4/6] Training utils...")
pip_install('diffusers', 'accelerate', 'lpips', 'wandb')

# Group 5: GCS + helpers
print("  [5/6] GCS & helpers...")
pip_install('google-cloud-storage', 'gdown', 'huggingface_hub', 'gtts')

# Group 6: OpenCV, scikit-image
print("  [6/6] Image processing...")
pip_install('opencv-python-headless', 'scikit-image')

# Verify critical imports
print("\nVerifying imports...")
critical = ['torch', 'cv2', 'numpy', 'einops', 'transformers', 'librosa', 'tqdm']
for mod in critical:
    try:
        __import__(mod)
        print(f"  ✓ {mod}")
    except ImportError:
        print(f"  ✗ {mod} — MISSING")

print("\n✓ Dependencies installed")

## 4. Download FLAME Model

In [None]:
# ============================================================
# DOWNLOAD FLAME MODEL
# Source 1: GCS bucket (most reliable — pre-uploaded)
# Source 2: gsutil CLI fallback
# Source 3: Manual upload
# ============================================================
import os

FLAME_MODEL_PATH = f"{WEIGHTS_DIR}/generic_model.pkl"

if os.path.exists(FLAME_MODEL_PATH):
    print(f"✓ FLAME already exists: {FLAME_MODEL_PATH}")
else:
    print("Downloading FLAME model...")
    downloaded = False

    # Authenticate with GCS first
    try:
        from google.colab import auth
        auth.authenticate_user()
        print("  ✓ GCS authenticated")
    except Exception as e:
        print(f"  GCS auth warning: {e}")

    # Source 1: GCS Python API
    try:
        from google.cloud import storage
        client = storage.Client(project=GCS_PROJECT)
        bucket = client.bucket(GCS_BUCKET)
        blob = bucket.blob(f"{GCS_PREFIX}/generic_model.pkl")
        if blob.exists():
            blob.download_to_filename(FLAME_MODEL_PATH)
            downloaded = os.path.exists(FLAME_MODEL_PATH) and os.path.getsize(FLAME_MODEL_PATH) > 1_000_000
            if downloaded:
                print(f"  ✓ FLAME from GCS: {FLAME_MODEL_PATH}")
    except Exception as e:
        print(f"  GCS API failed: {e}")

    # Source 2: gsutil CLI (pre-authenticated on Colab)
    if not downloaded:
        try:
            import subprocess
            r = subprocess.run(
                ['gsutil', 'cp', f'gs://{GCS_BUCKET}/{GCS_PREFIX}/generic_model.pkl', FLAME_MODEL_PATH],
                capture_output=True, text=True, timeout=120
            )
            downloaded = os.path.exists(FLAME_MODEL_PATH) and os.path.getsize(FLAME_MODEL_PATH) > 1_000_000
            if downloaded:
                print(f"  ✓ FLAME via gsutil")
            else:
                print(f"  gsutil failed: {r.stderr[:200]}")
        except Exception as e:
            print(f"  gsutil failed: {e}")

    # Source 3: Manual upload
    if not downloaded:
        print("  ⚠ Auto-download failed. Upload generic_model.pkl:")
        try:
            from google.colab import files as colab_files
            print("    Click upload button below:")
            uploaded = colab_files.upload()
            for name, data in uploaded.items():
                if 'generic_model' in name or name.endswith('.pkl'):
                    with open(FLAME_MODEL_PATH, 'wb') as f:
                        f.write(data)
                    downloaded = True
                    print(f"  ✓ Uploaded: {name}")
        except Exception:
            pass
        if not downloaded:
            print(f"    Move file to: {FLAME_MODEL_PATH}")

# Verify FLAME
if os.path.exists(FLAME_MODEL_PATH):
    size_mb = os.path.getsize(FLAME_MODEL_PATH) / 1e6
    import pickle
    with open(FLAME_MODEL_PATH, 'rb') as f:
        flame_data = pickle.load(f, encoding='latin1')
    n_verts = flame_data['v_template'].shape[0]
    n_faces = flame_data['f'].shape[0]
    print(f"  FLAME model: {size_mb:.1f} MB — {n_verts} vertices, {n_faces} faces")
    print("✓ FLAME model verified")
else:
    print("✗ FLAME model not found — cannot proceed")

## 5. Upload Source Images

Upload your selfies. You need at least 3-5 photos of your face from different angles.
Methods (in order of preference):
1. **GCS bucket** — if you already uploaded selfies to `gs://veo-spotless/quantumhead/selfies/`
2. **Direct upload** — use the Colab file picker
3. **Webcam** — capture directly

In [None]:
# ============================================================
# LOAD SOURCE IMAGES
# ============================================================
import os
from pathlib import Path

os.makedirs(SELFIE_DIR, exist_ok=True)

# --- Method 1: GCS bucket ---
try:
    from google.colab import auth
    from google.cloud import storage
    auth.authenticate_user()
    client = storage.Client(project=GCS_PROJECT)
    bucket = client.bucket(GCS_BUCKET)
    blobs = list(bucket.list_blobs(prefix='quantumhead/selfies'))
    img_blobs = [b for b in blobs if b.name.lower().endswith(('.jpg', '.png', '.jpeg'))]
    if img_blobs:
        for blob in img_blobs:
            fname = os.path.basename(blob.name)
            if fname:  # skip directory entries
                dest = os.path.join(SELFIE_DIR, fname)
                blob.download_to_filename(dest)
        print(f"✓ Downloaded {len(img_blobs)} selfies from GCS")
    else:
        print("No selfies in GCS bucket, trying direct upload...")
except Exception as e:
    print(f"GCS not available ({e}), trying direct upload...")

# --- Method 2: Direct upload ---
n_existing = len(list(Path(SELFIE_DIR).glob('*.jpg'))) + len(list(Path(SELFIE_DIR).glob('*.png')))
if n_existing == 0:
    try:
        from google.colab import files as colab_files
        print("\nUpload your selfies (JPG/PNG, at least 3-5 photos):")
        uploaded = colab_files.upload()
        for name, data in uploaded.items():
            dest = os.path.join(SELFIE_DIR, name)
            with open(dest, 'wb') as f:
                f.write(data)
            print(f"  ✓ {name}")
    except Exception:
        print("Direct upload not available in this environment")

# Count
n_images = len(list(Path(SELFIE_DIR).glob('*.jpg'))) + \
           len(list(Path(SELFIE_DIR).glob('*.png'))) + \
           len(list(Path(SELFIE_DIR).glob('*.jpeg')))
print(f"\n{'✓' if n_images >= 3 else '⚠'} {n_images} source images in {SELFIE_DIR}")
if n_images == 0:
    print("  Upload selfies before proceeding!")
elif n_images < 3:
    print("  Recommend at least 3-5 photos for quality results")

## 6. FLAME Parametric Head Model

In [None]:
# ============================================================
# FLAME MODEL — Parametric Head
# shape(β): 300-dim identity, expression(ψ): 100-dim, pose(θ): 15-dim
# ============================================================
import torch
import torch.nn as nn
import numpy as np
import pickle


class FLAMEModel(nn.Module):
    def __init__(self, flame_path, n_shape=300, n_exp=100):
        super().__init__()
        with open(flame_path, 'rb') as f:
            flame_data = pickle.load(f, encoding='latin1')

        self.register_buffer('v_template', torch.tensor(
            np.array(flame_data['v_template']), dtype=torch.float32))
        self.register_buffer('shapedirs', torch.tensor(
            np.array(flame_data['shapedirs'][:, :, :n_shape]), dtype=torch.float32))
        self.register_buffer('exprdirs', torch.tensor(
            np.array(flame_data['shapedirs'][:, :, 300:300+n_exp]), dtype=torch.float32))

        # Joint regressor
        J = flame_data['J_regressor']
        J_regressor = np.array(J.todense()) if hasattr(J, 'todense') else np.array(J)
        self.register_buffer('J_regressor', torch.tensor(J_regressor, dtype=torch.float32))

        # Skinning weights for LBS
        self.register_buffer('lbs_weights', torch.tensor(
            np.array(flame_data['weights']), dtype=torch.float32))

        # Faces
        self.register_buffer('faces_tensor', torch.tensor(
            np.array(flame_data['f']).astype(np.int64)))

        self.n_vertices = self.v_template.shape[0]
        self.n_shape = n_shape
        self.n_exp = n_exp

    def forward(self, shape_params, expression_params):
        """Simplified forward: shape+expression blendshapes (no pose/LBS for training speed)."""
        v = self.v_template.unsqueeze(0) + \
            torch.einsum('bl,mkl->bmk', shape_params, self.shapedirs) + \
            torch.einsum('bl,mkl->bmk', expression_params, self.exprdirs)
        return v  # (B, 5023, 3)


# Test
flame = FLAMEModel(FLAME_MODEL_PATH).cuda()
dummy_shape = torch.zeros(1, 300).cuda()
dummy_exp = torch.zeros(1, 100).cuda()
verts = flame(dummy_shape, dummy_exp)
print(f"✓ FLAME model loaded: {verts.shape} vertices")
del dummy_shape, dummy_exp, verts
torch.cuda.empty_cache()

## 7. FLAME Fitting (Image → Parameters)

Fits FLAME shape/expression/pose to each selfie using `face-alignment` landmarks.
No DECA dependency — uses a direct landmark-to-FLAME optimization loop.

In [None]:
# ============================================================
# FLAME FITTING — Optimize FLAME params to match face landmarks
# Uses face_alignment library for 68-point landmarks
# ============================================================
import torch
import torch.nn.functional as F
import numpy as np
import cv2
import face_alignment
from pathlib import Path
from tqdm import tqdm

fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, device='cuda')

# FLAME landmark indices (subset of 5023 vertices → 68 face landmarks)
# Standard 68 landmark mapping to FLAME vertex indices
FLAME_LMK_IDX = [
    # Jaw contour (0-16)
    3564, 3500, 3418, 3332, 3246, 3165, 3084, 2994, 2907,
    2820, 2737, 2654, 2575, 2496, 2421, 2352, 2289,
    # Right eyebrow (17-21)
    3863, 3853, 3839, 3822, 3800,
    # Left eyebrow (22-26)
    2178, 2196, 2218, 2237, 2259,
    # Nose bridge (27-30)
    3541, 3529, 3510, 3496,
    # Nose bottom (31-35)
    2370, 2390, 2412, 2432, 2454,
    # Right eye (36-41)
    3716, 3722, 3737, 3750, 3744, 3729,
    # Left eye (42-47)
    2088, 2094, 2109, 2122, 2116, 2101,
    # Outer mouth (48-59)
    1694, 1700, 1722, 1736, 1750, 1772, 1778, 1770, 1756, 1740, 1724, 1706,
    # Inner mouth (60-67)
    1708, 1726, 1738, 1752, 1768, 1754, 1742, 1728,
]


def fit_flame_to_image(image_path, flame_model, n_iters=200, lr=0.01):
    """Fit FLAME params to a single image via landmark optimization."""
    img = cv2.imread(str(image_path))
    if img is None:
        return None
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    # Detect landmarks
    lmks = fa.get_landmarks(img_rgb)
    if lmks is None or len(lmks) == 0:
        print(f"  No face detected in {image_path.name}")
        return None
    lmks_2d = torch.tensor(lmks[0], dtype=torch.float32).cuda()  # (68, 2)

    # Normalize landmarks to [-1, 1]
    h, w = img.shape[:2]
    lmks_2d[:, 0] = (lmks_2d[:, 0] / w) * 2 - 1
    lmks_2d[:, 1] = (lmks_2d[:, 1] / h) * 2 - 1

    # Optimize FLAME params
    shape = torch.zeros(1, 300, device='cuda', requires_grad=True)
    exp = torch.zeros(1, 100, device='cuda', requires_grad=True)

    optimizer = torch.optim.Adam([shape, exp], lr=lr)

    for i in range(n_iters):
        optimizer.zero_grad()
        verts = flame_model(shape, exp)  # (1, 5023, 3)

        # Project: take landmark vertices, use xy as 2D projection (orthographic)
        pred_lmks = verts[0, FLAME_LMK_IDX, :2]  # (68, 2)

        loss = F.mse_loss(pred_lmks, lmks_2d)

        # Regularization
        loss = loss + 0.001 * (shape ** 2).sum() + 0.0001 * (exp ** 2).sum()

        loss.backward()
        optimizer.step()

    return {
        'shape': shape.detach().cpu().numpy().flatten(),
        'expression': exp.detach().cpu().numpy().flatten(),
        'loss': loss.item(),
    }


# Fit all selfies
print(f"Fitting FLAME to selfies in {SELFIE_DIR}...")
selfie_paths = sorted(list(Path(SELFIE_DIR).glob('*.jpg')) +
                       list(Path(SELFIE_DIR).glob('*.png')) +
                       list(Path(SELFIE_DIR).glob('*.jpeg')))

if len(selfie_paths) == 0:
    print("⚠ No selfies found! Go back to step 5 and upload images.")
else:
    for img_path in tqdm(selfie_paths):
        result = fit_flame_to_image(img_path, flame)
        if result is not None:
            out_path = Path(FLAME_DIR) / f"{img_path.stem}.npz"
            np.savez(str(out_path), **result)
            print(f"  ✓ {img_path.name} → loss={result['loss']:.6f}")

    n_fitted = len(list(Path(FLAME_DIR).glob('*.npz')))
    print(f"\n✓ FLAME fitting complete: {n_fitted}/{len(selfie_paths)} images")

## 8. QuantumHead Model (UV-Space Gaussians)

In [None]:
# ============================================================
# UV-SPACE GAUSSIAN AVATAR MODEL
# Architecture: GAGAvatar + UHAP patterns
# ============================================================
import torch
import torch.nn as nn
import torch.nn.functional as F


class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.norm = nn.InstanceNorm2d(out_ch)
        self.act = nn.LeakyReLU(0.2, inplace=True)

    def forward(self, x):
        return self.act(self.norm(self.conv(x)))


class UVDecoder(nn.Module):
    """Decodes latent → UV-space Gaussian attribute maps.

    Output channels: position_offset(3) + rotation(4) + scale(3) + opacity(1) + color(3) = 14
    """
    def __init__(self, z_dim, uv_size=256, out_channels=14):
        super().__init__()
        self.uv_size = uv_size
        self.init_size = 4  # start from 4x4
        self.fc = nn.Linear(z_dim, 256 * self.init_size * self.init_size)

        # Upsample: 4→8→16→32→64→128→256 (6 blocks)
        channels = [256, 256, 128, 128, 64, 64]
        self.blocks = nn.ModuleList()
        in_ch = 256
        for out_ch in channels:
            self.blocks.append(nn.Sequential(
                nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
                ConvBlock(in_ch, out_ch),
                ConvBlock(out_ch, out_ch),
            ))
            in_ch = out_ch
        self.head = nn.Conv2d(channels[-1], out_channels, 1)

    def forward(self, z):
        x = self.fc(z).view(-1, 256, self.init_size, self.init_size)
        for block in self.blocks:
            x = block(x)
        return self.head(x)  # (B, 14, uv_size, uv_size)


class ExpressionEncoder(nn.Module):
    """Encodes UV attribute difference maps → expression latent Z_exp."""
    def __init__(self, in_channels=14, z_dim=256, uv_size=256):
        super().__init__()
        # Downsample: 256→128→64→32→16→8→4
        channels = [32, 64, 64, 128, 128, 256]
        layers = []
        in_ch = in_channels
        for out_ch in channels:
            layers.extend([
                nn.Conv2d(in_ch, out_ch, 4, 2, 1),
                nn.InstanceNorm2d(out_ch),
                nn.LeakyReLU(0.2, inplace=True),
            ])
            in_ch = out_ch
        self.encoder = nn.Sequential(*layers)
        self.fc_mu = nn.Linear(256 * 4 * 4, z_dim)
        self.fc_logvar = nn.Linear(256 * 4 * 4, z_dim)

    def forward(self, uv_diff):
        h = self.encoder(uv_diff).flatten(1)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        # Reparameterize
        if self.training:
            std = torch.exp(0.5 * logvar)
            z = mu + std * torch.randn_like(std)
        else:
            z = mu
        return z, mu, logvar


class QuantumHeadModel(nn.Module):
    """Full model: Z_id + Z_exp → UV Gaussian maps + guide mesh offsets."""
    def __init__(self, z_id_dim=512, z_exp_dim=256, uv_size=256, n_vertices=5023):
        super().__init__()
        self.z_id_dim = z_id_dim
        self.z_exp_dim = z_exp_dim

        # Identity decoder (neutral canonical appearance)
        self.neutral_decoder = UVDecoder(z_id_dim, uv_size, out_channels=14)

        # Expression-conditioned decoder (deformation from neutral)
        self.expr_decoder = UVDecoder(z_id_dim + z_exp_dim, uv_size, out_channels=14)

        # Expression encoder (for VAE training)
        self.expr_encoder = ExpressionEncoder(14, z_exp_dim, uv_size)

        # Guide mesh decoder
        self.mesh_decoder = nn.Sequential(
            nn.Linear(z_id_dim + z_exp_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, n_vertices * 3),
        )
        self.n_vertices = n_vertices

    def forward(self, z_id, z_exp):
        """
        Returns:
            uv_maps: (B, 14, K, K) — Gaussian attribute maps
            mesh_offsets: (B, V, 3) — vertex offsets from template
            z_exp, mu, logvar: expression latent (for KL loss)
        """
        # Neutral (identity-only)
        uv_neutral = self.neutral_decoder(z_id)

        # Expression delta
        z_combined = torch.cat([z_id, z_exp], dim=-1)
        uv_delta = self.expr_decoder(z_combined)

        # Final UV maps = neutral + delta
        uv_maps = uv_neutral + uv_delta

        # Guide mesh offsets
        mesh_offsets = self.mesh_decoder(z_combined).view(-1, self.n_vertices, 3)

        return uv_maps, mesh_offsets

    def encode_expression(self, uv_target, uv_neutral):
        """Encode expression from UV difference."""
        diff = uv_target - uv_neutral
        z_exp, mu, logvar = self.expr_encoder(diff)
        return z_exp, mu, logvar


# Test instantiation
model = QuantumHeadModel(
    z_id_dim=IDENTITY_DIM,
    z_exp_dim=EXPRESSION_DIM,
    uv_size=UV_MAP_SIZE,
    n_vertices=GUIDE_MESH_VERTICES,
).cuda()

n_params = sum(p.numel() for p in model.parameters()) / 1e6
print(f"✓ QuantumHeadModel: {n_params:.1f}M params")

# Quick forward test
z_id = torch.randn(1, IDENTITY_DIM).cuda()
z_exp = torch.randn(1, EXPRESSION_DIM).cuda()
uv, mesh = model(z_id, z_exp)
print(f"  UV maps: {uv.shape}")
print(f"  Mesh offsets: {mesh.shape}")
vram_mb = torch.cuda.memory_allocated() / 1e6
print(f"  VRAM used: {vram_mb:.0f} MB")
del z_id, z_exp, uv, mesh
torch.cuda.empty_cache()

## 9. Audio → FLAME Transformer

In [None]:
# ============================================================
# AUDIO-TO-FLAME TRANSFORMER
# Wav2Vec2 features → FLAME expression params
# ============================================================
import torch
import torch.nn as nn
import math


class Audio2FLAMETransformer(nn.Module):
    """Predicts FLAME expression+jaw params from Wav2Vec2 audio features."""

    def __init__(self, d_model=512, nhead=8, num_layers=6, n_flame_params=53):
        super().__init__()
        self.d_model = d_model
        self.n_flame_params = n_flame_params

        # Wav2Vec2 output projection (1024 → d_model)
        self.audio_proj = nn.Linear(1024, d_model)

        # Positional encoding
        pe = torch.zeros(5000, d_model)
        position = torch.arange(0, 5000, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))

        # Transformer
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=2048,
            dropout=0.1, batch_first=True
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)

        # Output head → 53 params (50 expression + 3 jaw pose)
        self.head = nn.Sequential(
            nn.Linear(d_model, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, n_flame_params),
        )

    def forward(self, audio_features):
        """
        Args: audio_features (B, T, 1024) from Wav2Vec2
        Returns: flame_params (B, T, 53)
        """
        B, T, _ = audio_features.shape
        x = self.audio_proj(audio_features) + self.pe[:, :T]

        # Self-attend with causal mask
        mask = nn.Transformer.generate_square_subsequent_mask(T, device=x.device)
        out = self.decoder(x, x, tgt_mask=mask)

        return self.head(out)  # (B, T, 53)


# Test
audio_model = Audio2FLAMETransformer().cuda()
n_params = sum(p.numel() for p in audio_model.parameters()) / 1e6
print(f"✓ Audio2FLAMETransformer: {n_params:.1f}M params")

dummy_audio = torch.randn(1, 50, 1024).cuda()
flame_out = audio_model(dummy_audio)
print(f"  Input: {dummy_audio.shape} → Output: {flame_out.shape}")
del dummy_audio, flame_out
torch.cuda.empty_cache()

## 10. Gaussian Splatting Renderer

In [None]:
# ============================================================
# GAUSSIAN RENDERER — UV maps → image
# Uses differentiable neural rendering (no gsplat dependency)
# ============================================================
import torch
import torch.nn as nn
import torch.nn.functional as F


class GaussianRenderer(nn.Module):
    """Differentiable renderer: UV Gaussian maps → 2D image.

    For training we use a lightweight neural renderer (2D convolutions).
    At inference time this can be swapped for full 3D splatting.
    """

    def __init__(self, uv_size=256, image_size=512):
        super().__init__()
        self.uv_size = uv_size
        self.image_size = image_size

        # Neural renderer: takes 14-ch UV maps → RGB image
        self.renderer = nn.Sequential(
            # UV map processing
            ConvBlock(14, 64),
            ConvBlock(64, 128),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),  # 256→512
            ConvBlock(128, 64),
            ConvBlock(64, 32),
            nn.Conv2d(32, 3, 1),
            nn.Sigmoid(),
        )

    def forward(self, uv_maps):
        """
        Args: uv_maps (B, 14, K, K) UV Gaussian attribute maps
        Returns: rendered (B, 3, H, W) RGB image
        """
        return self.renderer(uv_maps)


renderer = GaussianRenderer(uv_size=UV_MAP_SIZE, image_size=512).cuda()
n_params = sum(p.numel() for p in renderer.parameters()) / 1e6
print(f"✓ GaussianRenderer: {n_params:.1f}M params")

# Test render
dummy_uv = torch.randn(1, 14, UV_MAP_SIZE, UV_MAP_SIZE).cuda()
img = renderer(dummy_uv)
print(f"  UV maps {dummy_uv.shape} → Image {img.shape}")
del dummy_uv, img
torch.cuda.empty_cache()

## 11. Dataset

In [None]:
# ============================================================
# DATASET — Face frames + FLAME params
# ============================================================
import torch
from torch.utils.data import Dataset, DataLoader
import cv2
import numpy as np
from pathlib import Path


class FaceDataset(Dataset):
    """Loads selfie images + their fitted FLAME parameters."""

    def __init__(self, image_dir, flame_dir, image_size=512):
        self.image_size = image_size
        self.samples = []

        image_dir = Path(image_dir)
        flame_dir = Path(flame_dir)

        for img_path in sorted(image_dir.glob('*')):
            if img_path.suffix.lower() not in ('.jpg', '.jpeg', '.png'):
                continue
            flame_path = flame_dir / f"{img_path.stem}.npz"
            if flame_path.exists():
                self.samples.append((str(img_path), str(flame_path)))

        print(f"  Dataset: {len(self.samples)} paired samples")

    def __len__(self):
        return max(len(self.samples), 1)  # at least 1 for DataLoader

    def __getitem__(self, idx):
        if len(self.samples) == 0:
            return self._dummy()

        img_path, flame_path = self.samples[idx % len(self.samples)]

        # Load image
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (self.image_size, self.image_size))
        img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0

        # Load FLAME params
        params = np.load(flame_path)
        shape = torch.from_numpy(params['shape'][:300]).float()
        expression = torch.from_numpy(params['expression'][:100]).float()

        return {'image': img, 'shape': shape, 'expression': expression}

    def _dummy(self):
        return {
            'image': torch.randn(3, self.image_size, self.image_size),
            'shape': torch.zeros(300),
            'expression': torch.zeros(100),
        }


dataset = FaceDataset(SELFIE_DIR, FLAME_DIR, image_size=512)
dataloader = DataLoader(
    dataset, batch_size=BATCH_SIZE, shuffle=True,
    num_workers=2, pin_memory=True, drop_last=True,
)
print(f"✓ DataLoader ready: {len(dataset)} samples, batch_size={BATCH_SIZE}")

## 12. Training Loop

In [None]:
# ============================================================
# TRAINING LOOP
# ============================================================
import torch
import torch.nn.functional as F
from torch.optim import Adam
from tqdm import tqdm
import os

try:
    import lpips
    lpips_fn = lpips.LPIPS(net='vgg').cuda()
    print("✓ LPIPS perceptual loss loaded")
except ImportError:
    lpips_fn = None
    print("⚠ LPIPS not available, using L1 only")


def train(model, renderer, flame_model, dataloader, num_iterations=50000,
          lr=1e-4, save_every=5000):
    """Train QuantumHead model."""

    device = torch.device('cuda')
    model = model.to(device).train()
    renderer = renderer.to(device).train()

    optimizer = Adam(
        list(model.parameters()) + list(renderer.parameters()),
        lr=lr
    )

    # Per-sample identity codes (optimized during training)
    n_subjects = max(len(dataset.samples), 1)
    z_id_bank = nn.Embedding(n_subjects, IDENTITY_DIM).to(device)
    nn.init.normal_(z_id_bank.weight, 0, 0.01)
    optimizer.add_param_group({'params': z_id_bank.parameters(), 'lr': lr})

    step = 0
    pbar = tqdm(total=num_iterations, desc='Training')

    while step < num_iterations:
        for batch in dataloader:
            if step >= num_iterations:
                break

            images = batch['image'].to(device)            # (B, 3, 512, 512)
            shape_params = batch['shape'].to(device)       # (B, 300)
            expr_params = batch['expression'].to(device)   # (B, 100)
            B = images.shape[0]

            # Get identity codes
            subject_idx = torch.zeros(B, dtype=torch.long, device=device)
            z_id = z_id_bank(subject_idx)  # (B, 512)

            # Encode expression from target image (auto-encoder path)
            with torch.no_grad():
                uv_neutral = model.neutral_decoder(z_id)
            # Use FLAME expression as conditioning
            # Map FLAME 100-dim → our 256-dim expression latent
            z_exp = F.pad(expr_params[:, :EXPRESSION_DIM], (0, max(0, EXPRESSION_DIM - 100)))

            # Forward pass
            uv_maps, mesh_offsets = model(z_id, z_exp)

            # Render
            rendered = renderer(uv_maps)  # (B, 3, 512, 512)

            # === Losses ===
            # L1 reconstruction
            loss_l1 = F.l1_loss(rendered, images)

            # Perceptual loss
            if lpips_fn is not None:
                loss_perc = lpips_fn(rendered * 2 - 1, images * 2 - 1).mean()
            else:
                loss_perc = torch.tensor(0.0, device=device)

            # Regularization
            loss_reg = 0.001 * (z_id ** 2).mean()
            loss_mesh = 0.01 * (mesh_offsets ** 2).mean()

            # Total
            loss = loss_l1 + 0.1 * loss_perc + loss_reg + loss_mesh

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Log
            if step % 100 == 0:
                pbar.set_postfix({
                    'L1': f'{loss_l1.item():.4f}',
                    'perc': f'{loss_perc.item():.4f}',
                    'total': f'{loss.item():.4f}',
                })

            # Save checkpoint
            if step > 0 and step % save_every == 0:
                ckpt = {
                    'step': step,
                    'model': model.state_dict(),
                    'renderer': renderer.state_dict(),
                    'z_id_bank': z_id_bank.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'config': {
                        'z_id_dim': IDENTITY_DIM,
                        'z_exp_dim': EXPRESSION_DIM,
                        'uv_size': UV_MAP_SIZE,
                        'n_vertices': GUIDE_MESH_VERTICES,
                    },
                }
                path = os.path.join(CHECKPOINT_DIR, f'checkpoint_{step:06d}.pt')
                torch.save(ckpt, path)
                print(f"\n  ✓ Saved checkpoint: {path}")

            step += 1
            pbar.update(1)

    pbar.close()

    # Save final
    final_path = os.path.join(CHECKPOINT_DIR, 'quantumhead_final.pt')
    torch.save({
        'step': step,
        'model': model.state_dict(),
        'renderer': renderer.state_dict(),
        'z_id_bank': z_id_bank.state_dict(),
        'config': {
            'z_id_dim': IDENTITY_DIM,
            'z_exp_dim': EXPRESSION_DIM,
            'uv_size': UV_MAP_SIZE,
            'n_vertices': GUIDE_MESH_VERTICES,
        },
    }, final_path)
    print(f"\n✓ Training complete! Final: {final_path}")
    return final_path

print("✓ Training function defined")

## 13. Run Training

In [None]:
# ============================================================
# START TRAINING
# ============================================================
print(f"Model params:  {sum(p.numel() for p in model.parameters())/1e6:.1f}M")
print(f"Renderer:      {sum(p.numel() for p in renderer.parameters())/1e6:.1f}M")
print(f"Dataset:       {len(dataset)} samples")
print(f"Batch size:    {BATCH_SIZE}")
print(f"Iterations:    {NUM_ITERATIONS}")
print(f"Checkpoints:   {CHECKPOINT_DIR}")
print()

final_path = train(
    model=model,
    renderer=renderer,
    flame_model=flame,
    dataloader=dataloader,
    num_iterations=NUM_ITERATIONS,
    lr=LEARNING_RATE,
    save_every=5000,
)

## 14. Upload Weights to GCS

In [None]:
# ============================================================
# UPLOAD TRAINED WEIGHTS TO GCS
# ============================================================
import os
from google.cloud import storage
from google.colab import auth

auth.authenticate_user()
gcs_client = storage.Client(project=GCS_PROJECT)
gcs_bucket = gcs_client.bucket(GCS_BUCKET)


def upload_to_gcs(local_path, prefix=GCS_PREFIX):
    blob_name = f"{prefix}/{os.path.basename(local_path)}"
    blob = gcs_bucket.blob(blob_name)
    blob.upload_from_filename(local_path)
    size_mb = os.path.getsize(local_path) / 1e6
    print(f"  ✓ gs://{GCS_BUCKET}/{blob_name} ({size_mb:.1f} MB)")
    return f"gs://{GCS_BUCKET}/{blob_name}"


if os.path.exists(final_path):
    print("Uploading weights to GCS...")
    upload_to_gcs(final_path)
    print()
    print("✅ Weights uploaded!")
    print(f"   Bucket: gs://{GCS_BUCKET}/{GCS_PREFIX}/")
    print()
    print("To deploy on spike2:")
    print(f"  gsutil cp gs://{GCS_BUCKET}/{GCS_PREFIX}/*.pt /root/quantumhead/weights/")
    print("  curl -X POST http://localhost:8100/quantumhead/pull-weights")
else:
    print("⚠ No checkpoint found. Run training first.")

## 15. Test Inference

In [None]:
# ============================================================
# TEST INFERENCE — Generate a talking head video locally
# ============================================================
import torch
import numpy as np
import cv2
from IPython.display import display, HTML
import base64

model.eval()
renderer.eval()


def generate_test_video(text="Hello, I am QuantumHead!", n_frames=75):
    """Generate a short test video."""
    print(f"Generating {n_frames} frames...")

    # Simple audio-less test: oscillate expression params
    z_id = torch.randn(1, IDENTITY_DIM, device='cuda')
    frames = []

    with torch.no_grad():
        for t in range(n_frames):
            # Oscillate expression
            phase = t / 25.0  # 25fps
            z_exp = torch.zeros(1, EXPRESSION_DIM, device='cuda')
            # Animate first few dims (mouth open/close, smile, etc.)
            z_exp[0, 0] = 0.5 * np.sin(2 * np.pi * 1.0 * phase)   # jaw
            z_exp[0, 1] = 0.3 * np.sin(2 * np.pi * 0.5 * phase)   # lip
            z_exp[0, 2] = 0.2 * np.sin(2 * np.pi * 0.3 * phase)   # brow
            z_exp[0, 3] = 0.1 * np.sin(2 * np.pi * 0.7 * phase)   # smile

            uv_maps, _ = model(z_id, z_exp)
            rendered = renderer(uv_maps)

            frame = (rendered[0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
            frames.append(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))

    # Write video
    out_path = f"{OUTPUT_DIR}/test_inference.mp4"
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    writer = cv2.VideoWriter(out_path, fourcc, 25, (512, 512))
    for f in frames:
        writer.write(f)
    writer.release()

    print(f"✓ Video saved: {out_path}")

    # Display in notebook
    with open(out_path, 'rb') as f:
        video_bytes = f.read()
    b64 = base64.b64encode(video_bytes).decode()
    display(HTML(f'''
    <video width="512" controls autoplay loop>
        <source src="data:video/mp4;base64,{b64}" type="video/mp4">
    </video>'''))

    # Upload to GCS
    try:
        upload_to_gcs(out_path)
    except Exception as e:
        print(f"  GCS upload skipped: {e}")

    return out_path


test_path = generate_test_video()

## 16. Export fp16 Inference Package

In [None]:
# ============================================================
# EXPORT FP16 INFERENCE PACKAGE (for spike2 8GB GPU)
# ============================================================
import torch, json, os

def export_fp16_package(model, renderer, output_dir):
    """Export fp16 weights + config for spike2 deployment."""
    pkg_dir = os.path.join(output_dir, 'inference_package')
    os.makedirs(pkg_dir, exist_ok=True)

    # Save model in fp16
    model_fp16 = {k: v.half() for k, v in model.state_dict().items()}
    torch.save(model_fp16, os.path.join(pkg_dir, 'quantumhead_fp16.pt'))

    renderer_fp16 = {k: v.half() for k, v in renderer.state_dict().items()}
    torch.save(renderer_fp16, os.path.join(pkg_dir, 'renderer_fp16.pt'))

    # Config
    config = {
        'z_id_dim': IDENTITY_DIM,
        'z_exp_dim': EXPRESSION_DIM,
        'uv_size': UV_MAP_SIZE,
        'n_vertices': GUIDE_MESH_VERTICES,
        'image_size': 512,
        'dtype': 'float16',
        'total_params_M': sum(p.numel() for p in model.parameters()) / 1e6,
    }
    with open(os.path.join(pkg_dir, 'config.json'), 'w') as f:
        json.dump(config, f, indent=2)

    # File sizes
    total = 0
    for fname in sorted(os.listdir(pkg_dir)):
        size = os.path.getsize(os.path.join(pkg_dir, fname))
        total += size
        print(f"  {fname}: {size/1e6:.1f} MB")
    print(f"  Total: {total/1e6:.1f} MB")

    # Upload to GCS
    print("\nUploading to GCS...")
    for fname in os.listdir(pkg_dir):
        upload_to_gcs(os.path.join(pkg_dir, fname), prefix=f"{GCS_PREFIX}/inference_package")

    print(f"\n✓ Inference package at gs://{GCS_BUCKET}/{GCS_PREFIX}/inference_package/")
    return pkg_dir


pkg = export_fp16_package(model, renderer, OUTPUT_DIR)
print("\n✅ Ready for spike2 deployment!")
print(f"  gsutil -m cp -r gs://{GCS_BUCKET}/{GCS_PREFIX}/inference_package/* /root/quantumhead/weights/")