In [None]:
import sys
import torch
import platform

print(f"Python version: {sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU: {torch.cuda.get_device_name(0)}")

!nvidia-smi

In [None]:
import os
import torch
import sys
import platform

# 1. Clone Repository (Safe to re-run)
if not os.path.exists("InfiniteTalk"):
    !git clone https://github.com/MeiGen-AI/InfiniteTalk
else:
    print("Repository already cloned.")

# 2. Enter Directory (Required after every restart)
%cd InfiniteTalk

# 3. Safe Dependency Installation
py_ver = f"cp{sys.version_info.major}{sys.version_info.minor}"
torch_ver = torch.__version__.split('+')[0]
torch_ver_short = ".".join(torch_ver.split(".")[:2])
cuda_ver = torch.version.cuda.replace('.', '')[:3] if torch.version.cuda else "cpu"
abi = "TRUE" if torch._C._GLIBCXX_USE_CXX11_ABI else "FALSE"

print(f"✅ Detected: Python={py_ver}, Torch={torch_ver}, CUDA={cuda_ver}, ABI={abi}")
print("Installing dependencies... (Fast mode enabled)")

# Install dependencies needed for extensions
!pip install --no-cache-dir packaging ninja psutil wheel

# Install Flash Attention 2 (Dynamic Wheel Selection)
try:
    import flash_attn
    print("✅ Flash Attention already installed.")
except ImportError:
    print("Installing Flash Attention...")
    
    # URL Selection Logic
    url = None
    
    # CASE 1: PyTorch 2.9 (Nightly) - Use community wheel
    if torch_ver.startswith("2.9"):
        print("⚠️ Detected PyTorch 2.9 (Nightly). Using community pre-built wheel...")
        # Source: https://github.com/mjun0812/flash-attention-prebuild-wheels
        url = "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.7.16/flash_attn-2.6.3+cu126torch2.9-cp311-cp311-linux_x86_64.whl"
    
    # CASE 2: Standard Versions - Try to guess official URL
    if not url:
        fa_v = "2.8.3"
        cu_tag = f"cu{cuda_ver[:2]}" # e.g., cu12
        wheel = f"flash_attn-{fa_v}+{cu_tag}torch{torch_ver_short}cxx11abi{abi}-{py_ver}-{py_ver}-linux_x86_64.whl"
        url = f"https://github.com/Dao-AILab/flash-attention/releases/download/v{fa_v}/{wheel}"

    try:
        print(f"Attempting to install: {url}")
        # Using os.system to avoid notebook capture of massive output if needed, or just standard !
        res = os.system(f"pip install --no-cache-dir {url}")
        if res != 0:
            raise Exception("Pip install returned non-zero exit code")
        print(f"✅ Successfully installed Flash Attention via wheel")
    except Exception as e:
        print(f"⚠️ Wheel install failed: {e}")
        print("Falling back to slow source build (This may take 45+ mins)...")
        !pip install flash-attn --no-build-isolation

# Install other dependencies from requirements.txt
!sed -i 's/numpy>=1.23.5,<2/numpy>=1.23.5/' requirements.txt
!pip install --no-cache-dir -r requirements.txt
!pip install --no-cache-dir librosa huggingface_hub

# Install system dependencies
!apt-get update && apt-get install -y ffmpeg > /dev/null 2>&1

print("\n✅ Setup Complete! If you see a 'Restart Session' button, you can IGNORE it and try running the next cell.")

In [None]:
import os
os.makedirs("weights", exist_ok=True)

# Download base model and InfiniteTalk weights
!hf download Wan-AI/Wan2.1-I2V-14B-480P --local-dir ./weights/Wan2.1-I2V-14B-480P
!hf download MeiGen-AI/InfiniteTalk --local-dir ./weights/InfiniteTalk

# Download English audio encoder (DEFAULT)
!hf download facebook/wav2vec2-base-960h --local-dir ./weights/wav2vec2-base-960h

# Alternative: For Chinese audio support, uncomment the lines below:
# !hf download TencentGameMate/chinese-wav2vec2-base --local-dir ./weights/chinese-wav2vec2-base
# !hf download TencentGameMate/chinese-wav2vec2-base model.safetensors --revision refs/pr/1 --local-dir ./weights/chinese-wav2vec2-base

In [None]:
# Enable public link for Gradio
!sed -i 's/demo.launch(server_name="0.0.0.0", debug=True, server_port=8418)/demo.launch(server_name="0.0.0.0", debug=True, server_port=8418, share=True)/' app.py

# Launch Gradio interface with English audio encoder (DEFAULT)
!python app.py --ckpt_dir weights/Wan2.1-I2V-14B-480P --wav2vec_dir 'weights/wav2vec2-base-960h' --infinitetalk_dir weights/InfiniteTalk/single/infinitetalk.safetensors --num_persistent_param_in_dit 0 --motion_frame 9

# Alternative: For Chinese audio, uncomment below and comment the English line above:
# !python app.py --ckpt_dir weights/Wan2.1-I2V-14B-480P --wav2vec_dir 'weights/chinese-wav2vec2-base' --infinitetalk_dir weights/InfiniteTalk/single/infinitetalk.safetensors --num_persistent_param_in_dit 0 --motion_frame 9