# **CPU AvatarArtist2: Next3D (4D GAN) Fine-tuning**

---

## **What is AvatarArtist?**
AvatarArtist is a cutting-edge AI system that creates high-quality 3D avatars from text descriptions or 2D images. It can transform a simple photo or text prompt into a fully-realized 3D character in various artistic styles.

https://kumapowerliu.github.io/AvatarArtist/

---

## **Step 2 Pipeline Explanation: Next3D Fine-tuning**

## Overview
This pipeline fine-tunes a **Next3D** model (3D-aware GAN) on your style-transferred avatar images to generate new faces in that artistic style while maintaining 3D consistency.

---

## Pipeline Architecture

### 1. **Data Loading**
```python
StyleFaceDataset(
    image_dir="./output_styled",  # Your stylized images from Step 1
    mesh_dir=None,                 # Optional 3DMM parameters
    resolution=512                 # Image size
)
```

**Purpose**: Loads your style-transferred images and optionally 3D face parameters (shape, expression, pose).

---

### 2. **Generator Architecture (TriplaneGenerator)**

#### Components:

**a) Mapping Network**
- Converts random noise `z` (512-dim) → intermediate latent `w` (512-dim)
- 8 fully-connected layers with LeakyReLU activation
- Inspired by StyleGAN2

**b) 3DMM Conditioning Module** (Optional)
- Encodes face shape (80-dim) and expression (64-dim) parameters
- Fuses with latent code to control facial attributes
- Enables explicit control over identity and expressions

**c) Triplane Backbone**
- Generates 3D feature representations (tri-planes)
- Uses StyleGAN2-style synthesis blocks
- Output: 64×64 feature maps with 96 channels

**d) Super-Resolution Module**
- Upsamples 64×64 → 512×512
- Three upsampling stages (2× each)
- Final output: RGB image in [-1, 1] range

---

### 3. **Discriminator Architecture**
- Progressive downsampling (StyleGAN2-inspired)
- 5 discriminator blocks: 512→256→128→64→32
- Adaptive pooling to 4×4 before final classification
- Output: Real/Fake score

---

## Key Training Components

### 4. **Loss Functions**

#### **a) Adversarial Loss (Logistic Loss)**
```python
d_loss = softplus(d_fake) + softplus(-d_real)  # Discriminator
g_loss = softplus(-d_fake)                      # Generator
```
- Non-saturating GAN loss
- More stable than traditional GAN loss

#### **b) R1 Regularization**
```python
lambda_r1 = 1.0  # Gradient penalty weight
```
- Prevents discriminator gradients from exploding
- Stabilizes training
- Applied to real images only

#### **c) Density Regularization**
```python
lambda_den = 0.1  # Density penalty weight
```
- Regularizes triplane features
- Prevents overly complex 3D representations

#### **d) Adaptive Data Augmentation (ADA)**
```python
target_accuracy = 0.6
adjustment_speed = 0.001
```
- Automatically adjusts augmentation strength
- Prevents discriminator from overfitting
- Includes: horizontal flip, 90° rotations

---

## Parameter Settings

### **Critical Parameters**

| Parameter | Default Value | Purpose | Tuning Guide |
|-----------|--------------|---------|--------------|
| `lr_g` | 0.0002 | Generator learning rate | Lower for stability (0.0001) |
| `lr_d` | 0.0002 | Discriminator learning rate | Lower for stability (0.0001) |
| `lambda_r1` | 1.0 | R1 regularization strength | Increase if D gradients explode (2.0-5.0) |
| `lambda_den` | 0.1 | Density regularization | Increase for smoother 3D (0.2-0.5) |
| `batch_size` | CPU: 1, GPU: 4 | Batch size | GPU 12GB: 2-4, GPU 24GB: 4-8 |
| `num_epochs` | CPU: 10, GPU: 100 | Training epochs | More data = fewer epochs needed |
| `grad_clip` | 1.0 | Gradient clipping threshold | Lower if NaN occurs (0.5) |

### **Device-Specific Settings**

**CPU Mode:**
```python
BATCH_SIZE = 1
NUM_EPOCHS = 10  # For testing
num_workers = 0
```
⚠️ **Warning**: Training on CPU is **50-100× slower** than GPU

**GPU Mode (Recommended):**
```python
BATCH_SIZE = 4  # Adjust based on VRAM
NUM_EPOCHS = 100
num_workers = 2
```

---

## Training Process

### **Step-by-Step Execution**

1. **Initialization**
   - Creates generator with 96M parameters
   - Creates discriminator with 30M parameters
   - Initializes optimizers (Adam with β₁=0.5, β₂=0.999)

2. **Training Loop** (per batch)
   ```
   For each batch:
     1. Train Discriminator:
        - Evaluate real images → d_real
        - Generate fake images → d_fake
        - Compute adversarial loss + R1 penalty
        - Update discriminator weights
     
     2. Update ADA:
        - Monitor discriminator accuracy
        - Adjust augmentation strength
     
     3. Train Generator:
        - Generate new fake images
        - Compute adversarial loss + density penalty
        - Update generator weights
   ```

3. **Safety Features**
   - **NaN Detection**: Monitors for numerical instability
   - **Gradient Clipping**: Prevents exploding gradients
   - **Weight Initialization**: Proper initialization for stability

---

## Expected Results

### **Training Metrics**

**Healthy Training Signs:**
```
Epoch 1-10:
  D Loss: 1.2-0.8 (decreasing)
  G Loss: 1.5-1.0 (decreasing then stabilizing)
  R1 Loss: 0.5-0.3
  Density Loss: 0.1-0.05
  ADA p: 0.0-0.3 (gradually increasing)
  Real Accuracy: 0.6-0.7
```

**Convergence Indicators:**
- D Loss stabilizes around 0.6-0.8
- G Loss stabilizes around 0.8-1.2
- ADA p settles around 0.3-0.5
- No NaN warnings

### **Output Files**

```
next3d_checkpoints/
├── checkpoint_0010.pt  # Every 10 epochs
├── checkpoint_0020.pt
├── ...
├── checkpoint_0100.pt
└── final_model.pt      # Final trained model
```

**Checkpoint Contents:**
- Generator weights
- Discriminator weights
- Optimizer states
- ADA parameters
- Current epoch

---

## Common Issues & Solutions

### **Problem 1: NaN Loss**
**Symptoms:**
```
WARNING: NaN or Inf detected in d_real
CRITICAL: Too many NaN occurrences!
```

**Solutions:**
1. Reduce learning rates to 0.0001
2. Decrease batch size to 1
3. Lower gradient clipping to 0.5
4. Check input data for NaN values
5. Restart training from scratch

---

### **Problem 2: Mode Collapse**
**Symptoms:**
- G Loss stuck at high value (>2.0)
- Generated images all look similar

**Solutions:**
1. Increase R1 regularization (λ_r1 = 2.0-5.0)
2. Enable ADA (should happen automatically)
3. Use more diverse training data
4. Reduce discriminator learning rate

---

### **Problem 3: Discriminator Overfitting**
**Symptoms:**
- D Loss → 0 quickly
- G Loss keeps increasing
- Real Accuracy > 0.9

**Solutions:**
1. ADA will automatically increase (check `ada_p`)
2. Manually increase λ_r1
3. Add more augmentations

---

### **Problem 4: Slow Training**
**On CPU:**
- Each epoch: ~30-60 minutes (with 100 images)
- **Recommendation**: Use Google Colab GPU or cloud GPU

**On GPU:**
- Each epoch: 1-3 minutes
- Can complete 100 epochs in 2-5 hours

---

## Usage After Training

### **Generate New Faces**
```python
# Load trained model
checkpoint = torch.load("next3d_checkpoints/final_model.pt")
generator.load_state_dict(checkpoint['generator'])

# Generate
z = torch.randn(1, 512)  # Random latent code
output = generator(z)
image = output['image']  # Generated face in avatar style
```

### **Control Generation** (if using 3DMM)
```python
# Control shape and expression
shape = torch.randn(1, 80) * 0.5  # Identity variation
exp = torch.randn(1, 64) * 0.3    # Expression variation
output = generator(z, shape, exp)
```

---

## Performance Expectations

| Hardware | Batch Size | Time/Epoch | 100 Epochs |
|----------|-----------|------------|------------|
| CPU (8 cores) | 1 | 30-60 min | 50-100 hours |
| GPU (GTX 1080) | 2 | 3-5 min | 5-8 hours |
| GPU (RTX 3090) | 8 | 1-2 min | 2-3 hours |
| GPU (A100) | 16 | 30-60 sec | 1-2 hours |

---

## Summary

This pipeline:
✅ Trains a 3D-aware GAN on your avatar-style images  
✅ Maintains 3D consistency and view synthesis capability  
✅ Includes stability features (ADA, R1, gradient clipping)  
✅ Can optionally use 3DMM parameters for explicit control  
✅ Produces checkpoints every 10 epochs for evaluation  

**Best Practice**: Start with 10-20 epochs on GPU to verify training stability before committing to full 100-epoch training.

---

In [None]:
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
!pip install diffusers transformers accelerate
!pip install controlnet-aux opencv-python pillow
!pip install mediapipe==0.10.13

In [None]:
import shutil
import os

paths=[]
for dirname, _, filenames in os.walk('/kaggle/input/cpu-avatarartist1-2d-domain-transfer/output_styled'):
    for filename in filenames:
        paths+=[(os.path.join(dirname, filename))]
print(paths[0:6])

os.makedirs("output_styled", exist_ok=True)
for path in paths:
    shutil.copy(path, "output_styled")

for dirname, _, filenames in os.walk('./output_styled'):
    for filename in filenames:
        print(filename)
        
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
from pathlib import Path
from typing import Dict, Tuple, Optional
import torchvision.transforms as transforms
from tqdm import tqdm
import pickle


# ==================== Dataset ====================

class StyleFaceDataset(Dataset):
    """Dataset for style-transferred face images + 3DMM parameters"""
    
    def __init__(
        self,
        image_dir: str,
        mesh_dir: Optional[str] = None,
        resolution: int = 512,
        use_3dmm: bool = True
    ):
        """
        Args:
            image_dir: Directory for style-transferred images (output_styled)
            mesh_dir: Directory for 3DMM parameters (optional)
            resolution: Image resolution
            use_3dmm: Whether to use 3DMM parameters
        """
        self.image_dir = Path(image_dir)
        self.mesh_dir = Path(mesh_dir) if mesh_dir else None
        self.resolution = resolution
        self.use_3dmm = use_3dmm
        
        # List image files
        self.image_files = sorted(list(self.image_dir.glob("*.jpg")) + 
                                  list(self.image_dir.glob("*.png")))
        
        # Transformations
        self.transform = transforms.Compose([
            transforms.Resize((resolution, resolution)),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ])
        
        print(f"Dataset initialized: {len(self.image_files)} images found.")
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        # Load image
        img_path = self.image_files[idx]
        image = Image.open(img_path).convert('RGB')
        image = self.transform(image)
        
        data = {'image': image, 'index': idx}
        
        # Load 3DMM parameters (if they exist)
        if self.use_3dmm and self.mesh_dir:
            mesh_file = self.mesh_dir / f"{img_path.stem}.pkl"
            if mesh_file.exists():
                with open(mesh_file, 'rb') as f:
                    mesh_params = pickle.load(f)
                data['shape'] = torch.FloatTensor(mesh_params.get('shape', np.zeros(80)))
                data['exp'] = torch.FloatTensor(mesh_params.get('exp', np.zeros(64)))
                data['pose'] = torch.FloatTensor(mesh_params.get('pose', np.zeros(6)))
            else:
                # Dummy parameters
                data['shape'] = torch.zeros(80)
                data['exp'] = torch.zeros(64)
                data['pose'] = torch.zeros(6)
        
        return data


# ==================== Next3D Model Components ====================

class TriplaneGenerator(nn.Module):
    """Next3D Tri-plane Generator (Simplified version)"""
    
    def __init__(
        self,
        z_dim: int = 512,
        w_dim: int = 512,
        triplane_channels: int = 96,
        triplane_resolution: int = 256,
        use_3dmm: bool = True,
        shape_dim: int = 80,
        exp_dim: int = 64
    ):
        super().__init__()
        self.z_dim = z_dim
        self.w_dim = w_dim
        self.use_3dmm = use_3dmm
        
        # Mapping Network (z -> w)
        self.mapping = MappingNetwork(z_dim, w_dim)
        
        # 3DMM Conditioning Module (Dynamic components)
        if use_3dmm:
            self.shape_encoder = nn.Linear(shape_dim, w_dim)
            self.exp_encoder = nn.Linear(exp_dim, w_dim)
            self.condition_fusion = nn.Linear(w_dim * 3, w_dim)
        
        # Tri-plane generation network
        self.triplane_generator = TriplaneBackbone(
            w_dim=w_dim,
            channels=triplane_channels,
            resolution=triplane_resolution
        )
        
        # Super-resolution module (Low-res -> High-res)
        self.superres = SuperResolutionModule(
            triplane_channels, 
            output_resolution=512
        )
        
        # Weight initialization
        self.apply(self._init_weights)
    
    def _init_weights(self, m):
        """Proper weight initialization"""
        if isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, 0, 0.02)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Conv2d):
            nn.init.normal_(m.weight, 0, 0.02)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
    
    def forward(
        self,
        z: torch.Tensor,
        shape: Optional[torch.Tensor] = None,
        exp: Optional[torch.Tensor] = None,
        c: Optional[torch.Tensor] = None
    ):
        """
        Args:
            z: Latent code [B, z_dim]
            shape: 3DMM shape parameters [B, shape_dim]
            exp: 3DMM expression parameters [B, exp_dim]
            c: Camera parameters [B, 25]
        
        Returns:
            output: Dictionary containing generated image and tri-plane
        """
        # Mapping network
        w = self.mapping(z)
        
        # 3DMM Conditioning
        if self.use_3dmm and shape is not None and exp is not None:
            shape_feat = self.shape_encoder(shape)
            exp_feat = self.exp_encoder(exp)
            w_conditioned = self.condition_fusion(
                torch.cat([w, shape_feat, exp_feat], dim=1)
            )
        else:
            w_conditioned = w
        
        # Tri-plane generation
        triplane_features = self.triplane_generator(w_conditioned)
        
        # Rendering (Simplified: Actual version uses volume rendering)
        # Placeholder: generating image directly from features
        image = self.superres(triplane_features)
        
        return {
            'image': image,
            'triplane': triplane_features,  # Used for density regularization
            'w': w_conditioned
        }


class MappingNetwork(nn.Module):
    """StyleGAN2-style Mapping Network"""
    
    def __init__(self, z_dim: int = 512, w_dim: int = 512, num_layers: int = 8):
        super().__init__()
        layers = []
        for i in range(num_layers):
            in_dim = z_dim if i == 0 else w_dim
            layers.extend([
                nn.Linear(in_dim, w_dim),
                nn.LeakyReLU(0.2)
            ])
        self.net = nn.Sequential(*layers)
    
    def forward(self, z):
        # Normalize input to improve numerical stability
        z = F.normalize(z, dim=1)
        return self.net(z)


class TriplaneBackbone(nn.Module):
    """Backbone for Tri-plane generation"""
    
    def __init__(self, w_dim: int, channels: int, resolution: int):
        super().__init__()
        self.w_dim = w_dim
        self.channels = channels
        self.resolution = resolution
        
        # Synthesis network inspired by StyleGAN2
        self.const = nn.Parameter(torch.randn(1, channels, 4, 4) * 0.01)
        
        self.blocks = nn.ModuleList([
            StyleBlock(channels, channels, w_dim),
            StyleBlock(channels, channels, w_dim),
            StyleBlock(channels, channels, w_dim),
            StyleBlock(channels, channels, w_dim),
        ])
        
        # Output: 64x64 with channels aligned for super-resolution
        self.to_features = nn.Conv2d(channels, channels, 1)
    
    def forward(self, w):
        B = w.shape[0]
        x = self.const.repeat(B, 1, 1, 1)
        
        for block in self.blocks:
            x = block(x, w)
            x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        
        # Final output: [B, channels, 64, 64]
        features = self.to_features(x)
        return features


class StyleBlock(nn.Module):
    """StyleGAN2-style Synthesis Block"""
    
    def __init__(self, in_ch, out_ch, w_dim):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.style1 = nn.Linear(w_dim, in_ch)
        self.style2 = nn.Linear(w_dim, out_ch)
        self.noise1 = NoiseInjection()
        self.noise2 = NoiseInjection()
        self.activation = nn.LeakyReLU(0.2)
    
    def forward(self, x, w):
        # First conv + style modulation
        s1 = self.style1(w).unsqueeze(-1).unsqueeze(-1)
        # Stabilize style modulation (add a small value)
        x = self.conv1(x * (s1 + 1.0))
        x = self.noise1(x)
        x = self.activation(x)
        
        # Second conv + style modulation
        s2 = self.style2(w).unsqueeze(-1).unsqueeze(-1)
        x = self.conv2(x * (s2 + 1.0))
        x = self.noise2(x)
        x = self.activation(x)
        
        return x


class NoiseInjection(nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(1))
    
    def forward(self, x):
        noise = torch.randn(x.shape[0], 1, x.shape[2], x.shape[3], device=x.device)
        # Keep noise influence small
        return x + self.weight * noise * 0.1


class SuperResolutionModule(nn.Module):
    """Super-resolution module: 64x64 -> 512x512"""
    
    def __init__(self, in_channels, output_resolution=512):
        super().__init__()
        # 64x64 -> 512x512 requires 3 upsampling steps (2^3 = 8x)
        self.conv_blocks = nn.Sequential(
            # 64x64 -> 128x128
            nn.Conv2d(in_channels, 128, 3, padding=1),
            nn.LeakyReLU(0.2),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            
            # 128x128 -> 256x256
            nn.Conv2d(128, 64, 3, padding=1),
            nn.LeakyReLU(0.2),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            
            # 256x256 -> 512x512
            nn.Conv2d(64, 32, 3, padding=1),
            nn.LeakyReLU(0.2),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            
            # Final conv to RGB
            nn.Conv2d(32, 3, 3, padding=1),
            nn.Tanh()
        )
    
    def forward(self, x):
        return self.conv_blocks(x)


class Discriminator(nn.Module):
    """StyleGAN2-style Discriminator"""
    
    def __init__(self, resolution: int = 512, channels: int = 3):
        super().__init__()
        
        # Progressive discriminator blocks
        self.blocks = nn.ModuleList([
            DiscriminatorBlock(channels, 64),
            DiscriminatorBlock(64, 128),
            DiscriminatorBlock(128, 256),
            DiscriminatorBlock(256, 512),
            DiscriminatorBlock(512, 512),
        ])
        
        # Adaptive pooling to ensure consistent output size
        self.pool = nn.AdaptiveAvgPool2d((4, 4))
        
        self.final = nn.Sequential(
            nn.Flatten(),
            nn.Linear(512 * 4 * 4, 1)
        )
        
        # Weight initialization
        self.apply(self._init_weights)
    
    def _init_weights(self, m):
        """Proper weight initialization"""
        if isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, 0, 0.02)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Conv2d):
            nn.init.normal_(m.weight, 0, 0.02)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
    
    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        x = self.pool(x)  # Ensure 4x4 spatial size
        return self.final(x)


class DiscriminatorBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.activation = nn.LeakyReLU(0.2)
        self.downsample = nn.AvgPool2d(2)
    
    def forward(self, x):
        x = self.activation(self.conv1(x))
        x = self.activation(self.conv2(x))
        x = self.downsample(x)
        return x


# ==================== Loss Functions ====================

class AdaptiveAugmentation:
    """ADA (Adaptive Discriminator Augmentation)"""
    
    def __init__(self, target_accuracy=0.6, adjustment_speed=0.001):
        self.target = target_accuracy
        self.speed = adjustment_speed
        self.p = 0.0
        self.accuracy_ema = 0.5
    
    def update(self, real_acc):
        """Adjust augmentation strength based on Discriminator accuracy"""
        self.accuracy_ema = 0.99 * self.accuracy_ema + 0.01 * real_acc
        
        if self.accuracy_ema > self.target:
            self.p = min(1.0, self.p + self.speed)
        else:
            self.p = max(0.0, self.p - self.speed)
    
    def __call__(self, images):
        """Apply random augmentations"""
        if self.p == 0.0:
            return images
        
        B, C, H, W = images.shape
        
        # Apply random transformations
        if torch.rand(1).item() < self.p:
            # Horizontal flip
            if torch.rand(1).item() < 0.5:
                images = torch.flip(images, [3])
            
            # 90-degree rotations
            if torch.rand(1).item() < 0.25:
                k = torch.randint(1, 4, (1,)).item()
                images = torch.rot90(images, k, [2, 3])
        
        return images


def r1_regularization(d_real, real_images):
    """R1 Gradient Penalty"""
    grad_real = torch.autograd.grad(
        outputs=d_real.sum(),
        inputs=real_images,
        create_graph=True,
        only_inputs=True
    )[0]
    grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean()
    return grad_penalty


def density_regularization(triplane, lambda_den=0.1):
    """Density Regularization for Tri-plane"""
    # Regularize the density in the tri-plane
    density_loss = torch.abs(triplane).mean()
    return lambda_den * density_loss


def check_nan(tensor, name="tensor"):
    """Helper function for NaN checks"""
    if torch.isnan(tensor).any() or torch.isinf(tensor).any():
        print(f"WARNING: NaN or Inf detected in {name}")
        return True
    return False


# ==================== Training Loop ====================

class Next3DTrainer:
    """Trainer for Next3D Fine-tuning"""
    
    def __init__(
        self,
        generator: nn.Module,
        discriminator: nn.Module,
        device: str = 'cpu',
        lr_g: float = 0.0002,  # Learning rate reduced by 1/10
        lr_d: float = 0.0002,  # Learning rate reduced by 1/10
        lambda_r1: float = 1.0,  # Weakened R1 regularization
        lambda_den: float = 0.1,  # Weakened density regularization
        batch_size: int = 1
    ):
        self.G = generator.to(device)
        self.D = discriminator.to(device)
        self.device = device
        self.lambda_r1 = lambda_r1
        self.lambda_den = lambda_den
        
        # Optimizers (more stable betas)
        self.opt_g = torch.optim.Adam(self.G.parameters(), lr=lr_g, betas=(0.5, 0.999), eps=1e-8)
        self.opt_d = torch.optim.Adam(self.D.parameters(), lr=lr_d, betas=(0.5, 0.999), eps=1e-8)
        
        # ADA
        self.ada = AdaptiveAugmentation()
        
        # Gradient clipping threshold
        self.grad_clip = 1.0
        
        # NaN detection counter
        self.nan_count = 0
        
        print(f"Trainer initialized on device: {device}")
        print(f"  Learning rates: G={lr_g}, D={lr_d}")
        print(f"  Regularization: R1={lambda_r1}, Density={lambda_den}")
        print(f"  Generator parameters: {sum(p.numel() for p in self.G.parameters()):,}")
        print(f"  Discriminator parameters: {sum(p.numel() for p in self.D.parameters()):,}")
    
    def train_step(self, real_data: Dict[str, torch.Tensor]):
        """Single training step"""
        real_images = real_data['image'].to(self.device)
        B = real_images.shape[0]
        
        # ==================== Train Discriminator ====================
        self.opt_d.zero_grad()
        
        # Real images
        real_images_grad = real_images.clone().detach().requires_grad_(True)
        real_aug = self.ada(real_images_grad)
        d_real = self.D(real_aug)
        
        if check_nan(d_real, "d_real"):
            return self._return_nan_metrics()
        
        # Fake images
        z = torch.randn(B, self.G.z_dim, device=self.device) * 0.5  # Lower variance
        
        # Use 3DMM parameters if available
        shape = real_data.get('shape', None)
        exp = real_data.get('exp', None)
        if shape is not None:
            shape = shape.to(self.device)
            exp = exp.to(self.device)
        
        with torch.no_grad():
            fake_output = self.G(z, shape, exp)
        fake_images = fake_output['image'].detach()
        
        if check_nan(fake_images, "fake_images"):
            return self._return_nan_metrics()
        
        fake_aug = self.ada(fake_images)
        d_fake = self.D(fake_aug)
        
        if check_nan(d_fake, "d_fake"):
            return self._return_nan_metrics()
        
        # Adversarial loss (logistic loss)
        d_loss = F.softplus(d_fake).mean() + F.softplus(-d_real).mean()
        
        # R1 regularization
        r1_loss = r1_regularization(d_real, real_images_grad)
        
        if check_nan(r1_loss, "r1_loss"):
            r1_loss = torch.tensor(0.0, device=self.device)
        
        d_loss_total = d_loss + self.lambda_r1 * r1_loss
        
        if check_nan(d_loss_total, "d_loss_total"):
            return self._return_nan_metrics()
        
        d_loss_total.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(self.D.parameters(), self.grad_clip)
        
        self.opt_d.step()
        
        # Update ADA
        real_acc = (d_real.detach() > 0).float().mean().item()
        self.ada.update(real_acc)
        
        # ==================== Train Generator ====================
        self.opt_g.zero_grad()
        
        z = torch.randn(B, self.G.z_dim, device=self.device) * 0.5  # Lower variance
        fake_output = self.G(z, shape, exp)
        fake_images = fake_output['image']
        triplane = fake_output['triplane']
        
        if check_nan(fake_images, "g_fake_images"):
            return self._return_nan_metrics()
        
        fake_aug = self.ada(fake_images)
        d_fake = self.D(fake_aug)
        
        if check_nan(d_fake, "g_d_fake"):
            return self._return_nan_metrics()
        
        # Adversarial loss
        g_loss = F.softplus(-d_fake).mean()
        
        # Density regularization
        den_loss = density_regularization(triplane, self.lambda_den)
        
        if check_nan(den_loss, "den_loss"):
            den_loss = torch.tensor(0.0, device=self.device)
        
        g_loss_total = g_loss + den_loss
        
        if check_nan(g_loss_total, "g_loss_total"):
            return self._return_nan_metrics()
        
        g_loss_total.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(self.G.parameters(), self.grad_clip)
        
        self.opt_g.step()
        
        return {
            'd_loss': d_loss.item(),
            'g_loss': g_loss.item(),
            'r1_loss': r1_loss.item(),
            'den_loss': den_loss.item(),
            'ada_p': self.ada.p,
            'real_acc': real_acc
        }
    
    def _return_nan_metrics(self):
        """Return values when NaN occurs"""
        self.nan_count += 1
        if self.nan_count > 10:
            print("\n" + "="*60)
            print("CRITICAL: Too many NaN occurrences!")
            print("Suggestions:")
            print("1. Reduce learning rate further (e.g., 0.0001)")
            print("2. Check your input data for NaN/Inf values")
            print("3. Use smaller batch size")
            print("4. Restart training from scratch")
            print("="*60 + "\n")
        
        return {
            'd_loss': 0.0,
            'g_loss': 0.0,
            'r1_loss': 0.0,
            'den_loss': 0.0,
            'ada_p': self.ada.p,
            'real_acc': 0.0
        }
    
    def save_checkpoint(self, path: str, epoch: int):
        """Save training checkpoint"""
        torch.save({
            'epoch': epoch,
            'generator': self.G.state_dict(),
            'discriminator': self.D.state_dict(),
            'opt_g': self.opt_g.state_dict(),
            'opt_d': self.opt_d.state_dict(),
            'ada_p': self.ada.p,
        }, path)
        print(f"Checkpoint saved: {path}")
    
    def load_checkpoint(self, path: str):
        """Load training checkpoint"""
        checkpoint = torch.load(path, map_location=self.device)
        self.G.load_state_dict(checkpoint['generator'])
        self.D.load_state_dict(checkpoint['discriminator'])
        self.opt_g.load_state_dict(checkpoint['opt_g'])
        self.opt_d.load_state_dict(checkpoint['opt_d'])
        self.ada.p = checkpoint.get('ada_p', 0.0)
        print(f"Checkpoint loaded: {path}")
        return checkpoint['epoch']


# ==================== Main Training Function ====================

def train_next3d(
    data_dir: str = "./output_styled",
    mesh_dir: Optional[str] = None,
    output_dir: str = "./next3d_checkpoints",
    pretrained_path: Optional[str] = None,
    batch_size: int = 1,
    num_epochs: int = 100,
    save_every: int = 10,
    device: str = 'cpu'
):
    """
    Main function for Next3D fine-tuning
    
    Args:
        data_dir: Directory for styled images (output_styled)
        mesh_dir: Directory for 3DMM parameters
        output_dir: Output path for checkpoints
        pretrained_path: Path to pre-trained model
        batch_size: Batch size (CPU: 1-2, GPU 12GB: 2, GPU 24GB: 4-8)
        num_epochs: Number of epochs
        save_every: Interval for saving checkpoints
        device: Target device ('cpu' or 'cuda')
    """
    os.makedirs(output_dir, exist_ok=True)
    
    # Dataset
    dataset = StyleFaceDataset(
        image_dir=data_dir,
        mesh_dir=mesh_dir,
        resolution=512,
        use_3dmm=(mesh_dir is not None)
    )
    
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0 if device == 'cpu' else 2,  # CPU: 0, GPU: 2-4
        pin_memory=False  # CPU doesn't need pin_memory
    )
    
    # Model
    generator = TriplaneGenerator(
        z_dim=512,
        w_dim=512,
        use_3dmm=(mesh_dir is not None)
    )
    
    discriminator = Discriminator(resolution=512)
    
    # Load pre-trained model
    if pretrained_path and os.path.exists(pretrained_path):
        print(f"Loading pretrained model: {pretrained_path}")
        checkpoint = torch.load(pretrained_path, map_location=device)
        generator.load_state_dict(checkpoint.get('generator', checkpoint))
        print("Pretrained model loaded!")
    
    # Trainer
    trainer = Next3DTrainer(
        generator=generator,
        discriminator=discriminator,
        device=device,
        lr_g=0.0002,  # Reduced learning rate
        lr_d=0.0002,  # Reduced learning rate
        lambda_r1=1.0,  # Weakened R1
        lambda_den=0.1,  # Weakened density loss
        batch_size=batch_size
    )
    
    # Training Loop
    print(f"\n{'='*60}")
    print(f"Starting Training on {device.upper()}")
    if device == 'cpu':
        print("WARNING: Training on CPU will be significantly slower!")
        print("Consider using GPU for faster training.")
    print(f"{'='*60}\n")
    
    for epoch in range(num_epochs):
        pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
        
        epoch_metrics = {
            'd_loss': 0, 'g_loss': 0, 
            'r1_loss': 0, 'den_loss': 0
        }
        
        valid_batches = 0
        
        for i, batch in enumerate(pbar):
            metrics = trainer.train_step(batch)
            
            # Count non-NaN batches
            if metrics['d_loss'] != 0.0 or metrics['g_loss'] != 0.0:
                valid_batches += 1
                # Aggregate metrics
                for k in epoch_metrics:
                    epoch_metrics[k] += metrics[k]
            
            # Update progress bar
            if i % 10 == 0:
                pbar.set_postfix({
                    'D': f"{metrics['d_loss']:.4f}",
                    'G': f"{metrics['g_loss']:.4f}",
                    'ADA': f"{metrics['ada_p']:.3f}"
                })
        
        # End of epoch statistics
        if valid_batches > 0:
            print(f"\nEpoch {epoch+1} Completed (Valid batches: {valid_batches}/{len(dataloader)}):")
            print(f"  D Loss: {epoch_metrics['d_loss']/valid_batches:.4f}")
            print(f"  G Loss: {epoch_metrics['g_loss']/valid_batches:.4f}")
            print(f"  R1 Loss: {epoch_metrics['r1_loss']/valid_batches:.4f}")
            print(f"  Density Loss: {epoch_metrics['den_loss']/valid_batches:.4f}")
        else:
            print(f"\nEpoch {epoch+1}: All batches failed with NaN!")
        
        # Save checkpoint
        if (epoch + 1) % save_every == 0:
            save_path = os.path.join(output_dir, f"checkpoint_{epoch+1:04d}.pt")
            trainer.save_checkpoint(save_path, epoch+1)
    
    # Save final model
    final_path = os.path.join(output_dir, "final_model.pt")
    trainer.save_checkpoint(final_path, num_epochs)
    
    print(f"\n{'='*60}")
    print("Training Complete!")
    print(f"Final model saved at: {final_path}")
    print(f"{'='*60}\n")


if __name__ == "__main__":
    # Settings
    DATA_DIR = "./output_styled"  # Output from step 1
    MESH_DIR = None  # 3DMM parameters (if available)
    OUTPUT_DIR = "./next3d_checkpoints"
    PRETRAINED_PATH = None  # Path to Next3D pre-trained on FFHQ
    
    # CPU/GPU auto-detection
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Recommended settings based on device
    if DEVICE == 'cpu':
        BATCH_SIZE = 1  # Minimal batch size for CPU
        NUM_EPOCHS = 10  # Fewer epochs for testing on CPU
        print("Running on CPU - using minimal settings")
    else:
        BATCH_SIZE = 4  # GPU settings (adjust based on VRAM)
        NUM_EPOCHS = 100
        print("Running on GPU - using standard settings")
    
    # Execute training
    train_next3d(
        data_dir=DATA_DIR,
        mesh_dir=MESH_DIR,
        output_dir=OUTPUT_DIR,
        pretrained_path=PRETRAINED_PATH,
        batch_size=BATCH_SIZE,
        num_epochs=NUM_EPOCHS,
        device=DEVICE
    )


In [None]:

import os
import matplotlib.pyplot as plt
from PIL import Image

def show_image(image_dir):
    image_paths = [
        os.path.join(image_dir, f)
        for f in sorted(os.listdir(image_dir))
        if f.lower().endswith((".png", ".jpg", ".jpeg"))
    ][:6]  
    fig, axes = plt.subplots(2, 3, figsize=(12, 8))
    axes = axes.flatten()
    for ax, img_path in zip(axes, image_paths):
        img = Image.open(img_path)
        ax.imshow(img)
        ax.axis("off")
        ax.set_title(os.path.basename(img_path), fontsize=9)
    for ax in axes[len(image_paths):]:
        ax.axis("off")
    plt.tight_layout()
    plt.show()

show_image('output_styled')