# **CPU AvatarArtist3: Triplane Decomposition**

https://kumapowerliu.github.io/AvatarArtist/

---

## **Step 3 Pipeline Explanation: Data Synthesis for DiT Training**

## Overview
This pipeline uses your **trained Next3D model from Step 2** to synthesize a large dataset of (image, triplane) pairs. This synthetic data will be used to train a **DiT (Diffusion Transformer)** model in Step 4 for controllable 3D-aware generation.

---

## Pipeline Architecture

### **1. Device Management (CPU/GPU Compatible)**

```python
class DeviceManager:
    - Auto-detects GPU/CPU
    - Optimizes batch size based on hardware
    - Manages memory efficiently
```

**Key Features:**
- **GPU Mode**: Faster synthesis (8-16 samples/batch)
- **CPU Mode**: Slower but accessible (2-4 samples/batch)
- Automatic memory clearing to prevent OOM errors

**Device Selection:**
```python
FORCE_CPU = False  # Set to True to force CPU
# Auto-detects: CUDA if available, else CPU
```

---

### **2. Model Components**

#### **A) Generator Loading**
```python
load_generator(generator_path)
```
- Loads your trained Next3D model from Step 2
- Checkpoint: `./next3d_checkpoints/final_model.pt`
- Sets model to **eval mode** (no training, only inference)
- Disables gradients for efficiency

#### **B) Triplane Decomposer**
```python
ParametricTriplaneDecomposer(
    triplane_channels=96,
    static_ratio=0.7  # 70% static, 30% dynamic
)
```

**Purpose**: Splits triplane features into two components:

| Component | Channels | Represents | Examples |
|-----------|----------|------------|----------|
| **Static** | 67 (70%) | Identity features (unchanging) | Face shape, bone structure, skin texture |
| **Dynamic** | 29 (30%) | Expression/Pose features | Smile, frown, head rotation, eye gaze |

**Why decompose?**
- Enables **disentangled control** in DiT model
- Static features: Control "who" the person is
- Dynamic features: Control "what expression/pose" they have

---

### **3. Data Synthesis Process**

#### **Workflow Per Sample:**

```
1. Sample Random Inputs:
   ├─ z: Random latent code [512-dim]
   ├─ shape: 3DMM identity parameters [80-dim]
   ├─ exp: Expression parameters [64-dim]
   └─ pose: Head pose parameters [6-dim]

2. Generate with Next3D:
   ├─ Pass (z, shape, exp) → Generator
   ├─ Output: Image (512×512) + Triplane (96×64×64)
   └─ Image represents the avatar face

3. Decompose Triplane:
   ├─ Split triplane → Static + Dynamic
   ├─ Static: Identity-related features
   └─ Dynamic: Expression/pose-related features

4. Save Outputs:
   ├─ Image: image_000001.png
   ├─ Triplane data: triplane_000001.npz
   │   ├─ static: [67, 64, 64]
   │   ├─ dynamic: [29, 64, 64]
   │   ├─ z: [512]
   │   ├─ shape: [80]
   │   ├─ exp: [64]
   │   └─ pose: [6]
   └─ Metadata: Links image to triplane
```

---

## Key Parameters

### **Configuration Settings**

```python
# Required
GENERATOR_PATH = "./next3d_checkpoints/final_model.pt"
OUTPUT_DIR = "./dit_training_data"

# Synthesis settings
NUM_SAMPLES = 30              # Number of samples to generate
BATCH_SIZE = None             # Auto-adjusted (GPU: 8, CPU: 2)
TRIPLANE_RESOLUTION = 256     # Triplane spatial resolution
USE_3DMM = True               # Use 3DMM parameters

# Output format
SAVE_IMAGES = True            # Save PNG images
SAVE_FORMAT = 'npz'           # 'npz' or 'pth' for triplanes
```

### **Parameter Guide**

| Parameter | Recommended | Purpose | Notes |
|-----------|-------------|---------|-------|
| `NUM_SAMPLES` | 1000-10000 | DiT training samples | More = better DiT quality |
| `BATCH_SIZE` | GPU: 8-16, CPU: 2 | Synthesis speed | Auto-adjusted if None |
| `TRIPLANE_RESOLUTION` | 256 | Triplane detail | Must match Step 2 |
| `USE_3DMM` | True | Enable control | False for simpler pipeline |
| `SAVE_FORMAT` | 'npz' | Compression | 'npz' is smaller than 'pth' |

---

## Output Structure

### **Directory Layout**

```
dit_training_data/
├── images/                    # Generated avatar images
│   ├── image_000000.png      # Sample 0
│   ├── image_000001.png      # Sample 1
│   └── ...
│
├── triplanes/                 # 3D representations
│   ├── triplane_000000.npz   # Compressed triplane data
│   ├── triplane_000001.npz
│   └── ...
│
├── metadata.json              # Links images to triplanes
└── dataset_info.json          # Dataset configuration
```

### **File Contents**

#### **1. Images (PNG files)**
- **Resolution**: 512×512 pixels
- **Format**: RGB, 8-bit
- **Content**: Generated avatar faces in your artistic style
- **Purpose**: Visual verification + optional conditioning

#### **2. Triplane Data (NPZ files)**
Each `.npz` file contains:
```python
{
    'static': [67, 64, 64],     # Identity features
    'dynamic': [29, 64, 64],    # Expression/pose features
    'z': [512],                  # Latent code
    'shape': [80],               # 3DMM shape params
    'exp': [64],                 # 3DMM expression params
    'pose': [6]                  # 3DMM pose params
}
```

#### **3. metadata.json**
```json
[
  {
    "index": 0,
    "triplane_path": "triplanes/triplane_000000.npz",
    "image_path": "images/image_000000.png",
    "static_shape": [67, 64, 64],
    "dynamic_shape": [29, 64, 64]
  },
  ...
]
```

#### **4. dataset_info.json**
```json
{
  "num_samples": 1000,
  "triplane_resolution": 256,
  "static_channels": 67,
  "dynamic_channels": 29,
  "use_3dmm": true,
  "format": "npz",
  "device": "cuda"
}
```

---

## What the Result Images Indicate

### **Quality Indicators**

#### **✅ Good Results:**

1. **Diversity**
   - Different face shapes and identities
   - Various expressions (neutral, smiling, etc.)
   - Multiple head poses (frontal, profile, tilted)
   - Different ages/genders (if trained on diverse data)

2. **Consistency with Style**
   - Images match your avatar style from Step 1
   - Artistic qualities preserved (line style, coloring, shading)
   - Character design elements maintained

3. **Quality Metrics**
   - Clear facial features (eyes, nose, mouth)
   - No artifacts or distortions
   - Proper proportions
   - Sharp details

#### **❌ Problem Signs:**

1. **Mode Collapse**
   - All faces look nearly identical
   - Limited variation in expression/pose
   - **Solution**: Retrain Step 2 with more epochs or different settings

2. **Quality Issues**
   - Blurry images
   - Missing facial features
   - Distorted proportions
   - **Solution**: Step 2 training didn't converge properly

3. **Style Inconsistency**
   - Images don't match target style
   - Realistic instead of stylized
   - **Solution**: Check Step 1 style transfer quality

---

## Sample Analysis

### **Example Output Interpretation**

Imagine you generated 30 samples:

```
Sample 0: Female face, neutral expression, frontal view
Sample 1: Male face, smiling, slight left turn
Sample 2: Female face, serious, tilted head
...
```

**What to check:**

| Aspect | What to Look For | Interpretation |
|--------|------------------|----------------|
| **Identity Variety** | 30 different-looking people | Static triplane is diverse ✓ |
| **Expression Range** | Neutral, happy, sad, surprised | Dynamic triplane captures expressions ✓ |
| **Pose Diversity** | Frontal, profile, tilted views | Model handles 3D rotations ✓ |
| **Style Consistency** | All images look like your avatar style | Step 2 training successful ✓ |
| **Image Quality** | Sharp, clear, no artifacts | Generator is stable ✓ |

---

## Performance Expectations

### **Synthesis Speed**

| Hardware | Batch Size | Time/Sample | 1000 Samples |
|----------|-----------|-------------|--------------|
| CPU (8 cores) | 2 | ~10 sec | ~1.5 hours |
| GPU (GTX 1080) | 8 | ~0.5 sec | ~1 minute |
| GPU (RTX 3090) | 16 | ~0.2 sec | ~15 seconds |
| GPU (A100) | 32 | ~0.1 sec | ~5 seconds |

### **Storage Requirements**

For 1000 samples:
- **Images (PNG)**: ~1-2 GB
- **Triplanes (NPZ)**: ~3-5 GB
- **Total**: ~5-7 GB

For 10,000 samples:
- **Total**: ~50-70 GB

---

## Usage After Synthesis

### **Data Inspection**

```python
# Load a sample
import numpy as np
from PIL import Image

# Load image
img = Image.open('dit_training_data/images/image_000000.png')
img.show()

# Load triplane
data = np.load('dit_training_data/triplanes/triplane_000000.npz')
static = data['static']    # [67, 64, 64]
dynamic = data['dynamic']  # [29, 64, 64]
shape = data['shape']      # [80]
exp = data['exp']          # [64]

print(f"Static range: [{static.min():.3f}, {static.max():.3f}]")
print(f"Dynamic range: [{dynamic.min():.3f}, {dynamic.max():.3f}]")
```

### **Validation Checks**

```python
# Check dataset completeness
import json

with open('dit_training_data/metadata.json') as f:
    metadata = json.load(f)

print(f"Total samples: {len(metadata)}")

# Verify all files exist
for item in metadata:
    assert os.path.exists(item['triplane_path'])
    if item['image_path']:
        assert os.path.exists(item['image_path'])

print("✓ All files present")
```

---

## Common Issues & Solutions

### **Problem 1: Out of Memory (OOM)**

**Symptoms:**
```
RuntimeError: CUDA out of memory
```

**Solutions:**
1. Reduce `BATCH_SIZE` manually (e.g., 4 → 2)
2. Lower `NUM_SAMPLES` per run, synthesize in multiple runs
3. Set `FORCE_CPU = True` (slower but works)
4. Clear memory between batches (already implemented)

---

### **Problem 2: Poor Quality Images**

**Symptoms:**
- Blurry faces
- Missing features
- Style doesn't match

**Solutions:**
1. **Retrain Step 2**: Model didn't converge properly
2. **Check Step 1**: Style transfer quality was poor
3. **Verify checkpoint**: Ensure you're using the correct model file

---

### **Problem 3: Low Diversity**

**Symptoms:**
- All faces look similar
- Same expression/pose repeated

**Solutions:**
1. Check `sample_3dmm_params()` ranges:
   ```python
   pose_range=0.3      # Increase to 0.5 for more pose variety
   exp_strength=1.0    # Increase to 1.5 for stronger expressions
   ```
2. Step 2 may have mode collapse - retrain with stronger R1 regularization

---

### **Problem 4: Slow Synthesis on CPU**

**Expected:** CPU is 50-100× slower than GPU

**Solutions:**
1. **Use GPU**: Cloud GPU (Google Colab, AWS, etc.)
2. **Reduce samples**: Start with 50-100 for testing
3. **Increase batch size**: CPU can handle 2-4 if you have 16GB+ RAM
4. **Synthesize overnight**: Let it run for large datasets

---

## Preparation for Step 4 (DiT Training)

### **What You Need:**

✅ **Generated dataset** in `dit_training_data/`  
✅ **Minimum 1000 samples** (more is better)  
✅ **Quality validation**: Check image diversity and quality  
✅ **Disk space**: Ensure 10-100 GB available for larger datasets

### **Next Step Preview:**

Step 4 will:
1. Load your synthetic (image, triplane) pairs
2. Train a DiT model to **predict triplanes from noise**
3. Enable controllable generation: "Generate an avatar with expression X and pose Y"
4. Support **progressive refinement** during inference

---

## Summary

This pipeline:

✅ **Loads** your trained Next3D generator from Step 2  
✅ **Generates** thousands of (image, triplane) pairs  
✅ **Decomposes** triplanes into static (identity) + dynamic (expression) components  
✅ **Saves** everything needed for DiT training  
✅ **Works on** both CPU and GPU (with auto-optimization)

**Result Images Indicate:**
- **Diversity** → Model learned to generate varied identities/expressions
- **Quality** → Step 2 training was successful
- **Style match** → Your avatar style is preserved
- **Variety in poses** → 3D consistency is maintained

**Recommended:** Generate 1000-5000 samples for optimal DiT training quality in Step 4.

---

In [None]:
!pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118
!pip install diffusers transformers accelerate
!pip install controlnet-aux opencv-python pillow
!pip install mediapipe==0.10.13

In [None]:
final_model_path ='/kaggle/input/cpu-avatarartist2-next3d-4d-gan-fine-tuning/next3d_checkpoints/final_model.pt'

In [None]:
import shutil
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from PIL import Image
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from tqdm import tqdm
import json
import gc
import psutil
import warnings

# ==================== CPU/GPU Compatibility Helper ====================

class DeviceManager:
    """Manages device selection and memory optimization"""
    
    def __init__(self, force_cpu: bool = False):
        self.force_cpu = force_cpu
        self.device = self._select_device()
        self.is_cuda = self.device.type == 'cuda'
        
        self._log_device_info()
    
    def _select_device(self):
        """Select appropriate device"""
        if self.force_cpu:
            return torch.device('cpu')
        
        if torch.cuda.is_available():
            return torch.device('cuda')
        else:
            return torch.device('cpu')
    
    def _log_device_info(self):
        """Log device information"""
        print(f"\n{'='*60}")
        print(f"Device Configuration")
        print(f"{'='*60}")
        print(f"Selected Device: {self.device}")
        
        if self.is_cuda:
            print(f"GPU Name: {torch.cuda.get_device_name(0)}")
            print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
        else:
            mem = psutil.virtual_memory()
            print(f"CPU RAM: {mem.total / 1e9:.2f} GB")
            print(f"Available RAM: {mem.available / 1e9:.2f} GB")
            print(f"\n⚠ Running on CPU - synthesis will be slower")
        
        print(f"{'='*60}\n")
    
    def get_optimal_batch_size(self, default_gpu: int = 8, default_cpu: int = 2):
        """Get optimal batch size based on device"""
        if self.is_cuda:
            return default_gpu
        else:
            # CPU: Use smaller batch size
            return default_cpu
    
    def clear_memory(self):
        """Clear GPU/CPU memory"""
        if self.is_cuda:
            torch.cuda.empty_cache()
        gc.collect()


# ==================== Model Definitions (CPU Compatible) ====================

class MappingNetwork(nn.Module):
    """StyleGAN2-style Mapping Network"""
    
    def __init__(self, z_dim: int = 512, w_dim: int = 512, num_layers: int = 8):
        super().__init__()
        layers = []
        for i in range(num_layers):
            in_dim = z_dim if i == 0 else w_dim
            layers.extend([
                nn.Linear(in_dim, w_dim),
                nn.LeakyReLU(0.2)
            ])
        self.net = nn.Sequential(*layers)
    
    def forward(self, z):
        return self.net(z)


class NoiseInjection(nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(1))
    
    def forward(self, x):
        noise = torch.randn(x.shape[0], 1, x.shape[2], x.shape[3], 
                           device=x.device, dtype=x.dtype)
        return x + self.weight * noise


class StyleBlock(nn.Module):
    """StyleGAN2-style Block"""
    
    def __init__(self, in_ch, out_ch, w_dim):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.style1 = nn.Linear(w_dim, in_ch)
        self.style2 = nn.Linear(w_dim, out_ch)
        self.noise1 = NoiseInjection()
        self.noise2 = NoiseInjection()
        self.activation = nn.LeakyReLU(0.2)
    
    def forward(self, x, w):
        s1 = self.style1(w).unsqueeze(-1).unsqueeze(-1)
        x = self.conv1(x * s1)
        x = self.noise1(x)
        x = self.activation(x)
        
        s2 = self.style2(w).unsqueeze(-1).unsqueeze(-1)
        x = self.conv2(x * s2)
        x = self.noise2(x)
        x = self.activation(x)
        
        return x


class TriplaneBackbone(nn.Module):
    """Triplane Generation Backbone"""
    
    def __init__(self, w_dim: int, channels: int, resolution: int):
        super().__init__()
        self.w_dim = w_dim
        self.channels = channels
        self.resolution = resolution
        
        self.const = nn.Parameter(torch.randn(1, channels, 4, 4))
        
        self.blocks = nn.ModuleList([
            StyleBlock(channels, channels, w_dim),
            StyleBlock(channels, channels, w_dim),
            StyleBlock(channels, channels, w_dim),
            StyleBlock(channels, channels, w_dim),
        ])
        
        self.to_features = nn.Conv2d(channels, channels, 1)
    
    def forward(self, w):
        B = w.shape[0]
        x = self.const.repeat(B, 1, 1, 1)
        
        for block in self.blocks:
            x = block(x, w)
            x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        
        features = self.to_features(x)
        return features


class SuperResolutionModule(nn.Module):
    """Super-resolution module: 64x64 -> 512x512"""
    
    def __init__(self, in_channels, output_resolution=512):
        super().__init__()
        self.conv_blocks = nn.Sequential(
            nn.Conv2d(in_channels, 128, 3, padding=1),
            nn.LeakyReLU(0.2),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            
            nn.Conv2d(128, 64, 3, padding=1),
            nn.LeakyReLU(0.2),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            
            nn.Conv2d(64, 32, 3, padding=1),
            nn.LeakyReLU(0.2),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            
            nn.Conv2d(32, 3, 3, padding=1),
            nn.Tanh()
        )
    
    def forward(self, x):
        return self.conv_blocks(x)


class TriplaneGenerator(nn.Module):
    """Next3D Triplane Generator"""
    
    def __init__(
        self,
        z_dim: int = 512,
        w_dim: int = 512,
        triplane_channels: int = 96,
        triplane_resolution: int = 256,
        use_3dmm: bool = True,
        shape_dim: int = 80,
        exp_dim: int = 64
    ):
        super().__init__()
        self.z_dim = z_dim
        self.w_dim = w_dim
        self.use_3dmm = use_3dmm
        
        self.mapping = MappingNetwork(z_dim, w_dim)
        
        if use_3dmm:
            self.shape_encoder = nn.Linear(shape_dim, w_dim)
            self.exp_encoder = nn.Linear(exp_dim, w_dim)
            self.condition_fusion = nn.Linear(w_dim * 3, w_dim)
        
        self.triplane_generator = TriplaneBackbone(
            w_dim=w_dim,
            channels=triplane_channels,
            resolution=triplane_resolution
        )
        
        self.superres = SuperResolutionModule(
            triplane_channels, 
            output_resolution=512
        )
    
    def forward(
        self,
        z: torch.Tensor,
        shape: Optional[torch.Tensor] = None,
        exp: Optional[torch.Tensor] = None,
        c: Optional[torch.Tensor] = None
    ):
        w = self.mapping(z)
        
        if self.use_3dmm and shape is not None and exp is not None:
            shape_feat = self.shape_encoder(shape)
            exp_feat = self.exp_encoder(exp)
            w_conditioned = self.condition_fusion(
                torch.cat([w, shape_feat, exp_feat], dim=1)
            )
        else:
            w_conditioned = w
        
        triplane_features = self.triplane_generator(w_conditioned)
        image = self.superres(triplane_features)
        
        return {
            'image': image,
            'triplane': triplane_features,
            'w': w_conditioned
        }


class ParametricTriplaneDecomposer(nn.Module):
    """Decomposes Triplanes into Static and Dynamic components"""
    
    def __init__(
        self,
        triplane_channels: int = 96,
        triplane_resolution: int = 256,
        static_ratio: float = 0.7
    ):
        super().__init__()
        self.channels = triplane_channels
        self.resolution = triplane_resolution
        
        self.static_channels = int(triplane_channels * static_ratio)
        self.dynamic_channels = triplane_channels - self.static_channels
        
        print(f"Triplane Decomposition:")
        print(f"  Static: {self.static_channels}ch")
        print(f"  Dynamic: {self.dynamic_channels}ch")
    
    def decompose(
        self,
        triplane: torch.Tensor,
        pose: torch.Tensor,
        exp: torch.Tensor
    ) -> Dict[str, torch.Tensor]:
        static = triplane[:, :self.static_channels]
        dynamic = triplane[:, self.static_channels:]
        
        return {
            'static': static,
            'dynamic': dynamic,
            'pose': pose,
            'exp': exp
        }
    
    def reconstruct(self, decomposed: Dict[str, torch.Tensor]) -> torch.Tensor:
        return torch.cat([
            decomposed['static'],
            decomposed['dynamic']
        ], dim=1)


# ==================== CPU-Optimized Data Synthesizer ====================

class Next3DDataSynthesizer:
    """Synthesizes training data pairs from a Next3D model (CPU Compatible)"""
    
    def __init__(
        self,
        generator_path: str,
        device: Optional[str] = None,
        triplane_resolution: int = 256,
        use_3dmm: bool = True,
        force_cpu: bool = False
    ):
        print("=" * 60)
        print("Next3D Data Synthesizer Initialization (CPU Compatible)")
        print("=" * 60)
        
        # Device setup
        self.device_manager = DeviceManager(force_cpu=force_cpu)
        self.device = self.device_manager.device
        self.resolution = triplane_resolution
        self.use_3dmm = use_3dmm
        
        # Load model
        print(f"\nLoading model: {generator_path}")
        self.load_generator(generator_path)
        
        # Decomposer
        self.decomposer = ParametricTriplaneDecomposer(
            triplane_channels=96,
            triplane_resolution=triplane_resolution
        )
        
        print("\n✓ Initialization Complete!")
    
    def load_generator(self, path: str):
        """Loads the Generator model"""
        # Load to CPU first to avoid OOM on GPU
        checkpoint = torch.load(path, map_location='cpu')
        
        self.generator = TriplaneGenerator(
            z_dim=512,
            w_dim=512,
            triplane_channels=96,
            triplane_resolution=self.resolution,
            use_3dmm=self.use_3dmm
        )
        
        # Load weights
        if 'generator' in checkpoint:
            state_dict = checkpoint['generator']
        else:
            state_dict = checkpoint
        
        try:
            self.generator.load_state_dict(state_dict, strict=True)
            print("✓ Generator loaded successfully (strict mode)")
        except RuntimeError as e:
            print(f"⚠ Strict loading failed, trying flexible loading...")
            self.generator.load_state_dict(state_dict, strict=False)
            print("✓ Generator loaded successfully (flexible mode)")
        
        # Move to device after loading
        self.generator = self.generator.to(self.device)
        self.generator.eval()
        
        total_params = sum(p.numel() for p in self.generator.parameters())
        print(f"✓ Total parameters: {total_params:,}")
        
        # Set to eval mode and disable gradients
        for param in self.generator.parameters():
            param.requires_grad = False
    
    def sample_3dmm_params(
        self,
        batch_size: int = 1,
        pose_range: float = 0.3,
        exp_strength: float = 1.0
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Samples 3DMM parameters"""
        shape = torch.randn(batch_size, 80, device=self.device) * 0.5
        exp = torch.randn(batch_size, 64, device=self.device) * exp_strength
        pose = torch.rand(batch_size, 6, device=self.device) * 2 * pose_range - pose_range
        
        return shape, exp, pose
    
    def generate_sample(
        self,
        z: Optional[torch.Tensor] = None,
        shape: Optional[torch.Tensor] = None,
        exp: Optional[torch.Tensor] = None,
        pose: Optional[torch.Tensor] = None
    ) -> Dict[str, torch.Tensor]:
        """Generates a single sample"""
        with torch.no_grad():
            if z is None:
                z = torch.randn(1, 512, device=self.device)
            
            if self.use_3dmm:
                if shape is None or exp is None or pose is None:
                    shape, exp, pose = self.sample_3dmm_params(z.shape[0])
            else:
                shape, exp, pose = None, None, None
            
            output = self.generator(z, shape, exp)
            
            image = output['image']
            triplane = output['triplane']
            
            if self.use_3dmm:
                decomposed = self.decomposer.decompose(triplane, pose, exp)
            else:
                decomposed = {
                    'static': triplane,
                    'dynamic': torch.zeros_like(triplane[:, :10]),
                    'pose': torch.zeros(z.shape[0], 6, device=self.device),
                    'exp': torch.zeros(z.shape[0], 64, device=self.device)
                }
            
            return {
                'image': image,
                'triplane_static': decomposed['static'],
                'triplane_dynamic': decomposed['dynamic'],
                'z': z,
                'shape': shape if shape is not None else torch.zeros(z.shape[0], 80, device=self.device),
                'exp': exp if exp is not None else torch.zeros(z.shape[0], 64, device=self.device),
                'pose': pose if pose is not None else torch.zeros(z.shape[0], 6, device=self.device)
            }
    
    def synthesize_dataset(
        self,
        num_samples: int,
        output_dir: str,
        batch_size: Optional[int] = None,
        save_images: bool = True,
        save_format: str = 'npz'
    ):
        """Synthesizes a large volume of data pairs"""
        
        # Auto-adjust batch size for CPU
        if batch_size is None:
            batch_size = self.device_manager.get_optimal_batch_size()
        
        print(f"\n{'='*60}")
        print(f"Dataset Synthesis Configuration")
        print(f"{'='*60}")
        print(f"Samples: {num_samples}")
        print(f"Batch Size: {batch_size}")
        print(f"Device: {self.device}")
        print(f"Output: {output_dir}")
        print(f"Format: {save_format}")
        print(f"{'='*60}\n")
        
        os.makedirs(output_dir, exist_ok=True)
        
        if save_images:
            img_dir = os.path.join(output_dir, 'images')
            os.makedirs(img_dir, exist_ok=True)
        
        triplanes_dir = os.path.join(output_dir, 'triplanes')
        os.makedirs(triplanes_dir, exist_ok=True)
        
        num_batches = (num_samples + batch_size - 1) // batch_size
        sample_idx = 0
        metadata = []
        
        for batch_idx in tqdm(range(num_batches), desc="Synthesizing"):
            current_batch_size = min(batch_size, num_samples - sample_idx)
            
            # Generate batch
            z = torch.randn(current_batch_size, 512, device=self.device)
            
            if self.use_3dmm:
                shape, exp, pose = self.sample_3dmm_params(current_batch_size)
            else:
                shape, exp, pose = None, None, None
            
            samples = self.generate_sample(z, shape, exp, pose)
            
            # Save each sample
            for i in range(current_batch_size):
                idx = sample_idx + i
                
                triplane_data = {
                    'static': samples['triplane_static'][i].cpu().numpy(),
                    'dynamic': samples['triplane_dynamic'][i].cpu().numpy(),
                    'z': samples['z'][i].cpu().numpy(),
                    'shape': samples['shape'][i].cpu().numpy(),
                    'exp': samples['exp'][i].cpu().numpy(),
                    'pose': samples['pose'][i].cpu().numpy(),
                }
                
                triplane_path = os.path.join(
                    triplanes_dir,
                    f'triplane_{idx:06d}.{save_format}'
                )
                
                if save_format == 'npz':
                    np.savez_compressed(triplane_path, **triplane_data)
                elif save_format == 'pth':
                    torch.save({
                        k: torch.from_numpy(v) for k, v in triplane_data.items()
                    }, triplane_path)
                
                if save_images:
                    image = samples['image'][i]
                    image = (image.cpu().permute(1, 2, 0).numpy() + 1) / 2 * 255
                    image = image.clip(0, 255).astype(np.uint8)
                    
                    img_path = os.path.join(img_dir, f'image_{idx:06d}.png')
                    Image.fromarray(image).save(img_path)
                
                metadata.append({
                    'index': idx,
                    'triplane_path': triplane_path,
                    'image_path': os.path.join(img_dir, f'image_{idx:06d}.png') if save_images else None,
                    'static_shape': triplane_data['static'].shape,
                    'dynamic_shape': triplane_data['dynamic'].shape,
                })
            
            sample_idx += current_batch_size
            
            # Clear memory after each batch (important for CPU)
            del samples, z
            if shape is not None:
                del shape, exp, pose
            self.device_manager.clear_memory()
        
        # Save metadata
        metadata_path = os.path.join(output_dir, 'metadata.json')
        with open(metadata_path, 'w') as f:
            json.dump(metadata, f, indent=2)
        
        info = {
            'num_samples': num_samples,
            'triplane_resolution': self.resolution,
            'static_channels': self.decomposer.static_channels,
            'dynamic_channels': self.decomposer.dynamic_channels,
            'use_3dmm': self.use_3dmm,
            'format': save_format,
            'device': str(self.device)
        }
        
        info_path = os.path.join(output_dir, 'dataset_info.json')
        with open(info_path, 'w') as f:
            json.dump(info, f, indent=2)
        
        print(f"\n{'='*60}")
        print("✓ Dataset Synthesis Complete!")
        print(f"{'='*60}")
        print(f"Generated Samples: {num_samples}")
        print(f"Triplanes: {triplanes_dir}")
        if save_images:
            print(f"Images: {img_dir}")
        print(f"Metadata: {metadata_path}")
        print(f"Dataset Info: {info_path}")
        print(f"{'='*60}\n")

# ==================== Main Execution ====================

def main():
    """Execute Data Synthesis"""
    
    # Path settings
    GENERATOR_PATH = final_model_path
    
    if not os.path.exists(GENERATOR_PATH):
        print(f"❌ Model file not found: {GENERATOR_PATH}")
        return
    
    print(f"✓ Model found: {GENERATOR_PATH}")
    
    # Configuration
    OUTPUT_DIR = "./dit_training_data"
    NUM_SAMPLES = 30  # Adjust based on your needs
    SAVE_IMAGES = True
    SAVE_FORMAT = 'npz'
    TRIPLANE_RESOLUTION = 256
    USE_3DMM = True
    
    # Force CPU mode (set to False to auto-detect)
    FORCE_CPU = False  # Set to True to force CPU usage
    
    # Initialize synthesizer
    try:
        synthesizer = Next3DDataSynthesizer(
            generator_path=GENERATOR_PATH,
            triplane_resolution=TRIPLANE_RESOLUTION,
            use_3dmm=USE_3DMM,
            force_cpu=FORCE_CPU
        )
    except Exception as e:
        print(f"\n❌ Initialization failed: {e}")
        import traceback
        traceback.print_exc()
        return
    
    # Synthesize dataset (batch_size will be auto-adjusted)
    synthesizer.synthesize_dataset(
        num_samples=NUM_SAMPLES,
        output_dir=OUTPUT_DIR,
        batch_size=None,  # Auto-adjust based on device
        save_images=SAVE_IMAGES,
        save_format=SAVE_FORMAT
    )
    
    print("\n✓ All operations complete!")


if __name__ == "__main__":
    main()