# FF++ Staged Fine-tuning (Robust Version)

**Data**: 17GB FF++ dataset (7011 files)

**Stages:**
- **A**: Head-only stabilization (2 epochs)
- **B**: Partial unfreeze - layer4 (8 epochs)
- **C**: Optional deeper unfreeze (5 epochs)

**Estimated Time**: ~2-3 hours total on T4 GPU

In [None]:
# 1. Environment Check & Setup
import os
import sys

# Check GPU first - fail fast if not available
import torch
if not torch.cuda.is_available():
    print("ERROR: No GPU detected!")
    print("Go to: Runtime > Change runtime type > Hardware accelerator > GPU")
    raise SystemExit("GPU required")

gpu_name = torch.cuda.get_device_name(0)
gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
print(f"GPU: {gpu_name} ({gpu_mem:.1f} GB)")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.version.cuda}")

In [None]:
# 2. Mount Drive & Clone Repo
from google.colab import drive
drive.mount('/content/drive')

# Clone fresh repo
!rm -rf /content/Team-Converge
!git clone https://github.com/Incharajayaram/Team-Converge.git /content/Team-Converge
%cd /content/Team-Converge/Finetune1

# Verify we're in the right place
assert os.path.exists('config.yaml'), "ERROR: config.yaml not found!"
assert os.path.exists('train_staged.py'), "ERROR: train_staged.py not found!"
print("Repo cloned successfully!")

In [None]:
# 3. Install Dependencies
!pip install -q mediapipe pyyaml tqdm gdown scikit-learn

# Verify critical imports
try:
    import mediapipe
    import yaml
    import gdown
    from sklearn.metrics import roc_auc_score
    print("All dependencies installed!")
except ImportError as e:
    print(f"ERROR: Missing dependency - {e}")
    raise

In [None]:
# 4. Download FF++ Data from Drive (17GB - takes 5-10 min)
import os
import subprocess

# Updated FILE_ID - Jan 2026
FILE_ID = "1a7X9Cjv3gsj4qC6kcDq6VLNl7eR3osoy"
ZIP_PATH = "/content/ffpp_data.zip"
EXPECTED_SIZE_GB = 17.0

# Remove any previous partial/corrupted download
if os.path.exists(ZIP_PATH):
    print(f"Removing existing {ZIP_PATH}...")
    os.remove(ZIP_PATH)

# Upgrade gdown for large file support
!pip install -q --upgrade gdown

print(f"Downloading {EXPECTED_SIZE_GB}GB file from Google Drive...")
print("This will take 5-10 minutes depending on connection speed.")

# Download with explicit ID flag for large files
!gdown --id {FILE_ID} --output {ZIP_PATH} --fuzzy

# Verify download
if not os.path.exists(ZIP_PATH):
    print("\nERROR: Download failed! File not found.")
    print("\nTry manual method:")
    print("1. Upload ffpp_data_new.zip to your Google Drive")
    print("2. Run: !cp '/content/drive/MyDrive/ffpp_data_new.zip' /content/ffpp_data.zip")
    raise FileNotFoundError("Download failed")

size_gb = os.path.getsize(ZIP_PATH) / 1e9
print(f"\nDownloaded: {size_gb:.2f} GB")

if size_gb < 15:
    print("\nWARNING: File too small! Download may have failed.")
    print("The file might be an HTML error page.")
    print("\nCheck file content:")
    !head -c 200 {ZIP_PATH}
    print("\n\nTry manual method:")
    print("!cp '/content/drive/MyDrive/ffpp_data_new.zip' /content/ffpp_data.zip")
    raise ValueError(f"File too small: {size_gb:.2f} GB")
else:
    print("Download OK!")

In [None]:
# 5. Extract Data (takes 2-3 min for 17GB)
import zipfile

EXTRACT_PATH = "/content/data/raw/ffpp"
ZIP_PATH = "/content/ffpp_data.zip"

# Clean previous extraction
!rm -rf {EXTRACT_PATH}
!mkdir -p {EXTRACT_PATH}

print("Extracting (this takes 2-3 minutes)...")

# Test zip integrity first
try:
    with zipfile.ZipFile(ZIP_PATH, 'r') as zf:
        # Quick test - just read the file list
        file_count = len(zf.namelist())
        print(f"Zip contains {file_count} files")
except zipfile.BadZipFile:
    print("\nERROR: Zip file is corrupted!")
    print("The download was incomplete or corrupted.")
    print("\nSolutions:")
    print("1. Re-run the download cell")
    print("2. Or manually copy: !cp '/content/drive/MyDrive/ffpp_data_new.zip' /content/ffpp_data.zip")
    raise

# Extract using unzip (faster than Python for large files)
!unzip -q {ZIP_PATH} -d {EXTRACT_PATH}

# Verify extraction
items = os.listdir(EXTRACT_PATH)
print(f"\nExtracted {len(items)} top-level items: {items}")

# Find the FF++ data folder
ffpp_folder = None
for item in items:
    if 'FaceForensics' in item or 'ffpp' in item.lower():
        ffpp_folder = os.path.join(EXTRACT_PATH, item)
        break

if ffpp_folder:
    print(f"FF++ data found at: {ffpp_folder}")
    # Update EXTRACT_PATH for training
    FFPP_ROOT = ffpp_folder
else:
    FFPP_ROOT = EXTRACT_PATH
    
# Count total videos
total_videos = 0
for root, dirs, files in os.walk(FFPP_ROOT):
    total_videos += sum(1 for f in files if f.endswith('.mp4'))
print(f"Total video files: {total_videos}")

In [None]:
# 6. Verify Data Structure
import os

# Auto-detect the correct ffpp root
possible_roots = [
    "/content/data/raw/ffpp",
    "/content/data/raw/ffpp/FaceForensics++_C23",
    "/content/data/raw/ffpp/ffpp_data"
]

FFPP_ROOT = None
for root in possible_roots:
    # Check if this folder has the expected structure
    if os.path.exists(root):
        subdirs = os.listdir(root) if os.path.isdir(root) else []
        # Look for manipulation method folders or video files
        has_videos = any(f.endswith('.mp4') for f in subdirs)
        has_method_folders = any(d in subdirs for d in ['Deepfakes', 'Face2Face', 'FaceSwap', 'NeuralTextures', 'original', 'DeepFakeDetection'])
        if has_videos or has_method_folders or 'csv' in subdirs:
            FFPP_ROOT = root
            break

if FFPP_ROOT is None:
    print("WARNING: Could not auto-detect FF++ root. Using default.")
    FFPP_ROOT = "/content/data/raw/ffpp"
    
print(f"Using FFPP_ROOT: {FFPP_ROOT}")
print(f"Contents: {os.listdir(FFPP_ROOT)[:10]}...")

# Save for later cells
os.environ['FFPP_ROOT'] = FFPP_ROOT

In [None]:
# 7. Create output directory on Drive
OUTPUT_DIR = "/content/drive/MyDrive/ffpp_training/staged"
!mkdir -p {OUTPUT_DIR}

# Also create local cache dir
CACHE_DIR = "/content/cache/faces"
!mkdir -p {CACHE_DIR}

print(f"Output will be saved to: {OUTPUT_DIR}")
print(f"Face cache: {CACHE_DIR}")

# Check available disk space
!df -h /content | tail -1

In [None]:
# 8. Run Stage A: Head-only stabilization (2 epochs)
import os
FFPP_ROOT = os.environ.get('FFPP_ROOT', '/content/data/raw/ffpp')

print("="*60)
print("STAGE A: Head-only training (2 epochs)")
print("="*60)

!python train_staged.py --config config.yaml \
    --override dataset.ffpp_root={FFPP_ROOT} \
    --override caching.cache_dir=/content/cache/faces \
    --stages A \
    --output_dir /content/drive/MyDrive/ffpp_training/staged

In [None]:
# 9. Run Stage B: Partial unfreeze - layer4 (8 epochs)
import os
FFPP_ROOT = os.environ.get('FFPP_ROOT', '/content/data/raw/ffpp')

# Check if Stage A checkpoint exists
checkpoint = "/content/drive/MyDrive/ffpp_training/staged/best_model.pt"
if not os.path.exists(checkpoint):
    print(f"ERROR: Checkpoint not found at {checkpoint}")
    print("Make sure Stage A completed successfully!")
    raise FileNotFoundError(checkpoint)

print("="*60)
print("STAGE B: Partial unfreeze - layer4 (8 epochs)")
print("="*60)

!python train_staged.py --config config.yaml \
    --override dataset.ffpp_root={FFPP_ROOT} \
    --override caching.cache_dir=/content/cache/faces \
    --stages B \
    --resume {checkpoint} \
    --output_dir /content/drive/MyDrive/ffpp_training/staged

In [None]:
# 10. (Optional) Stage C: Deeper unfreeze - run only if Stage B plateaus
# Uncomment the lines below to run Stage C

# import os
# FFPP_ROOT = os.environ.get('FFPP_ROOT', '/content/data/raw/ffpp')
# 
# print("="*60)
# print("STAGE C: Deeper unfreeze (5 epochs)")
# print("="*60)
# 
# !python train_staged.py --config config.yaml \
#     --override dataset.ffpp_root={FFPP_ROOT} \
#     --override caching.cache_dir=/content/cache/faces \
#     --stages C \
#     --resume /content/drive/MyDrive/ffpp_training/staged/best_model.pt \
#     --output_dir /content/drive/MyDrive/ffpp_training/staged

In [None]:
# 11. View Training History
import json
import matplotlib.pyplot as plt
import os

history_path = '/content/drive/MyDrive/ffpp_training/staged/training_history.json'

if not os.path.exists(history_path):
    print(f"Training history not found at {history_path}")
    print("Training may not have completed yet.")
else:
    with open(history_path) as f:
        history = json.load(f)

    epochs = [h['epoch'] for h in history]
    train_loss = [h['train_loss'] for h in history]
    val_loss = [h['val_loss'] for h in history]
    val_auc = [h.get('val_auc', 0.5) for h in history]

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

    ax1.plot(epochs, train_loss, 'b-', label='Train Loss', linewidth=2)
    ax1.plot(epochs, val_loss, 'r-', label='Val Loss', linewidth=2)
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.set_title('Loss Curves')
    ax1.grid(True, alpha=0.3)

    ax2.plot(epochs, val_auc, 'g-', label='Val AUC', linewidth=2)
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('AUC')
    ax2.set_ylim(0.5, 1.0)
    ax2.legend()
    ax2.set_title('Validation AUC')
    ax2.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig('/content/drive/MyDrive/ffpp_training/training_curves.png', dpi=150)
    plt.show()

    print(f"\nBest val_loss: {min(val_loss):.4f} (epoch {epochs[val_loss.index(min(val_loss))]})")
    print(f"Best val_auc: {max(val_auc):.4f} (epoch {epochs[val_auc.index(max(val_auc))]})")

In [None]:
# 12. Final Summary & Save
import os
import shutil

output_dir = '/content/drive/MyDrive/ffpp_training/staged'
final_model = '/content/drive/MyDrive/ffpp_training/final_model.pt'

# Copy best model to final location
best_model = os.path.join(output_dir, 'best_model.pt')
if os.path.exists(best_model):
    shutil.copy(best_model, final_model)
    size_mb = os.path.getsize(final_model) / 1e6
    print(f"Final model saved: {final_model} ({size_mb:.1f} MB)")
else:
    print(f"WARNING: Best model not found at {best_model}")

# List all output files
print("\nOutput files:")
for f in os.listdir(output_dir):
    path = os.path.join(output_dir, f)
    size = os.path.getsize(path) / 1e6 if os.path.isfile(path) else 0
    print(f"  {f}: {size:.1f} MB" if size > 0 else f"  {f}/")

print("\n" + "="*60)
print("TRAINING COMPLETE!")
print("="*60)