# Fine-tuning Diff2Lip on Custom Data

This notebook shows how to fine-tune the mouth-region diffusion model (Diff2Lip)
on your own paired video+audio dataset for improved lip-sync fidelity.

---

## 1. Setup

In [None]:
# Clone Diff2Lip repo if not already present
!git clone https://github.com/YuanGary/DiffusionLi.git diff2lip
%cd diff2lip

# Install dependencies
!pip install -r requirements.txt
!pip install torch torchvision diffusers accelerate transformers

# Create a workspace directory
!mkdir -p ../workspace/data
!mkdir -p ../workspace/checkpoints

## 2. Prepare Your Dataset

Place your training data under `../workspace/data` in this structure:

```
workspace/data/
└── train/
    ├── video1.mp4
    ├── video1.wav
    ├── video2.mp4
    ├── video2.wav
    └── ...
```
Each `videoN.mp4` should be a talking-head clip; `videoN.wav` the exact audio.

## 3. Data Loader

In [None]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
import cv2
import torchaudio

class Diff2LipDataset(Dataset):
    def __init__(self, data_dir, fps=25, crop_size=256):
        self.samples = []
        for file in os.listdir(os.path.join(data_dir, "train")):
            if file.endswith(".mp4"):
                vid = os.path.join(data_dir, "train", file)
                wav = vid.replace(".mp4", ".wav")
                if os.path.exists(wav):
                    self.samples.append((vid, wav))
        self.fps = fps
        self.crop = crop_size

    def __len__(self): return len(self.samples)
    def __getitem__(self, idx):
        vid_path, wav_path = self.samples[idx]
        # Read video frames
        cap = cv2.VideoCapture(vid_path)
        frames = []
        while True:
            ret, frame = cap.read()
            if not ret: break
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            # center-crop
            h, w, _ = frame.shape
            m = min(h, w)
            y0, x0 = (h-m)//2, (w-m)//2
            frm = frame[y0:y0+m, x0:x0+m]
            frm = cv2.resize(frm, (self.crop, self.crop))
            frames.append(frm)
        cap.release()
        # Load audio later in collate
        return {"frames": torch.tensor(np.array(frames)).permute(0,3,1,2)/255., "audio": wav_path}

def collate_fn(batch):
    vids = [x["frames"] for x in batch]
    auds = [x["audio"] for x in batch]
    # Pad video sequences
    padded_vids = torch.nn.utils.rnn.pad_sequence(vids, batch_first=True)
    return padded_vids, auds

dataset = Diff2LipDataset("../workspace/data")
loader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)

## 4. Model & Optimizer

In [None]:
from diff2lip.diff2lip import Diff2Lip

device = "cuda" if torch.cuda.is_available() else "cpu"
model = Diff2Lip(device=device).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

## 5. Training Loop

In [None]:
num_epochs = 3
for epoch in range(num_epochs):
    for i, (video_batch, audio_paths) in enumerate(loader):
        video_batch = video_batch.to(device)  # shape [B,T,3,H,W]
        # For each exemplar take first frame as “driving”
        driving_frame = video_batch[:,0]
        
        # Load corresponding audio into tensor
        aud_tensors = []
        for ap in audio_paths:
            wav, sr = torchaudio.load(ap)
            aud_tensors.append(wav.mean(0))  # mono
        aud_batch = torch.nn.utils.rnn.pad_sequence(aud_tensors, batch_first=True).to(device)

        loss = model.training_step(driving_frame, aud_batch, video_batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Epoch {epoch+1}/{num_epochs}, Batch {i+1}/{len(loader)}, Loss: {loss.item():.4f}")

## 6. Save Fine-tuned Weights

In [None]:
os.makedirs("../workspace/checkpoints", exist_ok=True)
save_path = "../workspace/checkpoints/diff2lip_finetuned.pth"
torch.save(model.state_dict(), save_path)
print(f"Saved to {save_path}")

## 7. Inference with Fine-tuned Model

In [None]:
# Load your fine-tuned weights
model.load_state_dict(torch.load("../workspace/checkpoints/diff2lip_finetuned.pth"))
model.eval()

# Use same inference API as before
# Note: The original 'render_diff2lip' function is not exposed in the provided code.
# We will assume the model object has a 'render' method for this example.
try:
    model.render(
        video_in="../output_fomm.mp4", # Assuming output from previous notebook
        audio_in="../assets/hello.wav", # Assuming assets from previous notebook
        video_out="../output_diff2lip_finetuned.mp4",
        upscale_factor=1
    )

    from IPython.display import display, HTML
    from base64 import b64encode
    mp4 = open("../output_diff2lip_finetuned.mp4",'rb').read()
    data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
    display(HTML(f'<video controls width=480 src="{data_url}">'))
except AttributeError:
    print("Inference function 'render' not found in the Diff2Lip model object.")
    print("Please adapt this cell to the correct inference API for your fine-tuned model.")


You’ve now fine-tuned Diff2Lip on your own dataset for personalized, high-fidelity lip syncing. Integrate these checkpoints into your Avatar Renderer Pod pipeline for production-grade avatar videos!