# **CPU AvatarArtist5: Avatar Generation & Inference**

https://kumapowerliu.github.io/AvatarArtist/

---

## **Step 5 Pipeline Explanation: 3D Avatar Generation & Inference**

## Overview
This is the **final inference pipeline** that brings everything together. It takes your **trained DiT model from Step 4** and a **2D avatar image**, then generates a complete **3D avatar** that can be viewed from any angle, animated, and exported as a 3D mesh file.

---

## Pipeline Flow

```
Input: 2D Avatar Image (512√ó512)
   ‚Üì
[DiT Model] Diffusion Sampling (1000 steps ‚Üí 50 steps)
   ‚Üì
Output: 3D Triplane (96√ó256√ó256)
   ‚Üì
[Triplane Decoder] Volume Rendering
   ‚Üì
Output: 3D Volume (64√ó64√ó64 density grid)
   ‚Üì
[Marching Cubes] Surface Extraction
   ‚Üì
Output: 3D Mesh (vertices + faces)
   ‚Üì
Save: OBJ file, Visualizations
```

---

## Key Components

### **1. Model Loading & Initialization**

```python
AvatarGenerator(
    model_path="dit_checkpoints/best_model.pt",
    device='cuda',
    model_config=None  # Auto-detected
)
```

**Auto-Detection Logic:**
- Counts model parameters
- **< 50M params** ‚Üí Small model (384-dim, 6 layers)
- **> 50M params** ‚Üí Standard model (768-dim, 12 layers)

**Loaded Components:**
- **DiT Model**: Trained diffusion transformer
- **Diffusion Scheduler**: 1000-step noise schedule
- **Mesh Generator**: Triplane ‚Üí Volume ‚Üí Mesh converter
- **Image Preprocessor**: Resize, normalize to [-1, 1]

---

### **2. Diffusion Sampling Process**

#### **What Happens:**
```python
# Start: Random noise [1, 96, 256, 256]
x_T = torch.randn(1, 96, 256, 256)

# Iteratively denoise
for t in [1000, 980, 960, ..., 20, 0]:  # 50 steps
    noise_pred = model(x_t, t, condition_image)
    x_{t-1} = remove_noise(x_t, noise_pred)

# Final: Clean triplane [1, 96, 256, 256]
```

**Key Parameters:**

| Parameter | Default | Fast | High Quality | Purpose |
|-----------|---------|------|--------------|---------|
| `num_sampling_steps` | 50 | 25 | 100-250 | Denoising iterations |
| `seed` | None | 42 | Random | Reproducibility |

**Time vs Quality Trade-off:**
- **25 steps**: 5-10 seconds, preview quality
- **50 steps**: 15-20 seconds, good quality (recommended)
- **100 steps**: 30-40 seconds, high quality
- **250 steps**: 1-2 minutes, maximum quality

---

### **3. Triplane Structure**

#### **Understanding Triplanes:**

A triplane is a **3D representation using three 2D planes**:

```
Triplane [96 channels, 256√ó256]
‚îú‚îÄ XY Plane [32 channels]: Front view features
‚îú‚îÄ XZ Plane [32 channels]: Top view features
‚îî‚îÄ YZ Plane [32 channels]: Side view features
```

**Why Three Planes?**
- **Efficient**: 3√ó256√ó256 << 256√ó256√ó256 (3D volume)
- **Complete**: Any 3D point can be projected onto all three planes
- **Expressive**: Captures 3D structure with 2D operations

**Channel Decomposition:**
```python
Static Channels (67): Identity features
  - Face shape
  - Bone structure
  - Skin texture
  - Hair style (unchanging)

Dynamic Channels (29): Expression/Pose features
  - Facial expressions
  - Eye gaze
  - Mouth movement
  - Head rotation
```

---

### **4. Triplane ‚Üí Volume Conversion**

#### **Process:**

```python
triplane_to_volume(triplane, grid_size=64)
```

**Step-by-Step:**

1. **Create 3D Grid**
   ```python
   # Generate 64√ó64√ó64 sampling points in [-1, 1]¬≥
   coords = create_grid(64)  # [262,144 points]
   ```

2. **Sample from Each Plane**
   ```python
   # For each 3D point (x, y, z):
   feat_xy = sample_plane(xy_plane, (x, y))
   feat_xz = sample_plane(xz_plane, (x, z))
   feat_yz = sample_plane(yz_plane, (y, z))
   ```

3. **Aggregate Features**
   ```python
   density = feat_xy + feat_xz + feat_yz
   density = sigmoid(density)  # ‚Üí [0, 1]
   ```

4. **Reshape to Volume**
   ```python
   volume = density.reshape(64, 64, 64)
   ```

**Volume Interpretation:**
- **Value 0.0**: Empty space (air)
- **Value 0.5**: Surface boundary
- **Value 1.0**: Solid material (inside head)

---

### **5. Mesh Extraction (Marching Cubes)**

#### **Algorithm:**

**Marching Cubes** is a classic algorithm that converts volume density into triangle meshes:

```python
extract_mesh(volume, threshold=0.5)
```

**How It Works:**

1. **Scan the Volume**
   - Process each 2√ó2√ó2 cube of voxels
   - 8 corner vertices per cube

2. **Classify Corners**
   ```
   Corner > threshold ‚Üí Inside surface
   Corner < threshold ‚Üí Outside surface
   ```

3. **Generate Triangles**
   - Look up triangle configuration (256 possible cases)
   - Place vertices on edges where surface crosses
   - Connect vertices to form triangles

4. **Output Mesh**
   ```python
   vertices: [N, 3]  # 3D coordinates
   faces: [M, 3]     # Triangle indices
   ```

**Threshold Selection:**

| Threshold | Effect | Use Case |
|-----------|--------|----------|
| Low (30th percentile) | Thicker mesh, more detail | High-quality models |
| Medium (50th percentile) | Balanced | Default |
| High (70th percentile) | Thinner mesh, cleaner | Simplified models |

**Adaptive Strategy:**
```python
# Try multiple thresholds automatically
for percentile in [50, 60, 70, 80]:
    threshold = np.percentile(volume, percentile)
    try:
        vertices, faces = marching_cubes(volume, threshold)
        break  # Success!
    except:
        continue  # Try next threshold
```

---

## Output Files Explained

### **Generated Files:**

```
avatar_results/
‚îú‚îÄ‚îÄ input_image.png              # Original 2D image
‚îú‚îÄ‚îÄ triplane_visualization.png   # Three projection planes
‚îú‚îÄ‚îÄ volume_slices.png            # 3D volume cross-sections
‚îî‚îÄ‚îÄ avatar_mesh.obj              # 3D mesh (importable)
```

---

### **1. input_image.png**
- **Content**: Your original 2D avatar image
- **Purpose**: Reference for comparison

---

### **2. triplane_visualization.png**

**Shows three 2D feature maps:**

```
[XY Plane]  [XZ Plane]  [YZ Plane]
   Front       Top         Side
```

**What Each Plane Represents:**
- **Bright areas**: High activation (important features)
- **Dark areas**: Low activation (background/empty)

**Quality Indicators:**

‚úÖ **Good Triplane:**
- Clear facial structure visible
- Symmetrical patterns (left-right balance)
- Smooth gradients (no noise)
- Distinct features (eyes, nose, mouth regions)

‚ùå **Poor Triplane:**
- Random noise patterns
- No recognizable structure
- All black or all white (collapsed)
- Checkerboard artifacts

**Example Interpretation:**
```
XY Plane: Should show face outline, eye positions
XZ Plane: Should show head top-to-bottom structure
YZ Plane: Should show side profile of face
```

---

### **3. volume_slices.png**

**Shows three cross-sections of the 3D volume:**

```
[Slice X]   [Slice Y]   [Slice Z]
  Sagittal    Coronal    Transverse
```

**What to Look For:**

‚úÖ **Healthy Volume:**
- **Slice X (Left-Right)**: Oval shape (head from side)
- **Slice Y (Front-Back)**: Circular shape (head from front)
- **Slice Z (Top-Bottom)**: Circular shape (head from top)
- Smooth boundaries, no holes
- Clear distinction between inside (bright) and outside (dark)

‚ùå **Problem Volume:**
- Irregular shapes
- Multiple disconnected regions
- Uniform gray (no structure)
- Sharp discontinuities

**Density Interpretation:**
```
White (1.0): Solid (inside head)
Gray (0.5):  Surface boundary
Black (0.0): Empty space
```

---

### **4. avatar_mesh.obj**

**Standard 3D file format containing:**
```obj
v -0.5 0.3 0.2    # Vertex positions
v 0.1 -0.2 0.4
...
f 1 2 3           # Triangle faces (vertex indices)
f 4 5 6
...
```

**Usage:**
- Import into **Blender**, **Maya**, **3ds Max**
- View in online viewers (e.g., https://3dviewer.net/)
- Use in game engines (**Unity**, **Unreal**)
- 3D print preparation

**Typical Mesh Stats:**

| Quality | Vertices | Faces | File Size | Detail Level |
|---------|----------|-------|-----------|--------------|
| Low | 5,000 | 10,000 | 500 KB | Preview |
| Medium | 15,000 | 30,000 | 1.5 MB | Standard |
| High | 50,000 | 100,000 | 5 MB | High detail |

---

## Diagnostic Features

### **Model Diagnostics**

```python
generator.diagnose_model()
```

**Checks Performed:**

1. **Parameter Count**
   ```
   Total parameters: 40,123,456
   Trainable parameters: 40,123,456
   ```

2. **NaN Detection in Weights**
   ```
   ‚úì No NaN values in model weights
   ```
   OR
   ```
   ‚ö†Ô∏è WARNING: NaN found in model.blocks.3.mlp.0.weight
   ```

3. **Test Forward Pass**
   ```
   Output shape: [1, 96, 256, 256]
   Output stats: min=-2.34, mean=0.12, max=3.45
   ‚úì Model forward pass successful
   ```

**Interpretation:**

| Check Result | Meaning | Action |
|--------------|---------|--------|
| No NaN in weights | ‚úÖ Model trained correctly | Proceed |
| NaN in weights | ‚ùå Training failed | Retrain Step 4 |
| Normal output range | ‚úÖ Model producing valid results | Proceed |
| NaN in output | ‚ùå Model broken | Check checkpoint |
| Output all zeros | ‚ö†Ô∏è Model not learning | Retrain longer |

---

## Common Issues & Solutions

### **Problem 1: NaN in Triplane**

**Symptoms:**
```
‚ö†Ô∏è WARNING: Triplane contains NaN values!
NaN count: 12,543
```

**Causes:**
1. Model didn't train properly (Step 4 loss didn't converge)
2. Using wrong checkpoint
3. Numerical instability during sampling

**Solutions:**
```python
# 1. Try different checkpoint
MODEL_PATH = "dit_checkpoints/checkpoint_0080.pt"  # Earlier epoch

# 2. Reduce sampling steps
NUM_SAMPLING_STEPS = 25  # Less chance for instability

# 3. Use different seed
SEED = 123  # Try multiple seeds

# 4. If all fail ‚Üí Retrain Step 4 with:
#    - Lower learning rate (lr=5e-5)
#    - More epochs (num_epochs=150)
#    - Check training loss < 0.05
```

---

### **Problem 2: Empty or Incorrect Mesh**

**Symptoms:**
```
‚ùå ERROR: RuntimeError in marching_cubes
```
OR
```
Extracted mesh: 0 vertices, 0 faces
```

**Causes:**
1. Volume is all zeros or all ones (no surface)
2. Threshold too high or too low
3. Triplane didn't capture 3D structure

**Solutions:**

**A. Check Volume Statistics:**
```python
print(f"Volume min: {volume.min()}")
print(f"Volume max: {volume.max()}")
print(f"Volume mean: {volume.mean()}")
```

| Statistics | Interpretation | Action |
|------------|----------------|--------|
| min=max | Uniform volume, no structure | Retrain model |
| All > 0.9 | All solid | Lower threshold |
| All < 0.1 | All empty | Raise threshold |
| Good range [0, 1] | Normal | Adjust threshold |

**B. Manual Threshold:**
```python
# Try lower threshold
vertices, faces = mesh_generator.extract_mesh(volume, threshold=0.3)

# Try higher threshold
vertices, faces = mesh_generator.extract_mesh(volume, threshold=0.7)
```

**C. Use Fallback Volume:**
If the code creates a "fallback spherical volume", it means:
- Triplane was invalid (NaN/Inf)
- Model output is not meaningful
- **Must retrain Step 4**

---

### **Problem 3: Mesh Doesn't Resemble Input**

**Symptoms:**
- Mesh looks random/blobby
- No facial features visible
- Generic sphere instead of face

**Causes:**
1. **Model underfitted** (Step 4 trained too few epochs)
2. **Model overfitted** (memorized training set, can't generalize)
3. **Input image very different** from training style

**Solutions:**

**A. Check Training Loss:**
```
If Step 4 final loss > 0.08:
  ‚Üí Underfitted, train 50-100 more epochs
  
If train loss < 0.02 but val loss > 0.10:
  ‚Üí Overfitted, use earlier checkpoint or get more data
```

**B. Use Training Set Image:**
```python
# Test with an image from training data
TEST_IMAGE = "dit_training_data/images/image_000001.png"
```
If this works but new images don't ‚Üí Model overfitted

**C. Check Input Similarity:**
- Input must match **training style** from Step 1
- If using realistic photo instead of avatar ‚Üí Won't work
- Solution: Run Step 1 style transfer on new image first

---

### **Problem 4: Slow Generation**

**Expected Times:**

| Device | 50 Steps | 100 Steps | 250 Steps |
|--------|----------|-----------|-----------|
| CPU | 10-20 min | 20-40 min | 1-2 hours |
| GTX 1080 | 30 sec | 60 sec | 2-3 min |
| RTX 3090 | 15 sec | 30 sec | 60-90 sec |
| A100 | 10 sec | 20 sec | 40-50 sec |

**Speed Optimization:**

```python
# 1. Reduce steps (quality vs speed)
NUM_SAMPLING_STEPS = 25  # 2√ó faster, slight quality loss

# 2. Use smaller grid (faster mesh)
volume = mesh_generator.triplane_to_volume(triplane, grid_size=48)  # vs 64

# 3. Half precision (2√ó faster, GPU only)
model = model.half()
image_tensor = image_tensor.half()
```

---

### **Problem 5: Out of Memory**

**Symptoms:**
```
RuntimeError: CUDA out of memory
```

**Solutions:**

```python
# 1. Reduce grid size
grid_size = 48  # vs 64 (uses 50% less memory)

# 2. Use CPU for mesh generation
volume = mesh_generator.triplane_to_volume(
    triplane.cpu(), 
    grid_size=64
)

# 3. Clear cache after generation
torch.cuda.empty_cache()
```

---

## Result Quality Assessment

### **Excellent Results:**

‚úÖ **Triplane Visualization:**
- Clear three-plane structure
- Recognizable facial features in all views
- Smooth gradients, no artifacts

‚úÖ **Volume Slices:**
- Closed, continuous shapes
- Symmetric (left-right balance)
- Clear inside/outside boundary

‚úÖ **3D Mesh:**
- 15,000+ vertices
- Smooth surface
- Resembles input image from front view
- Realistic from all angles

---

### **Poor Results:**

‚ùå **Triplane:**
- Random noise
- No structure
- All uniform color

‚ùå **Volume:**
- Disconnected regions
- No clear surface
- Uniform gray

‚ùå **Mesh:**
- < 1,000 vertices
- Irregular shape
- No resemblance to input

**‚Üí Action: Retrain Step 4 or use earlier checkpoint**

---

## Advanced Usage

### **Batch Generation**

```python
# Generate multiple avatars
images = [
    "avatar1.png",
    "avatar2.png",
    "avatar3.png"
]

for i, img_path in enumerate(images):
    results = generator.generate_from_image(
        image_path=img_path,
        num_sampling_steps=50,
        seed=42 + i
    )
    generator.visualize_results(
        results, 
        save_dir=f"avatar_results_{i}"
    )
```

### **Expression Control** (Future Enhancement)

```python
# Separate static and dynamic
static = triplane[:67]   # Identity (fixed)
dynamic = triplane[67:]  # Expression (modify)

# Amplify expression
dynamic_amplified = dynamic * 1.5

# Create new triplane
triplane_new = torch.cat([static, dynamic_amplified], dim=0)

# Generate volume with new expression
volume_new = mesh_generator.triplane_to_volume(
    triplane_new.unsqueeze(0), 
    grid_size=64
)
```

### **Multi-View Rendering** (Future Enhancement)

```python
# Generate from different angles
for angle in [0, 45, 90, 135, 180]:
    # Rotate triplane or camera
    # Render view
    # Save image
```

---

## Integration with 3D Software

### **Blender Import:**

```python
# In Blender Python console:
import bpy

bpy.ops.import_scene.obj(filepath="avatar_results/avatar_mesh.obj")

# Add material
mesh = bpy.context.active_object
mat = bpy.data.materials.new(name="AvatarMaterial")
mesh.data.materials.append(mat)

# Add smooth shading
bpy.ops.object.shade_smooth()
```

### **Unity Import:**

1. Copy `avatar_mesh.obj` to `Assets/Models/`
2. Drag into scene
3. Add materials and textures
4. Attach animation controller

---

## Summary

This pipeline:

‚úÖ **Takes** a 2D avatar image  
‚úÖ **Generates** a 3D triplane using trained DiT model  
‚úÖ **Converts** triplane to 3D volume  
‚úÖ **Extracts** surface mesh using Marching Cubes  
‚úÖ **Outputs** OBJ file + visualizations  
‚úÖ **Includes** extensive diagnostics and error handling

**Quality Indicators:**
- **Triplane**: Shows clear 3-plane structure
- **Volume**: Closed shapes in all three slices
- **Mesh**: 15,000+ vertices, smooth surface
- **Resemblance**: Front view matches input image

**Performance:**
- **GPU (RTX 3090)**: ~15 seconds for 50-step generation
- **CPU**: ~10-15 minutes (not recommended)

**Use Cases:**
- Create 3D avatars from 2D concept art
- Generate game assets
- VR/AR avatar creation
- 3D printing preparation
- Animation and rigging base

**Final Note**: Quality depends entirely on Step 4 training. If results are poor, check Step 4 training metrics (loss should be < 0.05) and consider retraining with more data or epochs.


---

In [None]:
!pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118
!pip install diffusers transformers accelerate
!pip install controlnet-aux opencv-python pillow
!pip install mediapipe==0.10.9

In [None]:
#final_model_path = '/kaggle/input/avatarartist-next3d-4d-gan-fine-tuning/next3d_checkpoints/final_model.pt'
gd_path = '/kaggle/input/cpu-avatarartist1-2d-domain-transfer'
dit_checkpoints = '/kaggle/input/cpu-avatarartist2-next3d-4d-gan-fine-tuning/next3d_checkpoints'
dit_training_data = '/kaggle/input/cpu-avatarartist3-triplane-decomposition/dit_training_data'
best_model_path = '/kaggle/input/cpu-avatarartist4-diffusion-transformer-train/dit_checkpoints/best_model.pt'

In [None]:
"""
3D Avatar Generation & Inference Script
Generates a 3D avatar from a 2D image using a trained DiT model.

Features:
1. 2D Image ‚Üí 3D Triplane prediction
2. Triplane ‚Üí 3D Mesh generation
3. Expression animation generation
4. Visualization and saving of results
"""

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import torchvision.transforms as transforms
from tqdm import tqdm
import json
import cv2
from einops import rearrange, repeat
import math


# ==================== DiT Model Imports (from Training Script) ====================

class TimestepEmbedding(nn.Module):
    """Sinusoidal Embedding for time steps"""

    def __init__(self, dim: int):
        super().__init__()
        self.dim = dim
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.SiLU(),
            nn.Linear(dim * 4, dim)
        )

    def forward(self, timesteps):
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=timesteps.device) * -emb)
        emb = timesteps[:, None] * emb[None, :]
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
        return self.mlp(emb)


class PatchEmbed(nn.Module):
    """Split image into patches and embed them"""

    def __init__(self, img_size=256, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)
        x = rearrange(x, 'b d h w -> b (h w) d')
        return x


class DiTBlock(nn.Module):
    """DiT Transformer Block"""

    def __init__(self, dim: int, num_heads: int, mlp_ratio: float = 4.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim, elementwise_affine=False)
        self.attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        self.norm2 = nn.LayerNorm(dim, elementwise_affine=False)

        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Linear(mlp_hidden_dim, dim)
        )

        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(dim, 6 * dim)
        )

    def forward(self, x, c):
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = \
            self.adaLN_modulation(c).chunk(6, dim=-1)

        h = self.norm1(x)
        h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
        h, _ = self.attn(h, h, h)
        x = x + gate_msa.unsqueeze(1) * h

        h = self.norm2(x)
        h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
        h = self.mlp(h)
        x = x + gate_mlp.unsqueeze(1) * h

        return x


class FinalLayer(nn.Module):
    """Final Output Layer"""

    def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
        super().__init__()
        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False)
        self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 2 * hidden_size)
        )

    def forward(self, x, c):
        shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
        x = self.norm_final(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
        x = self.linear(x)
        return x


class DiffusionTransformer(nn.Module):
    """Diffusion Transformer (DiT)"""

    def __init__(
        self,
        img_size: int = 256,
        patch_size: int = 16,
        in_channels: int = 96,
        condition_channels: int = 3,
        hidden_size: int = 768,
        depth: int = 12,
        num_heads: int = 12,
        mlp_ratio: float = 4.0
    ):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.out_channels = in_channels
        self.num_patches = (img_size // patch_size) ** 2

        self.x_embedder = PatchEmbed(img_size, patch_size, in_channels, hidden_size)
        self.c_embedder = PatchEmbed(img_size, patch_size, condition_channels, hidden_size)
        self.t_embedder = TimestepEmbedding(hidden_size)

        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, hidden_size))

        self.blocks = nn.ModuleList([
            DiTBlock(hidden_size, num_heads, mlp_ratio)
            for _ in range(depth)
        ])

        self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)

        self.initialize_weights()

    def initialize_weights(self):
        def _basic_init(module):
            if isinstance(module, nn.Linear):
                torch.nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
        self.apply(_basic_init)
        nn.init.normal_(self.pos_embed, std=0.02)
        nn.init.constant_(self.final_layer.linear.weight, 0)
        nn.init.constant_(self.final_layer.linear.bias, 0)

    def unpatchify(self, x):
        p = self.patch_size
        h = w = int(x.shape[1] ** 0.5)
        x = rearrange(x, 'b (h w) (p1 p2 c) -> b c (h p1) (w p2)', h=h, p1=p, p2=p)
        return x

    def forward(self, x, t, c):
        x = self.x_embedder(x) + self.pos_embed
        c_embed = self.c_embedder(c)
        t_embed = self.t_embedder(t)

        x = x + c_embed
        c = t_embed

        for block in self.blocks:
            x = block(x, c)

        x = self.final_layer(x, c)
        x = self.unpatchify(x)

        return x


class GaussianDiffusion:
    """DDPM Gaussian Diffusion Implementation"""

    def __init__(self, timesteps=1000, beta_start=0.0001, beta_end=0.02, device='cuda'):
        self.timesteps = timesteps
        self.device = device

        self.betas = torch.linspace(beta_start, beta_end, timesteps).to(device)
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)

        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
        self.posterior_variance = (
            self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        )

    @torch.no_grad()
    def p_sample(self, model, x, t, condition):
        predicted_noise = model(x, t, condition)

        alpha = self.alphas[t][:, None, None, None]
        alpha_cumprod = self.alphas_cumprod[t][:, None, None, None]
        beta = self.betas[t][:, None, None, None]

        pred_x0 = (x - torch.sqrt(1 - alpha_cumprod) * predicted_noise) / torch.sqrt(alpha_cumprod)

        if t[0] > 0:
            noise = torch.randn_like(x)
            sigma = torch.sqrt(self.posterior_variance[t])[:, None, None, None]
        else:
            noise = 0
            sigma = 0

        mean = (1 / torch.sqrt(alpha)) * (x - (beta / torch.sqrt(1 - alpha_cumprod)) * predicted_noise)

        return mean + sigma * noise

    @torch.no_grad()
    def sample(self, model, shape, condition, num_steps=50):
        """
        Fast sampling (DDIM-style skip)

        Args:
            num_steps: Number of sampling steps (fewer steps are faster)
        """
        batch_size = shape[0]
        device = condition.device

        # Start from random noise
        x = torch.randn(shape, device=device)

        # Sample with skips
        step_size = self.timesteps // num_steps
        timesteps = list(range(0, self.timesteps, step_size))[::-1]

        for i in tqdm(timesteps, desc='Generating 3D Avatar'):
            t = torch.full((batch_size,), i, device=device, dtype=torch.long)
            x = self.p_sample(model, x, t, condition)

        return x


# ==================== 3D Mesh Generation ====================

class TriplaneToMesh:
    """Generates 3D mesh from triplane representations"""

    def __init__(self, resolution: int = 256):
        self.resolution = resolution

    def triplane_to_volume(self, triplane: torch.Tensor, grid_size: int = 64):
        """
        Generates a 3D volume from a triplane with NaN handling

        Args:
            triplane: [B, C, H, W] triplane tensor
            grid_size: Resolution of the 3D grid

        Returns:
            volume: [B, grid_size, grid_size, grid_size] density field
        """
        B, C, H, W = triplane.shape
        device = triplane.device

        # === DIAGNOSTIC: Check triplane for NaN/Inf ===
        if torch.isnan(triplane).any():
            print(f"‚ö†Ô∏è  WARNING: Triplane contains NaN values!")
            print(f"   NaN count: {torch.isnan(triplane).sum().item()}")
            # Replace NaN with zeros
            triplane = torch.nan_to_num(triplane, nan=0.0)
            
        if torch.isinf(triplane).any():
            print(f"‚ö†Ô∏è  WARNING: Triplane contains Inf values!")
            # Replace Inf with zeros
            triplane = torch.nan_to_num(triplane, posinf=0.0, neginf=0.0)

        # Clamp to reasonable range
        triplane = torch.clamp(triplane, -10.0, 10.0)
        
        print(f"[Volume] Triplane stats: min={triplane.min().item():.4f}, "
              f"mean={triplane.mean().item():.4f}, max={triplane.max().item():.4f}")

        # Generate 3D grid coordinates
        coords = torch.linspace(-1, 1, grid_size, device=device)
        grid = torch.stack(torch.meshgrid(coords, coords, coords, indexing='ij'), dim=-1)
        grid = grid.reshape(-1, 3)  # [N, 3]

        # Sample features from each plane
        xy_coords = grid[:, [0, 1]].unsqueeze(0).unsqueeze(0)  # XY plane
        xz_coords = grid[:, [0, 2]].unsqueeze(0).unsqueeze(0)  # XZ plane
        yz_coords = grid[:, [1, 2]].unsqueeze(0).unsqueeze(0)  # YZ plane

        # Split triplane into three components
        C_per_plane = C // 3
        plane_xy = triplane[:, :C_per_plane]
        plane_xz = triplane[:, C_per_plane:2*C_per_plane]
        plane_yz = triplane[:, 2*C_per_plane:]

        # Feature sampling with error handling
        try:
            feat_xy = F.grid_sample(plane_xy, xy_coords, align_corners=False, mode='bilinear', padding_mode='zeros')
            feat_xz = F.grid_sample(plane_xz, xz_coords, align_corners=False, mode='bilinear', padding_mode='zeros')
            feat_yz = F.grid_sample(plane_yz, yz_coords, align_corners=False, mode='bilinear', padding_mode='zeros')
        except Exception as e:
            print(f"‚ö†Ô∏è  Error during grid sampling: {e}")
            # Fallback: create simple volume
            return self._create_fallback_volume(B, grid_size, device)

        # Aggregate features
        features = feat_xy + feat_xz + feat_yz  # [B, C_per_plane, 1, N]
        
        # Check for NaN after sampling
        if torch.isnan(features).any():
            print(f"‚ö†Ô∏è  WARNING: Features contain NaN after sampling!")
            features = torch.nan_to_num(features, nan=0.0)

        # Convert to density
        density = features.mean(dim=1).squeeze(1)  # [B, N]
        
        # Normalize to [0, 1] range more carefully
        density_min = density.min()
        density_max = density.max()
        
        if density_max - density_min > 1e-8:
            # Normalize to [0, 1]
            density = (density - density_min) / (density_max - density_min + 1e-8)
        else:
            print(f"‚ö†Ô∏è  WARNING: Density range is too small, using sigmoid")
            density = torch.sigmoid(density)
        
        # Final NaN check
        if torch.isnan(density).any():
            print(f"‚ö†Ô∏è  WARNING: Density contains NaN, replacing with fallback")
            return self._create_fallback_volume(B, grid_size, device)

        # Reshape back to 3D grid
        volume = density.reshape(B, grid_size, grid_size, grid_size)
        
        print(f"[Volume] Density stats: min={volume.min().item():.4f}, "
              f"mean={volume.mean().item():.4f}, max={volume.max().item():.4f}")

        return volume

    def _create_fallback_volume(self, B, grid_size, device):
        """Create a simple spherical volume as fallback"""
        print("üîÑ Creating fallback spherical volume...")
        
        coords = torch.linspace(-1, 1, grid_size, device=device)
        x, y, z = torch.meshgrid(coords, coords, coords, indexing='ij')
        
        # Create a sphere
        radius = torch.sqrt(x**2 + y**2 + z**2)
        volume = torch.exp(-radius * 3)  # Gaussian sphere
        volume = volume.unsqueeze(0).repeat(B, 1, 1, 1)
        
        return volume

    def extract_mesh(self, volume: torch.Tensor, threshold: float = None):
        """
        Extract mesh from volume using Marching Cubes
        """
        from skimage import measure
        import numpy as np

        volume_np = volume.cpu().numpy()

        # Volume statistics
        vmin, vmax = volume_np.min(), volume_np.max()
        vmean = volume_np.mean()

        print(f"[Mesh] volume stats: min={vmin:.4f}, mean={vmean:.4f}, max={vmax:.4f}")

        # Check for NaN
        if np.isnan(volume_np).any():
            print("‚ùå ERROR: Volume contains NaN values!")
            print("   This usually means the diffusion model didn't train properly.")
            print("   Try using a different checkpoint or retraining the model.")
            raise ValueError("Cannot extract mesh from NaN volume")

        # Adaptive iso-value
        if threshold is None:
            # Try multiple thresholds
            percentiles = [50, 60, 70, 80]
            for p in percentiles:
                threshold = np.percentile(volume_np, p)
                print(f"[Mesh] trying {p}th percentile threshold = {threshold:.4f}")
                
                try:
                    vertices, faces, normals, values = measure.marching_cubes(
                        volume_np,
                        level=threshold
                    )
                    print(f"‚úì Success with {p}th percentile!")
                    break
                except RuntimeError:
                    continue
            else:
                # All failed, try mean
                threshold = vmean
                print(f"[Mesh] trying mean threshold = {threshold:.4f}")
                vertices, faces, normals, values = measure.marching_cubes(
                    volume_np,
                    level=threshold
                )
        else:
            vertices, faces, normals, values = measure.marching_cubes(
                volume_np,
                level=threshold
            )

        # Normalize to [-1, 1]
        vertices = vertices / volume_np.shape[0] * 2 - 1

        print(f"‚úì Extracted mesh: {len(vertices)} vertices, {len(faces)} faces")
        
        return vertices, faces

    def save_obj(self, vertices, faces, filepath: str):
        """Save in OBJ format"""
        with open(filepath, 'w') as f:
            for v in vertices:
                f.write(f"v {v[0]} {v[1]} {v[2]}\n")
            for face in faces:
                f.write(f"f {face[0]+1} {face[1]+1} {face[2]+1}\n")
        print(f"‚úì Mesh saved: {filepath}")


# ==================== Avatar Generator ====================

class AvatarGenerator:
    """Main class for 3D Avatar Generation"""

    def __init__(
        self,
        model_path: str,
        device: str = 'cuda',
        model_config: Dict = None
    ):
        """
        Args:
            model_path: Path to trained model
            device: Computing device
            model_config: Model configuration (can be auto-detected)
        """
        self.device = device

        print("=" * 60)
        print("Avatar Generator Initialization")
        print("=" * 60)

        # Load Model
        print(f"\nLoading model: {model_path}")
        self.load_model(model_path, model_config)

        # Diffusion setup
        self.diffusion = GaussianDiffusion(
            timesteps=1000,
            beta_start=0.0001,
            beta_end=0.02,
            device=device
        )

        # Mesh generator setup
        self.mesh_generator = TriplaneToMesh()

        # Image preprocessing
        self.transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ])

        print("\n‚úì Initialization complete!")
        print("=" * 60 + "\n")

    def load_model(self, model_path: str, model_config: Optional[Dict] = None):
        """Loads the model from a checkpoint"""
        checkpoint = torch.load(model_path, map_location=self.device)

        # Infer model settings if not provided
        if model_config is None:
            state_dict = checkpoint.get('model', checkpoint)

            # Guess config based on parameter count
            num_params = sum(p.numel() for p in state_dict.values())

            if num_params < 50_000_000:  # Under 50M
                model_config = {
                    'hidden_size': 384,
                    'depth': 6,
                    'num_heads': 6,
                }
                print("  Detected: SMALL model")
            else:
                model_config = {
                    'hidden_size': 768,
                    'depth': 12,
                    'num_heads': 12,
                }
                print("  Detected: STANDARD model")

        # Construct model
        self.model = DiffusionTransformer(
            img_size=256,
            patch_size=16,
            in_channels=96,  # Default
            condition_channels=3,
            **model_config
        ).to(self.device)

        # Load weights
        self.model.load_state_dict(checkpoint.get('model', checkpoint))
        self.model.eval()

        print(f"  Parameters: {sum(p.numel() for p in self.model.parameters()):,}")

    @torch.no_grad()
    def generate_from_image(
        self,
        image_path: str,
        num_sampling_steps: int = 50,
        seed: Optional[int] = None
    ) -> Dict:
        """
        Generates 3D avatar from a 2D image

        Args:
            image_path: Input image path
            num_sampling_steps: Number of diffusion sampling steps
            seed: Random seed

        Returns:
            results: Dictionary containing results
        """
        if seed is not None:
            torch.manual_seed(seed)
            if torch.cuda.is_available():
                torch.cuda.manual_seed(seed)

        print(f"\nGenerating 3D avatar from: {image_path}")

        # Load image
        image = Image.open(image_path).convert('RGB')
        image_tensor = self.transform(image).unsqueeze(0).to(self.device)
        
        # Check input
        print(f"[Input] Image tensor stats: min={image_tensor.min().item():.4f}, "
              f"mean={image_tensor.mean().item():.4f}, max={image_tensor.max().item():.4f}")

        # Generate triplane
        print("Generating triplane...")
        triplane = self.diffusion.sample(
            self.model,
            shape=(1, 96, 256, 256),
            condition=image_tensor,
            num_steps=num_sampling_steps
        )
        
        # Check triplane output
        print(f"[Triplane] Generated triplane stats: min={triplane.min().item():.4f}, "
              f"mean={triplane.mean().item():.4f}, max={triplane.max().item():.4f}")
        
        if torch.isnan(triplane).any():
            print("‚ö†Ô∏è  WARNING: Triplane contains NaN! Model may not be properly trained.")
            print("   Attempting to fix...")
            triplane = torch.nan_to_num(triplane, nan=0.0)
        
        if torch.isinf(triplane).any():
            print("‚ö†Ô∏è  WARNING: Triplane contains Inf values!")
            triplane = torch.nan_to_num(triplane, posinf=1.0, neginf=-1.0)

        # Generate volume
        print("Converting to 3D volume...")
        volume = self.mesh_generator.triplane_to_volume(triplane, grid_size=64)

        return {
            'image': image,
            'triplane': triplane[0],
            'volume': volume[0],
        }

    def visualize_results(
        self,
        results: Dict,
        save_dir: str = "/content/avatar_results"
    ):
        """Visualize and save results"""
        os.makedirs(save_dir, exist_ok=True)

        # 1. Input Image
        results['image'].save(os.path.join(save_dir, "input_image.png"))

        # 2. Triplane Visualization
        self.visualize_triplane(results['triplane'], save_dir)

        # 3. Volume Visualization
        self.visualize_volume(results['volume'], save_dir)

        # 4. Mesh Generation and Saving
        print("\nExtracting mesh...")
        vertices, faces = self.mesh_generator.extract_mesh(results['volume'])
        obj_path = os.path.join(save_dir, "avatar_mesh.obj")
        self.mesh_generator.save_obj(vertices, faces, obj_path)

        print(f"\n‚úì Results saved to: {save_dir}")
        return save_dir

    def visualize_triplane(self, triplane: torch.Tensor, save_dir: str):
        """Visualize the triplane components"""
        C, H, W = triplane.shape

        # Split into three planes
        C_per_plane = C // 3
        plane_xy = triplane[:C_per_plane].mean(0).cpu().numpy()
        plane_xz = triplane[C_per_plane:2*C_per_plane].mean(0).cpu().numpy()
        plane_yz = triplane[2*C_per_plane:].mean(0).cpu().numpy()

        # Normalization
        def normalize(p): return (p - p.min()) / (p.max() - p.min() + 1e-8)
        plane_xy = normalize(plane_xy)
        plane_xz = normalize(plane_xz)
        plane_yz = normalize(plane_yz)

        # Plotting
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))

        axes[0].imshow(plane_xy, cmap='viridis')
        axes[0].set_title('XY Plane')
        axes[0].axis('off')

        axes[1].imshow(plane_xz, cmap='viridis')
        axes[1].set_title('XZ Plane')
        axes[1].axis('off')

        axes[2].imshow(plane_yz, cmap='viridis')
        axes[2].set_title('YZ Plane')
        axes[2].axis('off')

        plt.tight_layout()
        plt.savefig(os.path.join(save_dir, "triplane_visualization.png"), dpi=150, bbox_inches='tight')
        plt.close()

        print("‚úì Triplane visualized")

    def visualize_volume(self, volume: torch.Tensor, save_dir: str):
        """Visualize slices of the 3D volume"""
        volume_np = volume.cpu().numpy()

        # Middle slice
        mid = volume_np.shape[0] // 2

        fig, axes = plt.subplots(1, 3, figsize=(15, 5))

        axes[0].imshow(volume_np[mid, :, :], cmap='gray')
        axes[0].set_title('Slice X')
        axes[0].axis('off')

        axes[1].imshow(volume_np[:, mid, :], cmap='gray')
        axes[1].set_title('Slice Y')
        axes[1].axis('off')

        axes[2].imshow(volume_np[:, :, mid], cmap='gray')
        axes[2].set_title('Slice Z')
        axes[2].axis('off')

        plt.tight_layout()
        plt.savefig(os.path.join(save_dir, "volume_slices.png"), dpi=150, bbox_inches='tight')
        plt.close()

        print("‚úì Volume visualized")

    def diagnose_model(self):
        """Run diagnostics on the model"""
        print("\n" + "="*60)
        print("Model Diagnostics")
        print("="*60)
        
        # Check model parameters
        total_params = sum(p.numel() for p in self.model.parameters())
        trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        
        print(f"Total parameters: {total_params:,}")
        print(f"Trainable parameters: {trainable_params:,}")
        
        # Check for NaN in weights
        nan_count = 0
        for name, param in self.model.named_parameters():
            if torch.isnan(param).any():
                print(f"‚ö†Ô∏è  WARNING: NaN found in {name}")
                nan_count += 1
        
        if nan_count == 0:
            print("‚úì No NaN values in model weights")
        else:
            print(f"‚ùå Found NaN in {nan_count} parameter tensors!")
        
        # Test forward pass with dummy data
        print("\nTesting forward pass with dummy data...")
        try:
            dummy_x = torch.randn(1, 96, 256, 256).to(self.device)
            dummy_t = torch.tensor([500]).to(self.device)
            dummy_c = torch.randn(1, 3, 256, 256).to(self.device)
            
            output = self.model(dummy_x, dummy_t, dummy_c)
            
            print(f"Output shape: {output.shape}")
            print(f"Output stats: min={output.min().item():.4f}, "
                  f"mean={output.mean().item():.4f}, max={output.max().item():.4f}")
            
            if torch.isnan(output).any():
                print("‚ùå Model output contains NaN!")
            else:
                print("‚úì Model forward pass successful")
        except Exception as e:
            print(f"‚ùå Error during forward pass: {e}")
        
        print("="*60 + "\n")


# ==================== Main Execution ====================

def main():
    """Main execution entry point"""

    # ========== Kaggle Environment Path Settings ==========

    # Trained model checkpoint
    MODEL_PATH = best_model_path
    # Alternative: MODEL_PATH = "/kaggle/working/dit_checkpoints/final_model.pt"

    # Test image (use one from input data)
    TEST_IMAGE = f'{gd_path}/output_styled/styled_Chris Evans42_1217.jpg'

    # Output directory
    OUTPUT_DIR = "avatar_results"

    # ========== Path Verification ==========

    print("=" * 60)
    print("Path Verification")
    print("=" * 60)
    print(f"Model: {MODEL_PATH}")
    print(f"  Exists: {os.path.exists(MODEL_PATH)}")
    if os.path.exists(MODEL_PATH):
        size_mb = os.path.getsize(MODEL_PATH) / (1024 * 1024)
        print(f"  Size: {size_mb:.1f} MB")

    print(f"\nTest Image: {TEST_IMAGE}")
    print(f"  Exists: {os.path.exists(TEST_IMAGE)}")

    print(f"\nOutput: {OUTPUT_DIR}")
    print("=" * 60 + "\n")

    if not os.path.exists(MODEL_PATH):
        print("‚ùå Model not found!")
        print("\nAvailable models:")
        checkpoint_dir = f'{gd_path}/dit_checkpoints'
        if os.path.exists(checkpoint_dir):
            for f in os.listdir(checkpoint_dir):
                if f.endswith('.pt'):
                    path = os.path.join(checkpoint_dir, f)
                    size = os.path.getsize(path) / (1024 * 1024)
                    print(f"  {path} ({size:.1f} MB)")
        return

    if not os.path.exists(TEST_IMAGE):
        print("‚ùå Test image not found!")
        print("\nSearching for alternative images...")
        data_dir = f'{gd_path}/dit_training_data'
        img_dir = os.path.join(data_dir, "images")
        if os.path.exists(img_dir):
            images = sorted(os.listdir(img_dir))[:5]
            print(f"Available images: {images}")
            if images:
                TEST_IMAGE = os.path.join(img_dir, images[0])
                print(f"Using: {TEST_IMAGE}")
        else:
            return

    # ========== Configuration ==========

    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    NUM_SAMPLING_STEPS = 50  # Fast generation (use 50 for higher quality)
    SEED = 42

    # ========== Avatar Generation Process ==========

    try:
        # Initialize generator
        generator = AvatarGenerator(
            model_path=MODEL_PATH,
            device=DEVICE
        )

        # Run diagnostics
        generator.diagnose_model()

        # Generation
        results = generator.generate_from_image(
            image_path=TEST_IMAGE,
            num_sampling_steps=NUM_SAMPLING_STEPS,
            seed=SEED
        )

        # Visualize and save
        output_dir = generator.visualize_results(results, OUTPUT_DIR)

        print("\n" + "=" * 60)
        print("Generation Complete!")
        print("=" * 60)
        print(f"Output directory: {output_dir}")
        print("\nGenerated files:")
        for f in sorted(os.listdir(output_dir)):
            print(f"  - {f}")
        print("=" * 60)

        # Generate download links for Kaggle environment
        try:
            from IPython.display import FileLink
            print("\nüì• Download links:")
            print(FileLink(os.path.join(output_dir, "avatar_mesh.obj")))
        except:
            pass

    except Exception as e:
        print(f"\n‚ùå Error: {e}")
        import traceback
        traceback.print_exc()


if __name__ == "__main__":
    main()

In [None]:
import os
import matplotlib.pyplot as plt
from PIL import Image

def show_image(image_dir):
    image_paths = [
        os.path.join(image_dir, f)
        for f in sorted(os.listdir(image_dir))
        if f.lower().endswith((".png", ".jpg", ".jpeg"))
    ][:6]  
    fig, axes = plt.subplots(2, 3, figsize=(12, 8))
    axes = axes.flatten()
    for ax, img_path in zip(axes, image_paths):
        img = Image.open(img_path)
        ax.imshow(img)
        ax.axis("off")
        ax.set_title(os.path.basename(img_path), fontsize=9)
    for ax in axes[len(image_paths):]:
        ax.axis("off")
    plt.tight_layout()
    plt.show()

In [None]:
show_image(f'{gd_path}/output_styled')