# An sample to finetune wave2lip on SageMaker

In [None]:
## Update sagemaker python sdk version
!pip install -U sagemaker

In [None]:
import sagemaker
import boto3
from sagemaker import get_execution_role

sess = sagemaker.Session()
role = get_execution_role()
sagemaker_default_bucket = sess.default_bucket()

account = sess.boto_session.client("sts").get_caller_identity()["Account"]
region = sess.boto_session.region_name

In [None]:
## download training script from github
!rm -rf ./wav2lip_288x288
!git clone https://github.com/whn09/wav2lip_288x288.git
!cp ./s5cmd ./wav2lip_288x288/

## Download pretrained model(expert Discriminator & face detect) and upload to s3

In [None]:
!cd ./wav2lip_288x288/face_detection/detection/sfd/ && wget https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth
!mv ./wav2lip_288x288/face_detection/detection/sfd/s3fd-619a316812.pth ./wav2lip_288x288/face_detection/detection/sfd/s3fd.pth
# !cd ./Wav2Lip/models && wget https://iiitaphyd-my.sharepoint.com/:u:/g/personal/radrabha_m_research_iiit_ac_in/EQRvmiZg-HRAjvI6zqN9eTEBP74KefynCwPWVmF57l-AYA?e=ZRPHKP

## Prepare docker image

In [None]:
%%writefile Dockerfile
## You should change below region code to the region you used, here sample is use us-west-2
From 763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-training:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04 
#From pytorch/pytorch:1.5-cuda10.1-cudnn7-runtime

ENV LANG=C.UTF-8
ENV PYTHONUNBUFFERED=TRUE
ENV PYTHONDONTWRITEBYTECODE=TRUE

RUN  apt-get update
RUN  echo "Y"|apt-get install ffmpeg
RUN  pip3 install deepspeed


In [None]:
## You should change below region code to the region you used, here sample is use us-west-2
!aws ecr get-login-password --region us-west-2 | docker login --username AWS --password-stdin 763104351884.dkr.ecr.us-west-2.amazonaws.com

### Build image and push to ECR.**

In [None]:
## define repo name, should contain *sagemaker* in the name
repo_name = "sagemaker-wav2lip-288x288-demo"

In [None]:
%%script env repo_name=$repo_name bash

#!/usr/bin/env bash

# This script shows how to build the Docker image and push it to ECR to be ready for use
# by SageMaker.

# The argument to this script is the image name. This will be used as the image on the local
# machine and combined with the account and region to form the repository name for ECR.
# The name of our algorithm
algorithm_name=${repo_name}

account=$(aws sts get-caller-identity --query Account --output text)

# Get the region defined in the current configuration (default to us-west-2 if none defined)
region=$(aws configure get region)
region=${region:-us-east-1}

fullname="${account}.dkr.ecr.${region}.amazonaws.com/${algorithm_name}:latest"

# If the repository doesn't exist in ECR, create it.
aws ecr describe-repositories --repository-names "${algorithm_name}" > /dev/null 2>&1

if [ $? -ne 0 ]
then
    aws ecr create-repository --repository-name "${algorithm_name}" > /dev/null
fi

# Get the login command from ECR and execute it directly
aws ecr get-login-password --region ${region}|docker login --username AWS --password-stdin ${fullname}

# Build the docker image locally with the image name and then push it to ECR
# with the full name.

docker build -t ${algorithm_name} .
docker tag ${algorithm_name} ${fullname}

docker push ${fullname}

### Prepare train dataset

use s3 dataset path, which is aligned with data_root of wav2lip 
should be like:  
data_root (mvlrs_v1)  
├── main, pretrain (we use only main folder in this work)  
|	├── list of folders  
|	│   ├── five-digit numbered video IDs ending with (.mp4)  

In [None]:
#!wget https://bj.bcebos.com/ai-studio-online/88f38e14dc9f4893bdb3ad6857b810a9e0ed6d63f4f84d7da0b69e165ca4f7a5?authorization=bce-auth-v1%2F5cfe9a5e1454405eb2a975c43eace6ec%2F2022-09-04T15%3A27%3A21Z%2F-1%2F%2F2048bfe72ab62c84e2a6622f565ac7cc30830ffdace4607bf847909d1038b5b0&responseContentDisposition=attachment%3B%20filename%3Dmain.zip
!./s5cmd sync ./main2/ s3://{sagemaker_default_bucket}/wav2lip_288x288/train_data/main2/

In [None]:
### 准备filelists/train.txt
import os

import time
from glob import glob
import shutil,os
 
from sklearn.model_selection import train_test_split
 

# 去除名字的特殊符号，统一序号视频文件命名
 
# def original_video_name_format():
#     base_path = "../LSR2/main"
#     result = list(glob("{}/*".format(base_path),recursive=False))
#     file_num = 0
#     result_list = []
 
#     for each in result:
#         file_num +=1
#         new_position ="{0}{1}".format( int(time.time()),file_num)
#         result_list.append(new_position)
#         shutil.move(each, os.path.join(base_path,new_position+".mp4"))
#         pass

def trained_data_name_format():
    base_path = "../LSR2/lrs2_preprocessed"
    # result = list(glob("{}/*".format(base_path)))
    result = os.listdir(base_path)
    print(result)
    result_list = []
    for i,dirpath in enumerate(result):
        # shutil.move(dirpath,"{0}/{1}".format(base_path,i))
        # result_list.append(str(i))
        # print('dirpath:', dirpath)
        result_list.append(dirpath)
    if len(result_list)<14:
        test_result=val_result=train_result=result_list
    else:
        train_result,test_result = train_test_split(result_list,test_size=0.15, random_state=42)
        test_result, val_result = train_test_split(test_result, test_size=0.5, random_state=42)
 
    for file_name,dataset in zip(("train.txt","test.txt","val.txt"),(train_result,test_result,val_result)):
        with open(os.path.join("filelists",file_name),'w',encoding='utf-8') as fi:
            for dataset_i in dataset:
                # print('dataset_i:', dataset_i)
                video_result = os.listdir(os.path.join(base_path, dirpath))
                # print('video_result:', video_result)
                video_result = [dataset_i+'/'+video for video in video_result]
                fi.write("\n".join(video_result))
                fi.write("\n")
 
    # print("\n".join(result_list))

trained_data_name_format()

### for local data process test

In [None]:
!pip install ffmpeg-python

In [None]:
###data process
#!pip install -r ./requirements.txt
!cd ./Wav2Lip/ && python preprocess.py --data_root ../main2 --preprocessed_root ../lrs2_preprocessed2/

In [None]:
!./s5cmd sync ./lrs2_preprocessed2/ s3://{sagemaker_default_bucket}/288x288/train_data/lrs2_preprocessed/

### 开始训练(单机单卡）

In [71]:
%%writefile ./wav2lip_288x288/train.sh
#!/bin/bash

chmod +x ./s5cmd
./s5cmd sync s3://${sagemaker_default_bucket}/wav2lip_288x288/train_data/main2/* /tmp/main2/
pip install -r ./requirements.txt

echo "--------preprocess wav-----------"
python ./preprocess.py --data_root /tmp/main2 --preprocessed_root /tmp/lrs2_preprocessed2/
echo "finished preprocess"

echo "--------prepare tain list--------"
python ./generate_filelists.py
echo "finished train list prepare"

echo "--------train the expert discriminator------"
startdt=$(date +%s)
python ./color_syncnet_train.py --data_root /tmp/lrs2_preprocessed2/ --checkpoint_dir /tmp/trained_syncnet/
enddt=$(date +%s)
interval_minutes=$(( (enddt - startdt) ))
echo "时间间隔为 $interval_minutes 秒"
echo "finished train syncnet"


echo "--------train wav2lip-----------------------" 
python ./hq_wav2lip_train.py --data_root /tmp/lrs2_preprocessed2/ --checkpoint_dir /tmp/trained_wav2lip_288x288/ --syncnet_checkpoint_path /tmp/trained_syncnet/checkpoint_step000000001.pth
echo "finished wav2lip"

./s5cmd sync /tmp/trained_wav2lip_288x288/ s3://${sagemaker_default_bucket}/models/wav2lip_288x288/output/$(date +%Y-%m-%d-%H-%M-%S)/

###inference
#echo "begin inference"
#./s5cmd sync  s3://${sagemaker_default_bucket}/wav2lip/inference/face_video/* /tmp/face_video/
#./s5cmd sync  s3://${sagemaker_default_bucket}/wav2lip/inference/audio/* /tmp/audio/
#python ./inference.py --checkpoint_path /tmp/trained_wav2lip_288x288/checkpoint_step000000001.pth  --face /tmp/face_video/VID20230623143819.mp4 --audio /tmp/audio/测试wav2lip.mp3
#./s5cmd sync ./results/result_voice.mp4  s3://${sagemaker_default_bucket}/models/wav2lip_288x288/results/

Overwriting ./wav2lip_288x288/train.sh


In [72]:
## The image uri which is build and pushed above
image_uri = "{}.dkr.ecr.{}.amazonaws.com/{}:latest".format(account, region, repo_name)
image_uri

'687912291502.dkr.ecr.us-west-2.amazonaws.com/sagemaker-wav2lip-288x288-demo:latest'

In [73]:
## set train_data_path to your training dataset path in s3
train_data_path = f's3://{sagemaker_default_bucket}/wav2lip_288x288/train_data/'.format(sagemaker_default_bucket)
inputs = {'data_root': train_data_path}


In [None]:
import time
from sagemaker.estimator import Estimator

environment = {
              'sagemaker_default_bucket': sagemaker_default_bucket # The bucket to store pretrained model and fine-tune model
}

base_job_name = 'wav2lip-288x288-demo'         

instance_type = 'ml.g5.48xlarge'

estimator = Estimator(role=role,
                      entry_point='train.sh',
                      source_dir='./wav2lip_288x288/',
                      base_job_name=base_job_name,
                      instance_count=1,
                      instance_type=instance_type,
                      image_uri=image_uri,
                      environment=environment,
                      keep_alive_period_in_seconds=3600,
                      disable_profiler=True,
                      debugger_hook_config=False,
                      max_run=24*60*60*2)

estimator.fit(inputs)

In [None]:
!aws s3 cp s3://${sagemaker_default_bucket}/models/wav2lip_288x288/results/result_voice.mp4 ./

### deepspeed （单机多卡）

In [None]:
## define repo name, should contain *sagemaker* in the name
repo_name = "sagemaker-wav2lip-288x288-demo"

In [None]:
## The image uri which is build and pushed above
image_uri = "{}.dkr.ecr.{}.amazonaws.com/{}:latest".format(account, region, repo_name)
image_uri

In [None]:
## set train_data_path to your training dataset path in s3
train_data_path = f's3://{sagemaker_default_bucket}/wav2lip_288x288/train_data/'.format(sagemaker_default_bucket)
inputs = {'data_root': train_data_path}


In [63]:
%%writefile ./wav2lip_288x288/color_syncnet_dist_train.py
from os.path import dirname, join, basename, isfile
from tqdm import tqdm
import time

from models import SyncNet_color as SyncNet
import audio

import torch
from torch import nn
from torch import optim
import torch.backends.cudnn as cudnn
from torch.utils import data as data_utils
import numpy as np

from glob import glob
import deepspeed

import os, random, cv2, argparse
from hparams import hparams, get_image_list

parser = argparse.ArgumentParser(description='Code to train the expert lip-sync discriminator')

parser.add_argument("--data_root", help="Root folder of the preprocessed LRS2 dataset", required=True)

parser.add_argument('--checkpoint_dir', help='Save checkpoints to this directory', required=True, type=str)
parser.add_argument('--checkpoint_path', help='Resumed from this checkpoint', default=None, type=str)
## Include DeepSpeed configuration arguments
parser.add_argument('--local_rank', type=int, default=-1, help='local rank passed from distributed launcher')
parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args()

global_step = 0
global_epoch = 0
use_cuda = torch.cuda.is_available()
print('use_cuda: {}'.format(use_cuda))

syncnet_T = 5
syncnet_mel_step_size = 16

class Dataset(object):
    def __init__(self, split):
        self.all_videos = get_image_list(args.data_root, split)

    def get_frame_id(self, frame):
        return int(basename(frame).split('.')[0])

    def get_window(self, start_frame):
        start_id = self.get_frame_id(start_frame)
        vidname = dirname(start_frame)

        window_fnames = []
        for frame_id in range(start_id, start_id + syncnet_T):
            frame = join(vidname, '{}.jpg'.format(frame_id))  # '{}.jpg' '{}.png'
            if not isfile(frame):
                return None
            window_fnames.append(frame)
        return window_fnames

    def crop_audio_window(self, spec, start_frame):
        # num_frames = (T x hop_size * fps) / sample_rate
        start_frame_num = self.get_frame_id(start_frame)
        start_idx = int(80. * (start_frame_num / float(hparams.fps)))

        end_idx = start_idx + syncnet_mel_step_size

        return spec[start_idx : end_idx, :]


    def __len__(self):
        return len(self.all_videos)

    def __getitem__(self, idx):
        while 1:
            idx = random.randint(0, len(self.all_videos) - 1)
            vidname = self.all_videos[idx]

            img_names = list(glob(join(vidname, '*.jpg')))  # '*.jpg' '*.png'
            # print('vidname:', vidname, 'img_names:', len(img_names))
            if len(img_names) <= 3 * syncnet_T:
                # print('CONTINUE: len(img_names) <= 3 * syncnet_T')
                continue
            img_name = random.choice(img_names)
            wrong_img_name = random.choice(img_names)
            while wrong_img_name == img_name:
                wrong_img_name = random.choice(img_names)

            if random.choice([True, False]):
                y = torch.ones(1).float()
                chosen = img_name
            else:
                y = torch.zeros(1).float()
                chosen = wrong_img_name

            window_fnames = self.get_window(chosen)
            if window_fnames is None:
                # print('CONTINUE: window_fnames is None. chosen:', chosen)
                continue

            window = []
            all_read = True
            for fname in window_fnames:
                img = cv2.imread(fname)
                if img is None:
                    all_read = False
                    break
                try:
                    img = cv2.resize(img, (hparams.img_size, hparams.img_size))
                except Exception as e:
                    all_read = False
                    break

                window.append(img)

            if not all_read: continue

            try:
                wavpath = join(vidname, "audio.wav")  # "audio.wav" "../audio.wav"
                wav = audio.load_wav(wavpath, hparams.sample_rate)

                orig_mel = audio.melspectrogram(wav).T
            except Exception as e:
                # print('CONTINUE Exception:', e, ', wavpath:', wavpath, ', wav:', wav)
                continue

            mel = self.crop_audio_window(orig_mel.copy(), img_name)

            if (mel.shape[0] != syncnet_mel_step_size):
                # print('CONTINUE: mel.shape[0] != syncnet_mel_step_size')
                continue

            # H x W x 3 * T
            x = np.concatenate(window, axis=2) / 255.
            x = x.transpose(2, 0, 1)
            x = x[:, x.shape[1]//2:]

            x = torch.FloatTensor(x)
            mel = torch.FloatTensor(mel.T).unsqueeze(0)
            
            # print('x, mel, y:', x, mel, y)

            return x, mel, y

logloss = nn.BCELoss()
def cosine_loss(a, v, y):
    # print('a:', a.shape, ', v:', v.shape, ', y:', y.shape)
    # print('a:', torch.isnan(a).any(), ', v:', torch.isnan(v).any(), ', y:', torch.isnan(y).any())
    # print('a:', torch.isinf(a).any(), ', v:', torch.isinf(v).any(), ', y:', torch.isinf(y).any())
    d = nn.functional.cosine_similarity(a, v)
    # print('d:', d.shape, torch.isnan(d).any(), d)
    # print('y:', y.shape, torch.isnan(y).any(), y)
    d = torch.clamp(d, min=0.0)  # TODO Pytorch.clamp：将小于0的元素修改为0，截断元素的取值空间
    loss = logloss(d.unsqueeze(1), y)
    # print('loss:', loss)

    return loss

def train(device, model_engine, train_data_loader, test_data_loader, optimizer_deepspeed,
          checkpoint_dir=None, checkpoint_interval=None, nepochs=None):

    global global_step, global_epoch
    resumed_step = global_step
    
    while global_epoch < nepochs:
        running_loss = 0.
        prog_bar = tqdm(enumerate(train_data_loader))
        for step, (x, mel, y) in prog_bar:
        # for step, (x, mel, y) in enumerate(train_data_loader):
            model_engine.train()
            #optimizer.zero_grad()

            # Transform data to CUDA device
            # print('x:', x.shape)
            x = x.to(device)

            mel = mel.to(device)

            a, v = model_engine(mel, x)
            y = y.to(device)

            loss = cosine_loss(a, v, y)
            model_engine.backward(loss)
            model_engine.step()

            global_step += 1
            cur_session_steps = global_step - resumed_step
            running_loss += loss.item()

            if global_step % checkpoint_interval == 0:
                save_checkpoint(
                    model_engine, optimizer_deepspeed, global_step, checkpoint_dir, global_epoch)

            #if global_step % hparams.syncnet_eval_interval == 0:
            #    with torch.no_grad():
            #        eval_model(test_data_loader, global_step, device, model, checkpoint_dir)

            prog_bar.set_description('Loss: {}'.format(running_loss / (step + 1)))
            # print('Loss: {}'.format(running_loss / (step + 1)))

        global_epoch += 1

def eval_model(test_data_loader, global_step, device, model, checkpoint_dir):
    eval_steps = 1400
    print('Evaluating for {} steps'.format(eval_steps))
    losses = []
    while 1:
        for step, (x, mel, y) in enumerate(test_data_loader):

            model.eval()

            # Transform data to CUDA device
            x = x.to(device)

            mel = mel.to(device)

            a, v = model(mel, x)
            y = y.to(device)

            loss = cosine_loss(a, v, y)
            losses.append(loss.item())

            if step > eval_steps: break

        averaged_loss = sum(losses) / len(losses)
        print(averaged_loss)

        return

def save_checkpoint(model_engine, optimizer, step, checkpoint_dir, epoch):
    # 获取当前时间戳
    timestamp = int(time.time())
    model_engine.save_checkpoint(checkpoint_dir)
    #torch_model = model_engine.module
    #output_model_path = checkpoint_dir+str(timestamp)+"_"+str(step)+"_trained_syncnet.pth"
    #torch.save(torch_model.state_dict(), output_model_path)
    print("Saved checkpoint:", checkpoint_dir)

def _load(checkpoint_path):
    if use_cuda:
        checkpoint = torch.load(checkpoint_path)
    else:
        checkpoint = torch.load(checkpoint_path,
                                map_location=lambda storage, loc: storage)
    return checkpoint

def load_checkpoint(path, model, optimizer, reset_optimizer=False):
    global global_step
    global global_epoch

    print("Load checkpoint from: {}".format(path))
    checkpoint = _load(path)
    model.load_state_dict(checkpoint["state_dict"])
    if not reset_optimizer:
        optimizer_state = checkpoint["optimizer"]
        if optimizer_state is not None:
            print("Load optimizer state from {}".format(path))
            optimizer.load_state_dict(checkpoint["optimizer"])
    global_step = checkpoint["global_step"]
    global_epoch = checkpoint["global_epoch"]

    return model

if __name__ == "__main__":
    checkpoint_dir = args.checkpoint_dir
    checkpoint_path = args.checkpoint_path

    if not os.path.exists(checkpoint_dir): os.mkdir(checkpoint_dir)

    # Dataset and Dataloader setup
    train_dataset = Dataset('train')
    test_dataset = Dataset('val')
    print('train_dataset:', len(train_dataset))
    print('test_dataset:', len(test_dataset))

    train_data_loader = data_utils.DataLoader(
        train_dataset, batch_size=hparams.syncnet_batch_size, shuffle=True,
        num_workers=hparams.num_workers)

    test_data_loader = data_utils.DataLoader(
        test_dataset, batch_size=hparams.syncnet_batch_size,
        num_workers=hparams.num_workers)

    device = torch.device("cuda" if use_cuda else "cpu")

    # Model
    model = SyncNet().to(device)
    print('total trainable params {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad)))

    optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],
                           lr=hparams.syncnet_lr)

    ###deepspeed 改造#################
    model_engine, optimizer_deepspeed, _, _ = deepspeed.initialize(args=args,model=model,optimizer=optimizer)
    
    if checkpoint_path is not None:
        load_checkpoint(checkpoint_path, model, optimizer, reset_optimizer=False)

    train(device, model_engine, train_data_loader, test_data_loader, optimizer_deepspeed,
          checkpoint_dir=checkpoint_dir,
          checkpoint_interval=hparams.syncnet_checkpoint_interval,
          nepochs=hparams.nepochs)

Overwriting ./wav2lip_288x288/color_syncnet_dist_train.py


In [64]:
%%writefile ./wav2lip_288x288/hq_wav2lip_train_deepspeed.py
from os.path import dirname, join, basename, isfile
from tqdm import tqdm

from models import SyncNet_color as SyncNet
from models import Wav2Lip, Wav2Lip_disc_qual
import audio

import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
import torch.backends.cudnn as cudnn
from torch.utils import data as data_utils
import numpy as np

from glob import glob

import os, random, cv2, argparse
from hparams import hparams, get_image_list

parser = argparse.ArgumentParser(description='Code to train the Wav2Lip model WITH the visual quality discriminator')

parser.add_argument("--data_root", help="Root folder of the preprocessed LRS2 dataset", required=True, type=str)

parser.add_argument('--checkpoint_dir', help='Save checkpoints to this directory', required=True, type=str)
parser.add_argument('--syncnet_checkpoint_path', help='Load the pre-trained Expert discriminator', required=True, type=str)

parser.add_argument('--checkpoint_path', help='Resume generator from this checkpoint', default=None, type=str)
parser.add_argument('--disc_checkpoint_path', help='Resume quality disc from this checkpoint', default=None, type=str)

args = parser.parse_args()


global_step = 0
global_epoch = 0
use_cuda = torch.cuda.is_available()
print('use_cuda: {}'.format(use_cuda))

syncnet_T = 5
syncnet_mel_step_size = 16

class Dataset(object):
    def __init__(self, split):
        self.all_videos = get_image_list(args.data_root, split)

    def get_frame_id(self, frame):
        return int(basename(frame).split('.')[0])

    def get_window(self, start_frame):
        start_id = self.get_frame_id(start_frame)
        vidname = dirname(start_frame)

        window_fnames = []
        for frame_id in range(start_id, start_id + syncnet_T):
            frame = join(vidname, '{}.jpg'.format(frame_id))
            if not isfile(frame):
                return None
            window_fnames.append(frame)
        return window_fnames

    def read_window(self, window_fnames):
        if window_fnames is None: return None
        window = []
        for fname in window_fnames:
            img = cv2.imread(fname)
            if img is None:
                return None
            try:
                img = cv2.resize(img, (hparams.img_size, hparams.img_size))
            except Exception as e:
                return None

            window.append(img)

        return window

    def crop_audio_window(self, spec, start_frame):
        if type(start_frame) == int:
            start_frame_num = start_frame
        else:
            start_frame_num = self.get_frame_id(start_frame)
        start_idx = int(80. * (start_frame_num / float(hparams.fps)))
        
        end_idx = start_idx + syncnet_mel_step_size

        return spec[start_idx : end_idx, :]

    def get_segmented_mels(self, spec, start_frame):
        mels = []
        assert syncnet_T == 5
        start_frame_num = self.get_frame_id(start_frame) + 1 # 0-indexing ---> 1-indexing
        if start_frame_num - 2 < 0: return None
        for i in range(start_frame_num, start_frame_num + syncnet_T):
            m = self.crop_audio_window(spec, i - 2)
            if m.shape[0] != syncnet_mel_step_size:
                return None
            mels.append(m.T)

        mels = np.asarray(mels)

        return mels

    def prepare_window(self, window):
        # 3 x T x H x W
        x = np.asarray(window) / 255.
        x = np.transpose(x, (3, 0, 1, 2))

        return x

    def __len__(self):
        return len(self.all_videos)

    def __getitem__(self, idx):
        while 1:
            idx = random.randint(0, len(self.all_videos) - 1)
            vidname = self.all_videos[idx]
            img_names = list(glob(join(vidname, '*.jpg')))
            if len(img_names) <= 3 * syncnet_T:
                continue
            
            img_name = random.choice(img_names)
            wrong_img_name = random.choice(img_names)
            while wrong_img_name == img_name:
                wrong_img_name = random.choice(img_names)

            window_fnames = self.get_window(img_name)
            wrong_window_fnames = self.get_window(wrong_img_name)
            if window_fnames is None or wrong_window_fnames is None:
                continue

            window = self.read_window(window_fnames)
            if window is None:
                continue

            wrong_window = self.read_window(wrong_window_fnames)
            if wrong_window is None:
                continue

            try:
                wavpath = join(vidname, "audio.wav")
                wav = audio.load_wav(wavpath, hparams.sample_rate)

                orig_mel = audio.melspectrogram(wav).T
            except Exception as e:
                continue

            mel = self.crop_audio_window(orig_mel.copy(), img_name)
            
            if (mel.shape[0] != syncnet_mel_step_size):
                continue

            indiv_mels = self.get_segmented_mels(orig_mel.copy(), img_name)
            if indiv_mels is None: continue

            window = self.prepare_window(window)
            y = window.copy()
            window[:, :, window.shape[2]//2:] = 0.

            wrong_window = self.prepare_window(wrong_window)
            x = np.concatenate([window, wrong_window], axis=0)

            x = torch.FloatTensor(x)
            mel = torch.FloatTensor(mel.T).unsqueeze(0)
            indiv_mels = torch.FloatTensor(indiv_mels).unsqueeze(1)
            y = torch.FloatTensor(y)
            return x, indiv_mels, mel, y

def save_sample_images(x, g, gt, global_step, checkpoint_dir):
    x = (x.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
    g = (g.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
    gt = (gt.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)

    refs, inps = x[..., 3:], x[..., :3]
    folder = join(checkpoint_dir, "samples_step{:09d}".format(global_step))
    if not os.path.exists(folder): os.mkdir(folder)
    collage = np.concatenate((refs, inps, g, gt), axis=-2)
    for batch_idx, c in enumerate(collage):
        for t in range(len(c)):
            cv2.imwrite('{}/{}_{}.jpg'.format(folder, batch_idx, t), c[t])

logloss = nn.BCELoss()
def cosine_loss(a, v, y):
    d = nn.functional.cosine_similarity(a, v)
    loss = logloss(d.unsqueeze(1), y)

    return loss

device = torch.device("cuda" if use_cuda else "cpu")
syncnet = SyncNet().to(device)
for p in syncnet.parameters():
    p.requires_grad = False

recon_loss = nn.L1Loss()
def get_sync_loss(mel, g):
    g = g[:, :, :, g.size(3)//2:]
    g = torch.cat([g[:, :, i] for i in range(syncnet_T)], dim=1)
    # B, 3 * T, H//2, W
    a, v = syncnet(mel, g)
    y = torch.ones(g.size(0), 1).float().to(device)
    return cosine_loss(a, v, y)

def train(device, model, disc, train_data_loader, test_data_loader, optimizer, disc_optimizer,
          checkpoint_dir=None, checkpoint_interval=None, nepochs=None):
    global global_step, global_epoch
    resumed_step = global_step

    while global_epoch < nepochs:
        print('Starting Epoch: {}'.format(global_epoch))
        running_sync_loss, running_l1_loss, disc_loss, running_perceptual_loss = 0., 0., 0., 0.
        running_disc_real_loss, running_disc_fake_loss = 0., 0.
        prog_bar = tqdm(enumerate(train_data_loader))
        for step, (x, indiv_mels, mel, gt) in prog_bar:
            disc.train()
            model.train()

            x = x.to(device)
            mel = mel.to(device)
            indiv_mels = indiv_mels.to(device)
            gt = gt.to(device)

            ### Train generator now. Remove ALL grads. 
            optimizer.zero_grad()
            disc_optimizer.zero_grad()

            g = model(indiv_mels, x)

            if hparams.syncnet_wt > 0.:
                sync_loss = get_sync_loss(mel, g)
            else:
                sync_loss = 0.

            if hparams.disc_wt > 0.:
                perceptual_loss = disc.perceptual_forward(g)
            else:
                perceptual_loss = 0.

            l1loss = recon_loss(g, gt)

            loss = hparams.syncnet_wt * sync_loss + hparams.disc_wt * perceptual_loss + \
                                    (1. - hparams.syncnet_wt - hparams.disc_wt) * l1loss

            loss.backward()
            optimizer.step()

            ### Remove all gradients before Training disc
            disc_optimizer.zero_grad()

            pred = disc(gt)
            disc_real_loss = F.binary_cross_entropy(pred, torch.ones((len(pred), 1)).to(device))
            disc_real_loss.backward()

            pred = disc(g.detach())
            disc_fake_loss = F.binary_cross_entropy(pred, torch.zeros((len(pred), 1)).to(device))
            disc_fake_loss.backward()

            disc_optimizer.step()

            running_disc_real_loss += disc_real_loss.item()
            running_disc_fake_loss += disc_fake_loss.item()

            if global_step % checkpoint_interval == 0:
                save_sample_images(x, g, gt, global_step, checkpoint_dir)

            # Logs
            global_step += 1
            cur_session_steps = global_step - resumed_step

            running_l1_loss += l1loss.item()
            if hparams.syncnet_wt > 0.:
                running_sync_loss += sync_loss.item()
            else:
                running_sync_loss += 0.

            if hparams.disc_wt > 0.:
                running_perceptual_loss += perceptual_loss.item()
            else:
                running_perceptual_loss += 0.

            if global_step==1 or global_step % checkpoint_interval == 0:
                save_checkpoint(
                    model, optimizer, global_step, checkpoint_dir, global_epoch)
                save_checkpoint(disc, disc_optimizer, global_step, checkpoint_dir, global_epoch, prefix='disc_')


            #if global_step % hparams.eval_interval == 0:
            #    with torch.no_grad():
            #        average_sync_loss = eval_model(test_data_loader, global_step, device, model, disc)

            #        if average_sync_loss < .75:
            #            hparams.set_hparam('syncnet_wt', 0.03)

            prog_bar.set_description('L1: {}, Sync: {}, Percep: {} | Fake: {}, Real: {}'.format(running_l1_loss / (step + 1),
                                                                                        running_sync_loss / (step + 1),
                                                                                        running_perceptual_loss / (step + 1),
                                                                                        running_disc_fake_loss / (step + 1),
                                                                                        running_disc_real_loss / (step + 1)))

        global_epoch += 1

def eval_model(test_data_loader, global_step, device, model, disc):
    eval_steps = 300
    print('Evaluating for {} steps'.format(eval_steps))
    running_sync_loss, running_l1_loss, running_disc_real_loss, running_disc_fake_loss, running_perceptual_loss = [], [], [], [], []
    while 1:
        for step, (x, indiv_mels, mel, gt) in enumerate((test_data_loader)):
            model.eval()
            disc.eval()

            x = x.to(device)
            mel = mel.to(device)
            indiv_mels = indiv_mels.to(device)
            gt = gt.to(device)

            pred = disc(gt)
            disc_real_loss = F.binary_cross_entropy(pred, torch.ones((len(pred), 1)).to(device))

            g = model(indiv_mels, x)
            pred = disc(g)
            disc_fake_loss = F.binary_cross_entropy(pred, torch.zeros((len(pred), 1)).to(device))

            running_disc_real_loss.append(disc_real_loss.item())
            running_disc_fake_loss.append(disc_fake_loss.item())

            sync_loss = get_sync_loss(mel, g)
            
            if hparams.disc_wt > 0.:
                perceptual_loss = disc.perceptual_forward(g)
            else:
                perceptual_loss = 0.

            l1loss = recon_loss(g, gt)

            loss = hparams.syncnet_wt * sync_loss + hparams.disc_wt * perceptual_loss + \
                                    (1. - hparams.syncnet_wt - hparams.disc_wt) * l1loss

            running_l1_loss.append(l1loss.item())
            running_sync_loss.append(sync_loss.item())
            
            if hparams.disc_wt > 0.:
                running_perceptual_loss.append(perceptual_loss.item())
            else:
                running_perceptual_loss.append(0.)

            if step > eval_steps: break

        print('L1: {}, Sync: {}, Percep: {} | Fake: {}, Real: {}'.format(sum(running_l1_loss) / len(running_l1_loss),
                                                            sum(running_sync_loss) / len(running_sync_loss),
                                                            sum(running_perceptual_loss) / len(running_perceptual_loss),
                                                            sum(running_disc_fake_loss) / len(running_disc_fake_loss),
                                                             sum(running_disc_real_loss) / len(running_disc_real_loss)))
        return sum(running_sync_loss) / len(running_sync_loss)


def save_checkpoint(model, optimizer, step, checkpoint_dir, epoch, prefix=''):
    checkpoint_path = join(
        checkpoint_dir, "{}checkpoint_step{:09d}.pth".format(prefix, global_step))
    optimizer_state = optimizer.state_dict() if hparams.save_optimizer_state else None
    torch.save({
        "state_dict": model.state_dict(),
        "optimizer": optimizer_state,
        "global_step": step,
        "global_epoch": epoch,
    }, checkpoint_path)
    print("Saved checkpoint:", checkpoint_path)

def _load(checkpoint_path):
    if use_cuda:
        checkpoint = torch.load(checkpoint_path)
    else:
        checkpoint = torch.load(checkpoint_path,
                                map_location=lambda storage, loc: storage)
    return checkpoint


def load_checkpoint(path, model, optimizer, reset_optimizer=False, overwrite_global_states=True):
    global global_step
    global global_epoch

    print("Load checkpoint from: {}".format(path))
    checkpoint = _load(path)
    s = checkpoint["state_dict"]
    new_s = {}
    for k, v in s.items():
        new_s[k.replace('module.', '')] = v
    model.load_state_dict(new_s)
    if not reset_optimizer:
        optimizer_state = checkpoint["optimizer"]
        if optimizer_state is not None:
            print("Load optimizer state from {}".format(path))
            optimizer.load_state_dict(checkpoint["optimizer"])
    if overwrite_global_states:
        global_step = checkpoint["global_step"]
        global_epoch = checkpoint["global_epoch"]

    return model


def load_checkpoint_syncnet(path, model):
    global global_step
    global global_epoch

    print("Load checkpoint from: {}".format(path))
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint["module"])
    return model


if __name__ == "__main__":
    checkpoint_dir = args.checkpoint_dir

    # Dataset and Dataloader setup
    train_dataset = Dataset('train')
    test_dataset = Dataset('val')

    train_data_loader = data_utils.DataLoader(
        train_dataset, batch_size=hparams.batch_size, shuffle=True,
        num_workers=hparams.num_workers)

    test_data_loader = data_utils.DataLoader(
        test_dataset, batch_size=hparams.batch_size,
        num_workers=4)

    device = torch.device("cuda" if use_cuda else "cpu")

     # Model
    model = Wav2Lip().to(device)
    disc = Wav2Lip_disc_qual().to(device)

    print('total trainable params {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad)))
    print('total DISC trainable params {}'.format(sum(p.numel() for p in disc.parameters() if p.requires_grad)))

    optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],
                           lr=hparams.initial_learning_rate, betas=(0.5, 0.999))
    disc_optimizer = optim.Adam([p for p in disc.parameters() if p.requires_grad],
                           lr=hparams.disc_initial_learning_rate, betas=(0.5, 0.999))

    if args.checkpoint_path is not None:
        load_checkpoint(args.checkpoint_path, model, optimizer, reset_optimizer=False)

    if args.disc_checkpoint_path is not None:
        load_checkpoint(args.disc_checkpoint_path, disc, disc_optimizer, 
                                reset_optimizer=False, overwrite_global_states=False)
        
    load_checkpoint_syncnet(args.syncnet_checkpoint_path, syncnet)

    if not os.path.exists(checkpoint_dir):
        os.mkdir(checkpoint_dir)

    # Train!
    train(device, model, disc, train_data_loader, test_data_loader, optimizer, disc_optimizer,
              checkpoint_dir=checkpoint_dir,
              checkpoint_interval=hparams.checkpoint_interval,
              nepochs=hparams.nepochs)


Overwriting ./wav2lip_288x288/hq_wav2lip_train_deepspeed.py


In [65]:
%%writefile ./wav2lip_288x288/ds_config.json
{
  "fp16": {
    "enabled": false,
    "auto_cast": true,
    "loss_scale": 0,
    "initial_scale_power": 16,
    "loss_scale_window": 1000,
    "hysteresis": 2,
    "min_loss_scale": 1
  },
  "optimizer": {
    "type": "Adam",
    "params": {
      "lr": 0.001,
      "betas": [
        0.8,
        0.999
      ],
      "eps": 1e-8,
      "weight_decay": 3e-7
    }
  },
  "scheduler": {
      "type": "WarmupLR",
      "params": {
          "warmup_min_lr": 0,
          "warmup_max_lr": 0.001,
          "warmup_num_steps": 1000
      }
  },
  "zero_optimization": {
      "stage": 0
  },
  "gradient_accumulation_steps": 2,
  "gradient_clipping": 0.0,
  "steps_per_print": 2000,
  "train_batch_size": 64,
  "train_micro_batch_size_per_gpu": 4,
  "wall_clock_breakdown": false
}

Overwriting ./wav2lip_288x288/ds_config.json


In [66]:
%%writefile ./wav2lip_288x288/train-distribute.sh
#!/bin/bash

chmod +x ./s5cmd
./s5cmd sync s3://${MODEL_S3_BUCKET}/wav2lip_288x288/train_data/main2/* /tmp/main2/
pip install -r ./requirements.txt

echo "--------preprocess wav-----------"
python ./preprocess.py --data_root /tmp/main2 --preprocessed_root /tmp/lrs2_preprocessed2/
echo "finished preprocess"

echo "--------prepare tain list--------"
python ./generate_filelists.py
echo "finished train list prepare"

echo "--------train the expert discriminator------"
startdt=$(date +%s)
deepspeed --num_gpus=8 ./color_syncnet_dist_train.py \
          --data_root /tmp/lrs2_preprocessed2/ \
          --checkpoint_dir /tmp/trained_syncnet/  \
          --deepspeed --deepspeed_config ds_config.json
enddt=$(date +%s)
interval_minutes=$(( (enddt - startdt) ))
echo "时间间隔为 $interval_minutes 秒"
echo "finished train syncnet"

trained_syncnet_model_file=$(find /tmp/trained_syncnet/ -name  "*.pt" -print)
trained_syncnet_model_file=$(echo $trained_syncnet_model_file|cut -d' ' -f1)
echo "trained_syncnet_model_file: "${trained_syncnet_model_file}
    
echo "--------train wav2lip-----------------------" 
python ./hq_wav2lip_train_deepspeed.py --data_root /tmp/lrs2_preprocessed2/ --checkpoint_dir /tmp/trained_wav2lip_288x288/ --syncnet_checkpoint_path $trained_syncnet_model_file
#python ./hq_wav2lip_train.py --data_root /tmp/lrs2_preprocessed2/ --checkpoint_dir /tmp/trained_wav2lip_288x288/ --syncnet_checkpoint_path /tmp/trained_syncnet/10/global_step2/mp_rank_00_model_states.pt
          
echo "finished wav2lip"

./s5cmd sync /tmp/trained_wav2lip_288x288/ s3://$MODEL_S3_BUCKET/wav2lip_288x288/output/$(date +%Y-%m-%d-%H-%M-%S)/

###inference
#echo "begin inference"
#./s5cmd sync  s3://${sagemaker_default_bucket}/wav2lip/inference/face_video/* /tmp/face_video/
#./s5cmd sync  s3://${sagemaker_default_bucket}/wav2lip/inference/audio/* /tmp/audio/
#python ./inference.py --checkpoint_path /tmp/trained_wav2lip_288x288/checkpoint_step000000001.pth  --face /tmp/face_video/VID20230623143819.mp4 --audio /tmp/audio/测试wav2lip.mp3
#./s5cmd sync ./results/result_voice.mp4  s3://${sagemaker_default_bucket}/models/wav2lip_288x288/results/

Overwriting ./wav2lip_288x288/train-distribute.sh


In [67]:
import time
from sagemaker.estimator import Estimator

environment = {
              'MODEL_S3_BUCKET': sagemaker_default_bucket # The bucket to store pretrained model and fine-tune model
}

base_job_name = 'wav2lip-288x288-demo'         

instance_type = 'ml.g5.48xlarge'

estimator = Estimator(role=role,
                      entry_point='train-distribute.sh',
                      source_dir='./wav2lip_288x288/',
                      base_job_name=base_job_name,
                      instance_count=1,
                      instance_type=instance_type,
                      image_uri=image_uri,
                      environment=environment,
                      keep_alive_period_in_seconds=3600,
                      disable_profiler=True,
                      debugger_hook_config=False,
                      max_run=24*60*60*2)

estimator.fit(inputs)

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/ec2-user/.config/sagemaker/config.yaml
Using provided s3_resource


INFO:sagemaker:Creating training-job with name: wav2lip-288x288-demo-2023-10-21-08-08-21-699


2023-10-21 08:10:46 Starting - Starting the training job......
2023-10-21 08:11:31 Starting - Preparing the instances for training.........
2023-10-21 08:13:10 Downloading - Downloading input data...
2023-10-21 08:13:30 Training - Downloading the training image.....................
2023-10-21 08:17:06 Training - Training image download completed. Training in progress.......[34mbash: cannot set terminal process group (-1): Inappropriate ioctl for device[0m
[34mbash: no job control in this shell[0m
[34m2023-10-21 08:18:00,953 sagemaker-training-toolkit INFO     Imported framework sagemaker_pytorch_container.training[0m
[34m2023-10-21 08:18:01,013 sagemaker-training-toolkit INFO     No Neurons detected (normal if no neurons installed)[0m
[34m2023-10-21 08:18:01,024 sagemaker_pytorch_container.training INFO     Block until all host DNS lookups succeed.[0m
[34m2023-10-21 08:18:01,026 sagemaker_pytorch_container.training INFO     Invoking user training script.[0m
[34m2023-10-21 

You could find the model path in S3 from above logs.