# üß† QuantumHead ‚Äî 3DGS Talking Head Avatar Pipeline

**Train on Colab (A100) ‚Üí Push weights to spike2 ‚Üí Serve inference from cloud**

Architecture based on SOTA research:
- **GAGAvatar** (246 FPS one-shot) ‚Äî UV-space Gaussians + PanoHead GAN
- **GaussianHeadTalk** (wobble-free) ‚Äî Audio ‚Üí FLAME params via transformer
- **UHAP** (universal prior) ‚Äî Expression latent encodes geometry + appearance

Pipeline:
```
Image ‚Üí DECA (FLAME fit) ‚Üí PanoHead (full-head tri-plane) ‚Üí UV Gaussians
Audio ‚Üí Wav2Vec2 ‚Üí Transformer ‚Üí FLAME params ‚Üí Animate ‚Üí 3DGS Render
                                                              ‚Üì
                                              Push weights to spike2
                                              spike2 serves inference
```

## 0. Setup & Configuration

In [None]:
# ============================================================
# CONFIGURATION ‚Äî Set these before running
# ============================================================

# spike2 server for inference serving
SPIKE2_HOST = "voice.quantum-forge.io"  # Public URL
SPIKE2_SSH = "spike2"                   # SSH alias (must be in ~/.ssh/config)
SPIKE2_WEIGHTS_DIR = "/root/quantumhead/weights"
SPIKE2_API_PORT = 8000

# Training config
BATCH_SIZE = 4
NUM_GAUSSIANS = 256_000   # No VRAM myth ‚Äî full 256K
UV_MAP_SIZE = 256         # K√óK UV attribute maps
EXPRESSION_DIM = 256      # Z_exp latent dimension
IDENTITY_DIM = 512        # Z_id latent dimension
LEARNING_RATE = 1e-4
NUM_ITERATIONS = 50_000   # Start with 50K, scale to 300K
GUIDE_MESH_VERTICES = 7306

# Diffusion config (audio ‚Üí expression)
DIFFUSION_STEPS = 500
AUDIO_CONTEXT_WINDOW = 120  # frames
AUDIO_OVERLAP = 30          # frames

# Model output
OUTPUT_DIR = "/content/quantumhead_output"
CHECKPOINT_DIR = f"{OUTPUT_DIR}/checkpoints"

print("‚úì Configuration set")

In [None]:
# ============================================================
# GPU CHECK
# ============================================================
import subprocess
result = subprocess.run(['nvidia-smi'], capture_output=True, text=True)
print(result.stdout)

import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
    assert torch.cuda.get_device_properties(0).total_mem > 30e9, "Need A100 (40/80GB). Change runtime!"

In [None]:
# ============================================================
# INSTALL DEPENDENCIES
# ============================================================
%%bash
set -e

# Core ML
pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

# 3D / Rendering
pip install -q \
  pytorch3d \
  trimesh \
  pyrender \
  open3d \
  plyfile

# FLAME / Face
pip install -q \
  chumpy \
  face-alignment \
  mediapipe \
  insightface \
  onnxruntime-gpu

# Audio
pip install -q \
  transformers \
  librosa \
  soundfile

# Diffusion
pip install -q \
  diffusers \
  accelerate

# Gaussian Splatting (build from source)
pip install -q \
  gsplat

# Utils
pip install -q \
  opencv-python-headless \
  scikit-image \
  einops \
  lpips \
  paramiko \
  scp \
  tqdm \
  wandb

echo "‚úì All dependencies installed"

In [None]:
# ============================================================
# CLONE COMPONENT REPOS
# ============================================================
import os
os.makedirs('/content/repos', exist_ok=True)
os.chdir('/content/repos')

repos = {
    'gaussian-avatars': 'https://github.com/ShenhanQian/GaussianAvatars.git',
    'DECA': 'https://github.com/yfeng95/DECA.git',
    'FaceFormer': 'https://github.com/EvelynFan/FaceFormer.git',
}

for name, url in repos.items():
    if not os.path.exists(name):
        os.system(f'git clone --depth 1 {url} {name}')
        print(f'‚úì Cloned {name}')
    else:
        print(f'‚úì {name} already exists')

os.chdir('/content')
print('\n‚úì All repos ready')

In [None]:
# ============================================================
# DOWNLOAD PRETRAINED WEIGHTS
# ============================================================
import os
os.makedirs('/content/weights', exist_ok=True)

# FLAME model (requires registration at https://flame.is.tue.mpg.de/)
# Upload generic_model.pkl manually or mount from Drive
FLAME_MODEL_PATH = '/content/weights/generic_model.pkl'

# DECA pretrained
DECA_CKPT = '/content/weights/deca_model.tar'

# Wav2Vec2 (auto-downloaded by transformers)
WAV2VEC_MODEL = 'facebook/wav2vec2-large-960h'

# Check what we have
print('Weight files:')
for f in os.listdir('/content/weights'):
    size = os.path.getsize(f'/content/weights/{f}') / 1e6
    print(f'  {f}: {size:.1f} MB')

if not os.path.exists(FLAME_MODEL_PATH):
    print('\n‚ö†Ô∏è  FLAME model not found. Upload generic_model.pkl to /content/weights/')
    print('   Get it from: https://flame.is.tue.mpg.de/')
    print('   Or mount Google Drive with: drive.mount("/content/drive")')

## 1. FLAME Parametric Head Model

In [None]:
# ============================================================
# FLAME MODEL WRAPPER
# ============================================================
import torch
import torch.nn as nn
import numpy as np
import pickle


class FLAMEModel(nn.Module):
    """FLAME parametric head model.

    shape(Œ≤): 300-dim identity shape
    expression(œà): 100-dim expression blendshapes
    pose(Œ∏): 15-dim (global + jaw + neck + eyes)
    """

    def __init__(self, flame_path, n_shape=300, n_exp=100):
        super().__init__()
        with open(flame_path, 'rb') as f:
            flame_data = pickle.load(f, encoding='latin1')

        # Template mesh vertices
        self.register_buffer('v_template', torch.tensor(
            np.array(flame_data['v_template']), dtype=torch.float32))

        # Shape blendshapes
        shapedirs = np.array(flame_data['shapedirs'][:, :, :n_shape])
        self.register_buffer('shapedirs', torch.tensor(shapedirs, dtype=torch.float32))

        # Expression blendshapes
        exprdirs = np.array(flame_data['shapedirs'][:, :, 300:300+n_exp])
        self.register_buffer('exprdirs', torch.tensor(exprdirs, dtype=torch.float32))

        # Pose blendshapes
        posedirs = np.array(flame_data['posedirs'])
        self.register_buffer('posedirs', torch.tensor(
            posedirs.reshape(posedirs.shape[0] * 3, -1).T, dtype=torch.float32))

        # Skinning weights
        self.register_buffer('lbs_weights', torch.tensor(
            np.array(flame_data['weights']), dtype=torch.float32))

        # Joint regressor
        J_regressor = np.array(flame_data['J_regressor'].todense())
        self.register_buffer('J_regressor', torch.tensor(J_regressor, dtype=torch.float32))

        # Kinematic tree
        self.register_buffer('kintree_table', torch.tensor(
            np.array(flame_data['kintree_table']).astype(np.int64)))

        # Faces
        self.register_buffer('faces_tensor', torch.tensor(
            np.array(flame_data['f']).astype(np.int64)))

        self.n_vertices = self.v_template.shape[0]  # 5023
        self.n_shape = n_shape
        self.n_exp = n_exp

    def forward(self, shape_params, expression_params, pose_params):
        """Forward pass: params ‚Üí deformed vertices.

        Args:
            shape_params: (B, n_shape) identity shape
            expression_params: (B, n_exp) expression
            pose_params: (B, 15) pose (global + jaw + neck + eyes)
        Returns:
            vertices: (B, 5023, 3)
        """
        batch_size = shape_params.shape[0]

        # Apply shape and expression blendshapes
        v_shaped = self.v_template.unsqueeze(0) + \
            torch.einsum('bl,mkl->bmk', shape_params, self.shapedirs) + \
            torch.einsum('bl,mkl->bmk', expression_params, self.exprdirs)

        # Joint locations
        J = torch.einsum('ji,bik->bjk', self.J_regressor, v_shaped)

        # Apply LBS (simplified ‚Äî full version uses rodrigues + kinematic chain)
        vertices = v_shaped  # For initial stage, just blendshapes

        return vertices

    @property
    def faces(self):
        return self.faces_tensor


print('‚úì FLAME model defined')

## 2. UV-Space Gaussian Avatar Model

In [None]:
# ============================================================
# UV-SPACE GAUSSIAN AVATAR (GAGAvatar + UHAP pattern)
# ============================================================
import torch
import torch.nn as nn
import torch.nn.functional as F


class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, upsample=False):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.norm = nn.InstanceNorm2d(out_ch)
        self.act = nn.LeakyReLU(0.2)
        self.upsample = upsample

    def forward(self, x):
        if self.upsample:
            x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        return self.act(self.norm(self.conv(x)))


class NeutralDecoder(nn.Module):
    """Decodes Z_id into identity-specific bias maps (UHAP pattern).
    Injected into Gaussian decoder at multiple scales."""

    def __init__(self, z_dim=512, num_scales=8):
        super().__init__()
        self.fc = nn.Linear(z_dim, 256 * 4 * 4)
        channels = [256, 256, 128, 128, 64, 64, 32, 16]
        self.blocks = nn.ModuleList()
        for i in range(num_scales):
            in_ch = 256 if i == 0 else channels[i-1]
            self.blocks.append(nn.Sequential(
                nn.ConvTranspose2d(in_ch, channels[i], 4, 2, 1),
                nn.LeakyReLU(0.2)
            ))

    def forward(self, z_id):
        x = self.fc(z_id).view(-1, 256, 4, 4)
        bias_maps = []
        for block in self.blocks:
            x = block(x)
            bias_maps.append(x)
        return bias_maps


class GuideMeshDecoder(nn.Module):
    """Predicts guide mesh vertex offsets from Z_id + Z_exp."""

    def __init__(self, z_id_dim=512, z_exp_dim=256, n_vertices=7306):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(z_id_dim + z_exp_dim, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 2048),
            nn.LeakyReLU(0.2),
            nn.Linear(2048, n_vertices * 3)
        )
        self.n_vertices = n_vertices

    def forward(self, z_id, z_exp):
        z = torch.cat([z_id, z_exp], dim=-1)
        offsets = self.net(z).view(-1, self.n_vertices, 3)
        return offsets


class GaussianAvatarDecoder(nn.Module):
    """Decodes Z_id + Z_exp + bias_maps ‚Üí UV Gaussian attribute maps.

    Outputs 14-channel UV map:
      - position offset (3)
      - rotation quaternion (4)
      - scale (3)
      - opacity (1)
      - color RGB (3)
    """

    def __init__(self, z_id_dim=512, z_exp_dim=256, uv_size=256):
        super().__init__()
        self.uv_size = uv_size

        # View-independent decoder (geometry: pos, rot, scale, opacity = 11ch)
        self.fc_vi = nn.Linear(z_id_dim + z_exp_dim, 256 * 8 * 8)
        vi_channels = [256, 128, 128, 64, 64, 32, 16, 11]
        self.vi_blocks = nn.ModuleList()
        for i, out_ch in enumerate(vi_channels):
            in_ch = 256 if i == 0 else vi_channels[i-1]
            self.vi_blocks.append(nn.Sequential(
                nn.ConvTranspose2d(in_ch, out_ch, 4, 2, 1),
                nn.LeakyReLU(0.2) if i < len(vi_channels) - 1 else nn.Identity()
            ))

        # Appearance decoder (color: RGB = 3ch, view-dependent)
        self.fc_rgb = nn.Linear(z_id_dim + z_exp_dim + 3, 256 * 8 * 8)  # +3 for view dir
        rgb_channels = [256, 128, 128, 64, 64, 32, 16, 3]
        self.rgb_blocks = nn.ModuleList()
        for i, out_ch in enumerate(rgb_channels):
            in_ch = 256 if i == 0 else rgb_channels[i-1]
            self.rgb_blocks.append(nn.Sequential(
                nn.ConvTranspose2d(in_ch, out_ch, 4, 2, 1),
                nn.LeakyReLU(0.2) if i < len(rgb_channels) - 1 else nn.Sigmoid()
            ))

    def forward(self, z_id, z_exp, view_dir=None, bias_maps=None):
        B = z_id.shape[0]
        z = torch.cat([z_id, z_exp], dim=-1)

        # View-independent (geometry)
        x_vi = self.fc_vi(z).view(B, 256, 8, 8)
        for i, block in enumerate(self.vi_blocks):
            x_vi = block(x_vi)
            # Inject neutral bias maps at matching scales
            if bias_maps is not None and i < len(bias_maps):
                bm = bias_maps[i]
                if bm.shape[2:] == x_vi.shape[2:] and bm.shape[1] == x_vi.shape[1]:
                    x_vi = x_vi + bm

        # Crop/pad to target UV size
        x_vi = F.interpolate(x_vi, size=(self.uv_size, self.uv_size), mode='bilinear', align_corners=False)

        # View-dependent (color)
        if view_dir is None:
            view_dir = torch.zeros(B, 3, device=z.device)  # frontal
        z_rgb = torch.cat([z, view_dir], dim=-1)
        x_rgb = self.fc_rgb(z_rgb).view(B, 256, 8, 8)
        for i, block in enumerate(self.rgb_blocks):
            x_rgb = block(x_rgb)
            if bias_maps is not None and i < len(bias_maps):
                bm = bias_maps[i]
                if bm.shape[2:] == x_rgb.shape[2:] and bm.shape[1] == x_rgb.shape[1]:
                    x_rgb = x_rgb + bm

        x_rgb = F.interpolate(x_rgb, size=(self.uv_size, self.uv_size), mode='bilinear', align_corners=False)

        # Combine: (B, 14, K, K)
        uv_maps = torch.cat([x_vi, x_rgb], dim=1)
        return uv_maps


class ExpressionEncoder(nn.Module):
    """VAE encoder: UV texture/geometry difference maps ‚Üí Z_exp (256-dim)."""

    def __init__(self, z_dim=256):
        super().__init__()
        channels = [32, 32, 64, 64, 128, 128, 256, 256]
        layers = []
        in_ch = 6  # 3ch texture diff + 3ch geometry diff
        for out_ch in channels:
            layers.extend([
                nn.Conv2d(in_ch, out_ch, 3, padding=1),
                nn.LeakyReLU(0.2),
                nn.AvgPool2d(2)
            ])
            in_ch = out_ch
        self.encoder = nn.Sequential(*layers)  # 512‚Üí2√ó2
        self.fc_mu = nn.Linear(256 * 2 * 2, z_dim)
        self.fc_logvar = nn.Linear(256 * 2 * 2, z_dim)

    def forward(self, delta_tex, delta_geo):
        x = torch.cat([delta_tex, delta_geo], dim=1)  # (B, 6, 512, 512)
        h = self.encoder(x).flatten(1)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        # Reparameterization trick
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z_exp = mu + eps * std
        return z_exp, mu, logvar


class QuantumHeadModel(nn.Module):
    """Full QuantumHead model ‚Äî combines all decoders."""

    def __init__(self, config=None):
        super().__init__()
        z_id = config.get('z_id_dim', 512) if config else 512
        z_exp = config.get('z_exp_dim', 256) if config else 256
        uv = config.get('uv_size', 256) if config else 256
        n_verts = config.get('guide_vertices', 7306) if config else 7306

        self.expression_encoder = ExpressionEncoder(z_dim=z_exp)
        self.neutral_decoder = NeutralDecoder(z_dim=z_id)
        self.guide_mesh_decoder = GuideMeshDecoder(z_id, z_exp, n_verts)
        self.gaussian_decoder = GaussianAvatarDecoder(z_id, z_exp, uv)

    def forward(self, z_id, z_exp, view_dir=None):
        """Full forward: latents ‚Üí UV Gaussian attribute maps + guide mesh."""
        bias_maps = self.neutral_decoder(z_id)
        guide_offsets = self.guide_mesh_decoder(z_id, z_exp)
        uv_maps = self.gaussian_decoder(z_id, z_exp, view_dir, bias_maps)
        return uv_maps, guide_offsets

    def encode_expression(self, delta_tex, delta_geo):
        return self.expression_encoder(delta_tex, delta_geo)


# Quick test
model = QuantumHeadModel()
z_id = torch.randn(2, 512)
z_exp = torch.randn(2, 256)
uv_maps, guide = model(z_id, z_exp)
print(f'UV maps: {uv_maps.shape}')     # (2, 14, 256, 256)
print(f'Guide mesh: {guide.shape}')     # (2, 7306, 3)
params = sum(p.numel() for p in model.parameters())
print(f'Total params: {params/1e6:.1f}M')
del model, z_id, z_exp, uv_maps, guide
print('‚úì QuantumHead model verified')

## 3. Audio ‚Üí FLAME Motion Model

In [None]:
# ============================================================
# AUDIO-TO-FLAME TRANSFORMER (GaussianHeadTalk pattern)
# ============================================================
import torch
import torch.nn as nn
import math


class PeriodicPositionalEncoding(nn.Module):
    """Periodic positional encoding for audio sequence."""
    def __init__(self, d_model, max_len=5000, period=25):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        # Add periodic component
        pe[:, 0::2] += torch.sin(position * 2 * math.pi / period)
        pe[:, 1::2] += torch.cos(position * 2 * math.pi / period)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]


class StyleEncoder(nn.Module):
    """Encodes identity-specific speaking style from template mesh."""
    def __init__(self, n_vertices=5023, d_model=512):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_vertices * 3, 1024),
            nn.ReLU(),
            nn.Linear(1024, d_model)
        )

    def forward(self, template_mesh):
        return self.net(template_mesh.flatten(1))


class Audio2FLAMETransformer(nn.Module):
    """Wav2Vec2 ‚Üí Transformer decoder ‚Üí FLAME expression params.

    Predicts FLAME params directly (not vertices) for stability.
    Output: 53 params (50 expression + 3 jaw pose)
    """

    def __init__(self, d_model=512, nhead=8, num_layers=6, n_flame_params=53):
        super().__init__()
        self.d_model = d_model

        # Audio feature projection (Wav2Vec2 output is 1024-dim)
        self.audio_proj = nn.Linear(1024, d_model)

        # Positional encoding
        self.pos_enc = PeriodicPositionalEncoding(d_model, period=25)  # 25fps

        # Transformer decoder
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model, nhead=nhead,
            dim_feedforward=2048, dropout=0.1,
            batch_first=True
        )
        self.transformer_decoder = nn.TransformerDecoder(
            decoder_layer, num_layers=num_layers
        )

        # Style embedding
        self.style_encoder = StyleEncoder(d_model=d_model)

        # Output heads
        self.flame_head = nn.Sequential(
            nn.Linear(d_model, 256),
            nn.ReLU(),
            nn.Linear(256, n_flame_params)
        )

    def forward(self, audio_features, style_embedding, causal_mask=None):
        """
        Args:
            audio_features: (B, T, 1024) from Wav2Vec2
            style_embedding: (B, d_model) from StyleEncoder
            causal_mask: optional causal mask for autoregressive
        Returns:
            flame_params: (B, T, 53) ‚Äî 50 expression + 3 jaw
        """
        B, T, _ = audio_features.shape

        # Project audio to d_model
        audio = self.audio_proj(audio_features)  # (B, T, 512)
        audio = self.pos_enc(audio)

        # Style as initial query
        style = style_embedding.unsqueeze(1).expand(-1, T, -1)  # (B, T, 512)

        # Generate causal mask
        if causal_mask is None:
            causal_mask = nn.Transformer.generate_square_subsequent_mask(
                T, device=audio.device)

        # Decode
        output = self.transformer_decoder(
            tgt=style,
            memory=audio,
            tgt_mask=causal_mask
        )

        # Predict FLAME params per frame
        flame_params = self.flame_head(output)  # (B, T, 53)
        return flame_params


# Quick test
a2f = Audio2FLAMETransformer()
audio_feat = torch.randn(2, 100, 1024)  # 100 frames of wav2vec2
style = torch.randn(2, 512)
params = a2f(audio_feat, style)
print(f'FLAME params: {params.shape}')  # (2, 100, 53)
n_params = sum(p.numel() for p in a2f.parameters())
print(f'Audio2FLAME params: {n_params/1e6:.1f}M')
del a2f, audio_feat, style, params
print('‚úì Audio2FLAME transformer verified')

## 4. Expression Diffusion Model (UHAP pattern)

In [None]:
# ============================================================
# EXPRESSION DIFFUSION MODEL (UHAP pattern ‚Äî Audio ‚Üí Z_exp)
# Maps Wav2Vec2 audio + lip vertices ‚Üí expression latent codes
# via DDPM for richer expressions beyond FLAME params
# ============================================================
import torch
import torch.nn as nn


class FiLMLayer(nn.Module):
    """Feature-wise Linear Modulation for timestep conditioning."""
    def __init__(self, d_model, d_cond):
        super().__init__()
        self.gamma = nn.Linear(d_cond, d_model)
        self.beta = nn.Linear(d_cond, d_model)

    def forward(self, x, cond):
        gamma = self.gamma(cond).unsqueeze(1)  # (B, 1, D)
        beta = self.beta(cond).unsqueeze(1)
        return gamma * x + beta


class ExpressionDiffusionTransformer(nn.Module):
    """DDPM backbone: denoises Z_exp conditioned on audio + lip vertices.

    Based on UHAP's design:
    - Self-attention on noisy expression codes
    - Cross-attention to audio features + lip vertices
    - FiLM layers for timestep embedding
    """

    def __init__(self, z_dim=256, d_model=512, nhead=8, num_layers=6):
        super().__init__()
        self.z_dim = z_dim

        # Input projections
        self.z_proj = nn.Linear(z_dim, d_model)
        self.audio_proj = nn.Linear(1024, d_model)  # Wav2Vec2
        self.lip_proj = nn.Linear(338 * 3, d_model)  # 338 lip vertices √ó 3

        # Timestep embedding
        self.time_embed = nn.Sequential(
            nn.Linear(256, d_model),
            nn.SiLU(),
            nn.Linear(d_model, d_model)
        )

        # Transformer layers with FiLM
        self.layers = nn.ModuleList()
        self.film_layers = nn.ModuleList()
        for _ in range(num_layers):
            self.layers.append(nn.TransformerDecoderLayer(
                d_model=d_model, nhead=nhead,
                dim_feedforward=2048, dropout=0.1,
                batch_first=True
            ))
            self.film_layers.append(FiLMLayer(d_model, d_model))

        # Output projection
        self.out_proj = nn.Linear(d_model, z_dim)

    def get_timestep_embedding(self, timesteps, dim=256):
        half = dim // 2
        freqs = torch.exp(-torch.arange(half, device=timesteps.device).float() *
                         (torch.log(torch.tensor(10000.0)) / half))
        args = timesteps.float().unsqueeze(1) * freqs.unsqueeze(0)
        return torch.cat([torch.cos(args), torch.sin(args)], dim=-1)

    def forward(self, z_noisy, timestep, audio_features, lip_vertices):
        """
        Args:
            z_noisy: (B, T, z_dim) noisy expression codes
            timestep: (B,) diffusion timestep
            audio_features: (B, T, 1024) from Wav2Vec2
            lip_vertices: (B, T, 338*3) predicted lip vertices
        Returns:
            noise_pred: (B, T, z_dim) predicted noise
        """
        # Embed inputs
        z = self.z_proj(z_noisy)
        audio = self.audio_proj(audio_features)
        lip = self.lip_proj(lip_vertices)

        # Conditioning: concat audio + lip
        memory = audio + lip

        # Timestep
        t_emb = self.get_timestep_embedding(timestep)
        t_emb = self.time_embed(t_emb)

        # Transformer + FiLM
        h = z
        for layer, film in zip(self.layers, self.film_layers):
            h = layer(h, memory)
            h = film(h, t_emb)

        return self.out_proj(h)


# Quick test
diff = ExpressionDiffusionTransformer()
z = torch.randn(2, 50, 256)
t = torch.randint(0, 500, (2,))
a = torch.randn(2, 50, 1024)
l = torch.randn(2, 50, 338*3)
out = diff(z, t, a, l)
print(f'Noise pred: {out.shape}')  # (2, 50, 256)
n = sum(p.numel() for p in diff.parameters())
print(f'Diffusion params: {n/1e6:.1f}M')
del diff, z, t, a, l, out
print('‚úì Expression diffusion model verified')

## 5. Gaussian Splatting Renderer

In [None]:
# ============================================================
# 3D GAUSSIAN SPLATTING RENDERER
# Converts UV attribute maps ‚Üí 3D Gaussians ‚Üí rendered image
# ============================================================
import torch
import torch.nn as nn
import torch.nn.functional as F


class GaussianRenderer(nn.Module):
    """Renders 3D Gaussians from UV attribute maps.

    UV maps ‚Üí sample Gaussians ‚Üí rasterize ‚Üí image.
    """

    def __init__(self, uv_size=256, image_size=512):
        super().__init__()
        self.uv_size = uv_size
        self.image_size = image_size

        # UV mask (which texels are valid ‚Äî matches FLAME topology)
        # In full implementation, this comes from FLAME UV layout
        self.register_buffer('uv_mask', torch.ones(uv_size, uv_size, dtype=torch.bool))

        # Sampling grid for Gaussians
        grid_y, grid_x = torch.meshgrid(
            torch.linspace(-1, 1, uv_size),
            torch.linspace(-1, 1, uv_size),
            indexing='ij'
        )
        self.register_buffer('sample_grid', torch.stack([grid_x, grid_y], dim=-1))

    def uv_to_gaussians(self, uv_maps, position_map):
        """Convert UV attribute maps to Gaussian parameters.

        Args:
            uv_maps: (B, 14, K, K) ‚Äî pos_offset(3), rot(4), scale(3), opacity(1), color(3)
            position_map: (B, 3, K, K) ‚Äî base 3D positions from FLAME mesh
        Returns:
            dict of Gaussian parameters
        """
        B = uv_maps.shape[0]

        # Split channels
        pos_offset = uv_maps[:, 0:3]   # (B, 3, K, K)
        rotation = uv_maps[:, 3:7]     # (B, 4, K, K)
        scale = uv_maps[:, 7:10]       # (B, 3, K, K)
        opacity = uv_maps[:, 10:11]    # (B, 1, K, K)
        color = uv_maps[:, 11:14]      # (B, 3, K, K)

        # Final positions = base + offset
        positions = position_map + pos_offset

        # Flatten UV to point cloud: (B, N, C)
        mask = self.uv_mask.flatten()  # (K*K,)

        def flatten_uv(t):
            B, C, H, W = t.shape
            return t.reshape(B, C, H*W).permute(0, 2, 1)[:, mask]  # (B, N_valid, C)

        return {
            'positions': flatten_uv(positions),
            'rotations': F.normalize(flatten_uv(rotation), dim=-1),
            'scales': torch.exp(flatten_uv(scale)),
            'opacities': torch.sigmoid(flatten_uv(opacity)),
            'colors': flatten_uv(color),
        }

    def render(self, gaussians, camera):
        """Render Gaussians to image using gsplat.

        Falls back to neural rendering if gsplat unavailable.
        """
        try:
            import gsplat
            # Full gsplat rasterization
            rendered = gsplat.rasterization(
                means=gaussians['positions'][0],
                quats=gaussians['rotations'][0],
                scales=gaussians['scales'][0],
                opacities=gaussians['opacities'][0].squeeze(-1),
                colors=gaussians['colors'][0],
                viewmats=camera['viewmat'].unsqueeze(0),
                Ks=camera['K'].unsqueeze(0),
                width=self.image_size,
                height=self.image_size,
            )
            return rendered[0]  # (H, W, 3)
        except (ImportError, Exception):
            # Fallback: simple neural renderer for training
            return self._neural_render(gaussians)

    def _neural_render(self, gaussians):
        """Simple differentiable neural renderer fallback."""
        B = gaussians['positions'].shape[0]
        # Project to 2D (simplified orthographic)
        pos_2d = gaussians['positions'][:, :, :2]  # (B, N, 2)
        colors = gaussians['colors']  # (B, N, 3)
        opacities = gaussians['opacities']  # (B, N, 1)

        # Splatting via scatter (simplified)
        H = W = self.image_size
        img = torch.zeros(B, 3, H, W, device=pos_2d.device)

        # Convert positions to pixel coords
        px = ((pos_2d[:, :, 0] + 1) * 0.5 * W).long().clamp(0, W-1)
        py = ((pos_2d[:, :, 1] + 1) * 0.5 * H).long().clamp(0, H-1)

        for b in range(B):
            for c in range(3):
                img[b, c].index_put_(
                    (py[b], px[b]),
                    colors[b, :, c] * opacities[b, :, 0],
                    accumulate=True
                )

        return img.clamp(0, 1)


print('‚úì Gaussian renderer defined')

## 6. Training Loop

In [None]:
# ============================================================
# TRAINING LOOP
# ============================================================
import os
import torch
import torch.nn.functional as F
from torch.optim import Adam
from tqdm import tqdm


def train_quantumhead(
    model,
    audio_model,
    renderer,
    dataloader,
    num_iterations=50000,
    lr=1e-4,
    checkpoint_dir='/content/quantumhead_output/checkpoints',
    save_every=5000,
):
    """Main training loop for QuantumHead."""

    os.makedirs(checkpoint_dir, exist_ok=True)
    device = torch.device('cuda')

    model = model.to(device)
    audio_model = audio_model.to(device)
    renderer = renderer.to(device)

    # Optimizers
    opt_model = Adam(model.parameters(), lr=lr)
    opt_audio = Adam(audio_model.parameters(), lr=lr)

    # Loss weights (from UHAP paper)
    w_rec = 1.0       # Image reconstruction (L1 + SSIM)
    w_neut = 0.5      # Neutral scan reconstruction
    w_kl = 0.001      # KL divergence
    w_geo = 1.0       # Guide mesh geometry
    w_perc = 0.1      # Perceptual (LPIPS)

    try:
        import lpips
        lpips_fn = lpips.LPIPS(net='vgg').to(device)
    except ImportError:
        lpips_fn = None

    model.train()
    audio_model.train()

    step = 0
    pbar = tqdm(total=num_iterations, desc='Training')

    while step < num_iterations:
        for batch in dataloader:
            if step >= num_iterations:
                break

            # Unpack batch
            images = batch['image'].to(device)           # (B, 3, H, W)
            flame_shape = batch['shape'].to(device)      # (B, 300)
            flame_exp = batch['expression'].to(device)   # (B, 100)
            flame_pose = batch['pose'].to(device)        # (B, 15)
            delta_tex = batch.get('delta_tex', torch.zeros(images.shape[0], 3, 512, 512)).to(device)
            delta_geo = batch.get('delta_geo', torch.zeros(images.shape[0], 3, 512, 512)).to(device)

            # --- Forward ---
            # Encode expression
            z_exp, mu, logvar = model.encode_expression(delta_tex, delta_geo)

            # Identity code (learnable per subject)
            z_id = torch.randn(images.shape[0], 512, device=device)  # TODO: per-subject optimization

            # Decode
            uv_maps, guide_offsets = model(z_id, z_exp)

            # Create position map from FLAME (simplified)
            pos_map = torch.zeros_like(uv_maps[:, :3])  # TODO: from FLAME model

            # Render
            gaussians = renderer.uv_to_gaussians(uv_maps, pos_map)
            rendered = renderer._neural_render(gaussians)

            # --- Losses ---
            # Reconstruction
            target = F.interpolate(images, size=rendered.shape[2:], mode='bilinear', align_corners=False)
            loss_rec = F.l1_loss(rendered, target)

            # KL divergence
            loss_kl = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())

            # Perceptual
            loss_perc = torch.tensor(0.0, device=device)
            if lpips_fn is not None:
                loss_perc = lpips_fn(rendered * 2 - 1, target * 2 - 1).mean()

            # Total
            loss = w_rec * loss_rec + w_kl * loss_kl + w_perc * loss_perc

            # --- Backward ---
            opt_model.zero_grad()
            opt_audio.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt_model.step()
            opt_audio.step()

            # Log
            if step % 100 == 0:
                pbar.set_postfix({
                    'loss': f'{loss.item():.4f}',
                    'rec': f'{loss_rec.item():.4f}',
                    'kl': f'{loss_kl.item():.4f}',
                })

            # Checkpoint
            if step > 0 and step % save_every == 0:
                ckpt = {
                    'step': step,
                    'model': model.state_dict(),
                    'audio_model': audio_model.state_dict(),
                    'opt_model': opt_model.state_dict(),
                    'opt_audio': opt_audio.state_dict(),
                }
                path = os.path.join(checkpoint_dir, f'ckpt_{step:06d}.pt')
                torch.save(ckpt, path)
                print(f'\n‚úì Saved checkpoint: {path}')

            step += 1
            pbar.update(1)

    pbar.close()

    # Final save
    final_path = os.path.join(checkpoint_dir, 'final.pt')
    torch.save({
        'step': step,
        'model': model.state_dict(),
        'audio_model': audio_model.state_dict(),
    }, final_path)
    print(f'‚úì Final model saved: {final_path}')
    return final_path


print('‚úì Training loop defined')

## 7. Dataset & DataLoader

In [None]:
# ============================================================
# DATASET ‚Äî Loads face video frames with FLAME annotations
# Supports: VFHQ, HDTF, VOCASET, or custom selfies
# ============================================================
import os
import torch
from torch.utils.data import Dataset, DataLoader
import cv2
import numpy as np
from pathlib import Path


class FaceVideoDataset(Dataset):
    """Dataset of face video frames with FLAME parameters.

    Directory structure:
      data_root/
        subject_001/
          frames/        # extracted video frames (512x512)
          flame/         # FLAME params per frame (.npz)
          audio/         # audio features (.npy)
    """

    def __init__(self, data_root, image_size=512, split='train'):
        self.data_root = Path(data_root)
        self.image_size = image_size
        self.samples = []

        # Scan for frame/FLAME pairs
        for subject_dir in sorted(self.data_root.iterdir()):
            if not subject_dir.is_dir():
                continue
            frames_dir = subject_dir / 'frames'
            flame_dir = subject_dir / 'flame'
            if not frames_dir.exists():
                continue

            for frame_path in sorted(frames_dir.glob('*.png')):
                flame_path = flame_dir / frame_path.with_suffix('.npz').name
                self.samples.append({
                    'frame': str(frame_path),
                    'flame': str(flame_path) if flame_path.exists() else None,
                    'subject': subject_dir.name,
                })

        print(f'‚úì Dataset: {len(self.samples)} frames from {len(set(s["subject"] for s in self.samples))} subjects')

    def __len__(self):
        return max(len(self.samples), 1)

    def __getitem__(self, idx):
        if len(self.samples) == 0:
            # Return dummy data for testing
            return {
                'image': torch.randn(3, self.image_size, self.image_size),
                'shape': torch.zeros(300),
                'expression': torch.zeros(100),
                'pose': torch.zeros(15),
                'delta_tex': torch.zeros(3, 512, 512),
                'delta_geo': torch.zeros(3, 512, 512),
            }

        sample = self.samples[idx % len(self.samples)]

        # Load image
        img = cv2.imread(sample['frame'])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (self.image_size, self.image_size))
        img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0

        # Load FLAME params
        if sample['flame'] and os.path.exists(sample['flame']):
            flame = np.load(sample['flame'])
            shape = torch.from_numpy(flame.get('shape', np.zeros(300))).float()
            expression = torch.from_numpy(flame.get('expression', np.zeros(100))).float()
            pose = torch.from_numpy(flame.get('pose', np.zeros(15))).float()
        else:
            shape = torch.zeros(300)
            expression = torch.zeros(100)
            pose = torch.zeros(15)

        return {
            'image': img,
            'shape': shape,
            'expression': expression,
            'pose': pose,
            'delta_tex': torch.zeros(3, 512, 512),
            'delta_geo': torch.zeros(3, 512, 512),
        }


print('‚úì Dataset class defined')

In [None]:
# ============================================================
# DOWNLOAD & PREPARE HDTF DATASET (audio-visual talking heads)
# ============================================================
import os

DATA_ROOT = '/content/quantumhead_data'
os.makedirs(DATA_ROOT, exist_ok=True)

print('=== Dataset Options ===')
print('1. HDTF (High-Definition Talking Face) ‚Äî ~15.8 hours, 720p+')
print('   git clone https://github.com/MRzzm/HDTF.git')
print()
print('2. VFHQ (Video Face Super-Resolution HQ) ‚Äî large-scale face videos')
print('   Requires application at: https://liangbinxie.github.io/projects/vfhq/')
print()
print('3. VOCASET (VOice Controlled Avatars) ‚Äî audio + 3D tracked FLAME')
print('   https://voca.is.tue.mpg.de/')
print()
print('4. Custom selfies from spike2 (your 12 selfies)')
print('   Will auto-download from spike2 server')
print()
print('For quick start, we\'ll use your selfies + HDTF samples.')
print('Upload data to:', DATA_ROOT)

In [None]:
# ============================================================
# PULL SELFIES FROM SPIKE2
# ============================================================
import os
import requests

SELFIE_DIR = os.path.join(DATA_ROOT, 'riley_selfies/frames')
os.makedirs(SELFIE_DIR, exist_ok=True)

# Pull selfies from spike2 via the API
SPIKE2_URL = f'https://{SPIKE2_HOST}'

try:
    # List available selfie sets
    resp = requests.get(f'{SPIKE2_URL}/avatar/models', timeout=10, verify=False)
    if resp.ok:
        models = resp.json()
        print(f'Found {len(models)} avatar models on spike2')
        for m in models:
            print(f'  - {m}')
    else:
        print(f'Could not reach spike2 API: {resp.status_code}')
except Exception as e:
    print(f'spike2 not reachable: {e}')
    print('Will use dummy data for architecture testing')

## 8. FLAME Fitting with DECA

In [None]:
# ============================================================
# DECA-BASED FLAME FITTING
# Single image ‚Üí FLAME shape, expression, pose params
# ============================================================
import sys
sys.path.insert(0, '/content/repos/DECA')


def fit_flame_to_image(image_path, deca_model=None):
    """Fit FLAME parameters to a single image using DECA.

    Returns:
        dict with 'shape' (300,), 'expression' (100,), 'pose' (15,)
    """
    try:
        from decalib.deca import DECA
        from decalib.utils.config import cfg as deca_cfg

        if deca_model is None:
            deca_cfg.model.use_tex = False
            deca_model = DECA(config=deca_cfg, device='cuda')

        img = cv2.imread(image_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (224, 224))
        img_tensor = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0
        img_tensor = img_tensor.unsqueeze(0).to('cuda')

        with torch.no_grad():
            codedict = deca_model.encode(img_tensor)

        return {
            'shape': codedict['shape'][0].cpu().numpy(),
            'expression': codedict['exp'][0].cpu().numpy(),
            'pose': codedict['pose'][0].cpu().numpy(),
        }
    except Exception as e:
        print(f'DECA fitting failed: {e}')
        print('Using face-alignment fallback...')
        # Fallback: use face-alignment for landmarks ‚Üí approximate FLAME
        return {
            'shape': np.zeros(300),
            'expression': np.zeros(100),
            'pose': np.zeros(15),
        }


def batch_fit_flame(image_dir, output_dir):
    """Fit FLAME params for all images in a directory."""
    os.makedirs(output_dir, exist_ok=True)
    images = sorted(Path(image_dir).glob('*.png')) + sorted(Path(image_dir).glob('*.jpg'))

    print(f'Fitting FLAME to {len(images)} images...')
    for img_path in tqdm(images):
        params = fit_flame_to_image(str(img_path))
        out_path = Path(output_dir) / img_path.with_suffix('.npz').name
        np.savez(str(out_path), **params)

    print(f'‚úì Saved FLAME params to {output_dir}')


print('‚úì FLAME fitting pipeline defined')

## 9. Run Training

In [None]:
# ============================================================
# INITIALIZE & RUN TRAINING
# ============================================================
import torch
from torch.utils.data import DataLoader

# Config
config = {
    'z_id_dim': IDENTITY_DIM,       # 512
    'z_exp_dim': EXPRESSION_DIM,     # 256
    'uv_size': UV_MAP_SIZE,          # 256
    'guide_vertices': GUIDE_MESH_VERTICES,  # 7306
}

# Initialize models
model = QuantumHeadModel(config)
audio_model = Audio2FLAMETransformer()
renderer = GaussianRenderer(uv_size=UV_MAP_SIZE, image_size=512)

# Dataset
dataset = FaceVideoDataset(DATA_ROOT, image_size=512)
dataloader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    drop_last=True,
)

# Print model sizes
def count_params(m):
    return sum(p.numel() for p in m.parameters()) / 1e6

print(f'QuantumHead model:  {count_params(model):.1f}M params')
print(f'Audio2FLAME model:  {count_params(audio_model):.1f}M params')
print(f'Dataset size:       {len(dataset)} samples')
print(f'Batch size:         {BATCH_SIZE}')
print(f'Iterations:         {NUM_ITERATIONS}')
print(f'\nStarting training...')

# Train
final_path = train_quantumhead(
    model=model,
    audio_model=audio_model,
    renderer=renderer,
    dataloader=dataloader,
    num_iterations=NUM_ITERATIONS,
    lr=LEARNING_RATE,
    checkpoint_dir=CHECKPOINT_DIR,
    save_every=5000,
)

print(f'\n‚úÖ Training complete! Final model: {final_path}')

## 10. Push Weights to spike2

In [None]:
# ============================================================
# PUSH TRAINED WEIGHTS TO SPIKE2
# spike2 serves inference ‚Äî Colab just trains
# ============================================================
import os
import json


def push_to_spike2(checkpoint_path, spike2_host, weights_dir):
    """Push trained weights to spike2 via SCP or HTTP upload."""

    # Method 1: Direct upload via spike2 API
    print(f'Uploading {checkpoint_path} to spike2...')

    try:
        import requests
        url = f'https://{spike2_host}/quantumhead/upload-weights'

        file_size = os.path.getsize(checkpoint_path) / 1e6
        print(f'  File size: {file_size:.1f} MB')

        with open(checkpoint_path, 'rb') as f:
            resp = requests.post(
                url,
                files={'weights': (os.path.basename(checkpoint_path), f)},
                timeout=300,
                verify=False,
            )

        if resp.ok:
            result = resp.json()
            print(f'  ‚úì Uploaded to spike2: {result}')
            return True
        else:
            print(f'  ‚úó Upload failed: {resp.status_code} {resp.text}')
    except Exception as e:
        print(f'  API upload failed: {e}')

    # Method 2: SCP via SSH
    print('  Trying SCP fallback...')
    try:
        import paramiko
        from scp import SCPClient

        ssh = paramiko.SSHClient()
        ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
        ssh.connect(SPIKE2_SSH)  # Uses SSH config

        with SCPClient(ssh.get_transport()) as scp_client:
            remote_path = os.path.join(weights_dir, os.path.basename(checkpoint_path))
            scp_client.put(checkpoint_path, remote_path)
            print(f'  ‚úì SCP\'d to spike2:{remote_path}')

        ssh.close()
        return True
    except Exception as e:
        print(f'  SCP failed: {e}')

    # Method 3: Save to Google Drive (manual transfer)
    print('  Saving to Google Drive as fallback...')
    try:
        from google.colab import drive
        drive.mount('/content/drive')
        drive_path = '/content/drive/MyDrive/quantumhead_weights/'
        os.makedirs(drive_path, exist_ok=True)
        import shutil
        shutil.copy2(checkpoint_path, drive_path)
        print(f'  ‚úì Saved to Drive: {drive_path}')
        return True
    except Exception as e:
        print(f'  Drive save failed: {e}')

    return False


# Push the final checkpoint
if os.path.exists(final_path):
    success = push_to_spike2(final_path, SPIKE2_HOST, SPIKE2_WEIGHTS_DIR)
    if success:
        print('\n‚úÖ Weights pushed to spike2! Inference server can now load them.')
    else:
        print('\n‚ö†Ô∏è  Auto-push failed. Download the checkpoint manually:')
        print(f'    {final_path}')
        print(f'    Then SCP to spike2: scp {final_path} spike2:{SPIKE2_WEIGHTS_DIR}/')
else:
    print('No checkpoint found. Run training first.')

## 11. Export for Inference

In [None]:
# ============================================================
# EXPORT MODEL FOR INFERENCE
# Creates a self-contained inference package for spike2
# ============================================================
import torch
import json
import os


def export_inference_package(model, audio_model, config, output_dir):
    """Export everything spike2 needs for inference."""

    pkg_dir = os.path.join(output_dir, 'inference_package')
    os.makedirs(pkg_dir, exist_ok=True)

    # 1. Model weights
    torch.save(model.state_dict(), os.path.join(pkg_dir, 'quantumhead.pt'))
    torch.save(audio_model.state_dict(), os.path.join(pkg_dir, 'audio2flame.pt'))

    # 2. Config
    with open(os.path.join(pkg_dir, 'config.json'), 'w') as f:
        json.dump(config, f, indent=2)

    # 3. Model architecture info for loading
    arch_info = {
        'quantumhead': {
            'class': 'QuantumHeadModel',
            'params': sum(p.numel() for p in model.parameters()),
        },
        'audio2flame': {
            'class': 'Audio2FLAMETransformer',
            'd_model': 512,
            'nhead': 8,
            'num_layers': 6,
            'n_flame_params': 53,
            'params': sum(p.numel() for p in audio_model.parameters()),
        },
        'pipeline': {
            'input': 'audio_wav (16kHz) + source_image (512x512)',
            'output': 'rendered_frames (512x512x3, 25fps)',
            'stages': [
                '1. DECA: image ‚Üí FLAME shape params',
                '2. Wav2Vec2: audio ‚Üí features (1024-dim)',
                '3. Audio2FLAME: features ‚Üí expression params (53-dim/frame)',
                '4. QuantumHead: Z_id + Z_exp ‚Üí UV Gaussian maps (14ch, 256¬≤)',
                '5. GaussianRenderer: UV maps ‚Üí rendered image (512¬≤)',
            ]
        }
    }
    with open(os.path.join(pkg_dir, 'architecture.json'), 'w') as f:
        json.dump(arch_info, f, indent=2)

    # 4. Package size
    total_size = 0
    for f_name in os.listdir(pkg_dir):
        size = os.path.getsize(os.path.join(pkg_dir, f_name))
        total_size += size
        print(f'  {f_name}: {size/1e6:.1f} MB')

    print(f'\n‚úì Inference package: {pkg_dir} ({total_size/1e6:.1f} MB total)')
    return pkg_dir


# Export
pkg_dir = export_inference_package(
    model, audio_model, config, OUTPUT_DIR
)

# Push the package
print('\nPushing inference package to spike2...')
for f_name in os.listdir(pkg_dir):
    f_path = os.path.join(pkg_dir, f_name)
    push_to_spike2(f_path, SPIKE2_HOST, SPIKE2_WEIGHTS_DIR)

## 12. Test Inference (via spike2 API)

In [None]:
# ============================================================
# TEST INFERENCE VIA SPIKE2
# ============================================================
import requests
import time
import base64
from IPython.display import display, Image as IPImage, Video


def test_quantumhead_inference(text, source_image_path=None):
    """Test the full pipeline via spike2 API.

    1. TTS: text ‚Üí audio
    2. QuantumHead: audio + image ‚Üí video
    """
    base_url = f'https://{SPIKE2_HOST}'

    # Step 1: Generate speech audio
    print('Step 1: Generating speech...')
    tts_resp = requests.post(
        f'{base_url}/speak',
        json={
            'text': text,
            'voice_id': '960f89fc',  # Riley's voice
        },
        timeout=60,
        verify=False,
    )
    if not tts_resp.ok:
        print(f'TTS failed: {tts_resp.status_code}')
        return

    audio_data = tts_resp.json()
    audio_url = audio_data.get('audio_url', audio_data.get('url'))
    print(f'  Audio: {audio_url}')

    # Step 2: Generate avatar video via QuantumHead
    print('Step 2: Generating avatar video...')
    qh_payload = {
        'audio_url': audio_url,
        'model': 'quantumhead',
    }
    if source_image_path:
        with open(source_image_path, 'rb') as f:
            qh_payload['source_image'] = base64.b64encode(f.read()).decode()

    qh_resp = requests.post(
        f'{base_url}/quantumhead/generate',
        json=qh_payload,
        timeout=120,
        verify=False,
    )

    if qh_resp.ok:
        result = qh_resp.json()
        video_url = result.get('video_url')
        print(f'  ‚úì Video: {video_url}')

        # Download and display
        if video_url:
            vid_data = requests.get(video_url, verify=False).content
            with open('/tmp/quantumhead_test.mp4', 'wb') as f:
                f.write(vid_data)
            display(Video('/tmp/quantumhead_test.mp4', embed=True, width=512))
    else:
        print(f'  ‚úó Generation failed: {qh_resp.status_code}')
        print(f'    {qh_resp.text}')


# Test it!
test_quantumhead_inference(
    "Hello! I'm a 3D Gaussian Splatting avatar powered by QuantumHead. "
    "Built with FLAME parametric models and trained on an A100 GPU."
)