In [1]:
import os
os.environ["HF_HOME"] = "/mnt/disks/celebv-hq/.cache/huggingface"
from pathlib import Path
from datasets import load_dataset
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


## Data preparation

In [6]:
data_dir = Path("data/vox2/")
audio = data_dir / "aac" / "dev" / "aac"
video = data_dir / "mp4" / "dev" / "mp4"
txt = data_dir / "txt"

speaker_ids = sorted([d.name for d in audio.iterdir() if d.is_dir()])

In [None]:
# Create a list to store all audio-video pairs
audio_video_pairs = []

# Iterate through each ID
for id_name in speaker_ids:
    # Get all scene directories for this ID
    audio_id_dir = audio / id_name
    video_id_dir = video / id_name
    txt_id_dir = txt / id_name
    
    # Ensure both directories exist
    if not audio_id_dir.exists() or not video_id_dir.exists():
        continue
    
    # Get all scene directories
    audio_scenes = [d for d in audio_id_dir.iterdir() if d.is_dir()]
    
    # Iterate through each scene
    for audio_scene in audio_scenes:
        scene_name = audio_scene.name
        video_scene = video_id_dir / scene_name
        
        # Ensure the corresponding video scene directory exists
        if not video_scene.exists():
            continue
        
        # Get all audio files in this scene
        audio_files = list(audio_scene.glob("*.m4a"))
        
        # For each audio file, find the corresponding video file
        for audio_file in audio_files:
            # Get the base name without extension
            base_name = audio_file.stem
            # Look for the corresponding video file
            video_file = video_scene / f"{base_name}.mp4"
            
            # If the video file exists, add the pair to our list
            if video_file.exists():
                audio_video_pairs.append((audio_file, video_file))

print(f"Found {len(audio_video_pairs)} audio-video pairs")

In [None]:
audio_video_pairs[0]

In [None]:
# Create a directory to store all audio and video files
import shutil

output_dir = Path("data/combined_files")
os.makedirs(output_dir, exist_ok=True)

# Copy all files to the output directory
for audio_file, video_file in audio_video_pairs:
    # Extract filenames
    scene = audio_file.parent.name
    speaker_id = audio_file.parent.parent.name
    audio_filename = f"{speaker_id}_{scene}_{audio_file.name}"
    video_filename = f"{speaker_id}_{scene}_{video_file.name}"
    
    # Copy files to the output directory
    shutil.copy2(audio_file, output_dir / audio_filename)
    shutil.copy2(video_file, output_dir / video_filename)

print(f"Copied {len(audio_video_pairs)} audio files and {len(audio_video_pairs)} video files to {output_dir}")

# Display the ori

In [2]:
# Load the dataset and store it in the workspace directory
ds = load_dataset("acul3/voxceleb2", cache_dir="/mnt/disks/celebv-hq/.cache/huggingface")

Downloading data:   0%|          | 1/215 [00:05<17:54,  5.02s/files]


KeyboardInterrupt: 

In [None]:
ds['train'][0]

In [None]:
from concurrent.futures import ProcessPoolExecutor
import numpy as np
from tqdm.auto import tqdm

def process_range(range_tuple):
    start_idx, end_idx = range_tuple
    results = []
    for i in range(start_idx, end_idx):
        entry = ds['train'][i]
        speaker_id = entry['speaker_id']
        if i % 1000 == 0:
            print(f"Processing entry {i} of {len(ds['train'])}")
        if speaker_id in speaker_ids:
            results.append({"speaker_id": speaker_id, "audio_path": entry['audio_path']['path'], "transcription": entry['transcription'], "gender": entry['gender']})
    return results

# Define the ranges directly without creating batches first
total_items = len(ds['train'])
num_workers = 4
chunk_size = total_items // num_workers
ranges = [(i, min(i + chunk_size, total_items)) for i in range(0, total_items, chunk_size)]

print(f"Processing {total_items} items in {len(ranges)} chunks")

speaker_videos = []
# Process ranges in parallel
with ProcessPoolExecutor(max_workers=num_workers) as executor:
    results = list(tqdm(
        executor.map(process_range, ranges),
        total=len(ranges),
        desc="Processing data chunks"
    ))
    
    # Flatten results
    for chunk_result in results:
        speaker_videos.extend(chunk_result)

print(f"Found {len(speaker_videos)} matching speaker videos")


In [15]:
speaker_videos = sorted(speaker_videos, key=lambda x: x['speaker_id'])

In [None]:
speaker_videos[1]

In [17]:
all_video_paths = sorted([file for file in os.listdir('data/combined_files') if file.endswith('.mp4')])

In [18]:
video_path_and_transcripts = []

for video in speaker_videos:
    speaker_id = video['speaker_id']
    audio_path = video['audio_path']
    transcription = video['transcription']
    
    video_path = audio_path.replace('.wav', '.mp4')
    file_name = speaker_id + "_" + video_path
    video_path_and_transcripts.append({"file_name": file_name, "text": transcription.strip()})

In [19]:
video_paths = {vid_path['file_name']: vid_path['text'] for vid_path in video_path_and_transcripts}

In [21]:
cleaned_paths = []
for vid_path in all_video_paths:
    speaker_id, scene, vid_id = vid_path.split('_')[:3]
    target_id = f"{speaker_id}_{vid_id}"
    
    if target_id in video_paths:
        cleaned_paths.append({"file_name": vid_path, "text": video_paths[target_id].strip()})

In [33]:
cleaned_paths = sorted(cleaned_paths, key=lambda x: x['file_name'])

In [34]:
videos_df = pd.DataFrame.from_dict(cleaned_paths)

In [35]:
videos_df.to_csv("data/metadata.csv", index=False)

### CelebV-HQ

In [None]:
ds_path = Path("dataset/metadata.csv")
ds = pd.read_csv(ds_path)

ds['file'] = ds['file'].apply(lambda x: x.replace('.wav', '.mp4'))
ds['transcript_with_timestamps'] = ds['transcript_with_timestamps'].apply(lambda x: x.strip('"').strip())

# Rename columns
ds.rename(columns={'file': 'file_name', 'transcript_with_timestamps': 'text'}, inplace=True)

In [5]:
ds.to_csv("dataset/metadata.csv", index=False)

## Finetuning

In [None]:
import torch
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData


model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
model_manager.load_models([
    "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
    "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
    "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
])
model_manager.load_lora("models/lightning_logs/version_2/checkpoints/epoch=4-step=625.ckpt", lora_alpha=1.0)
pipe = WanVideoPipeline.from_model_manager(model_manager, device="cuda")
pipe.enable_vram_management(num_persistent_param_in_dit=None)

video = pipe(
    prompt="i was telling my friend about how much i wanted to see him",
    negative_prompt="low quality, unclear facial expressions, blurry",
    num_inference_steps=50,
    seed=0, tiled=True
)
save_video(video, "video.mp4", fps=30, quality=5)

In [None]:
from diffusers import AutoencoderKLWan
import torch
import os

In [4]:
model_id = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers"
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)

latents_list = []

for latents in os.listdir("data/train"):
    if latents.endswith(".pth"):
        latents_list.append(torch.load(os.path.join("data/train", latents)))

In [28]:
vae.to('cuda')
latents_tensor = torch.stack([latent["latents"] for latent in latents_list])
latents_tensor = latents_tensor.to(vae.device)
latents_tensor = latents_tensor.to(vae.dtype)

In [None]:
with torch.no_grad():
    decoded_images = vae.decode(latents_tensor).sample

In [None]:
print(decoded_images.shape)

# Create videos from each decoded image sequence
import os
from diffsynth import save_video

# Create output directory if it doesn't exist
os.makedirs("output_videos", exist_ok=True)

# Iterate over each sequence in the batch (dim=0)
for i, image_sequence in enumerate(decoded_images):
    # image_sequence shape is [frames, channels, height, width]
    # Convert to the format expected by save_video
    video = image_sequence.permute(1, 2, 3, 0)  # [frames, height, width, channels]
    
    # Ensure values are in [0, 1] range
    if video.min() < 0 or video.max() > 1:
        video = (video + 1.0) / 2.0  # Convert from [-1, 1] to [0, 1] if needed
    
    # Clamp values to ensure they're in valid range
    video = torch.clamp(video, 0, 1)
    
    # Save the video
    output_path = os.path.join("output_videos", f"video_{i}.mp4")
    save_video(video.cpu(), output_path, fps=30, quality=5)
    
    print(f"Saved video {i+1}/{len(decoded_images)} to {output_path}")
