In [None]:
# Downgrade python to 3.9
!sudo apt-get update -y
!sudo apt-get install python3.9
!sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.9 1
!sudo update-alternatives --config python3
!python --version
!sudo apt-get install python3-pip

In [None]:
!git clone https://github.com/shunqilei/Thin-Plate-Spline-Motion-Model.git

In [None]:
cd Thin-Plate-Spline-Motion-Model

/content/Thin-Plate-Spline-Motion-Model


In [None]:
!pip install -r requirements.txt
!pip install matplotlib
!pip install scikit-image==0.18.3
!pip install scikit-learn==1.0
!pip install PyYAML==5.4.1
!pip install torch
!pip install torchvision
!pip install tqdm

In [None]:
!mkdir checkpoints
!wget -c https://cloud.tsinghua.edu.cn/f/da8d61d012014b12a9e4/?dl=1 -O checkpoints/vox.pth.tar
!wget -c https://cloud.tsinghua.edu.cn/f/483ef53650b14ac7ae70/?dl=1 -O checkpoints/ted.pth.tar
!wget -c https://cloud.tsinghua.edu.cn/f/9ec01fa4aaef423c8c02/?dl=1 -O checkpoints/taichi.pth.tar
!wget -c https://cloud.tsinghua.edu.cn/f/cd411b334a2e49cdb1e2/?dl=1 -O checkpoints/mgif.pth.tar

mkdir: cannot create directory ‘checkpoints’: File exists
--2023-05-11 05:24:24--  https://cloud.tsinghua.edu.cn/f/da8d61d012014b12a9e4/?dl=1
Resolving cloud.tsinghua.edu.cn (cloud.tsinghua.edu.cn)... 166.111.6.101, 2402:f000:1:406:166:111:6:101
Connecting to cloud.tsinghua.edu.cn (cloud.tsinghua.edu.cn)|166.111.6.101|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cloud.tsinghua.edu.cn/seafhttp/files/391d9495-8935-47c9-9857-2ca88c605eed/vox.pth.tar [following]
--2023-05-11 05:24:25--  https://cloud.tsinghua.edu.cn/seafhttp/files/391d9495-8935-47c9-9857-2ca88c605eed/vox.pth.tar
Reusing existing connection to cloud.tsinghua.edu.cn:443.
HTTP request sent, awaiting response... 416 Requested Range Not Satisfiable

    The file is already fully retrieved; nothing to do.

--2023-05-11 05:24:26--  https://cloud.tsinghua.edu.cn/f/483ef53650b14ac7ae70/?dl=1
Resolving cloud.tsinghua.edu.cn (cloud.tsinghua.edu.cn)... 166.111.6.101, 2402:f000:1:406:166:111:6:

In [None]:
import torch

# edit the config
device = torch.device('cuda:0')
dataset_name = 'mgif' # ['vox', 'taichi', 'ted', 'mgif']
source_image_path = './assets/mgif4.png'
driving_video_path = './assets/00004.mp4'
output_video_path = './generated-mgif4.mp4'
config_path = 'config/mgif-256.yaml'
checkpoint_path = 'checkpoints/mgif.pth.tar'
predict_mode = 'standard' # ['standard', 'relative', 'avd']
find_best_frame = False # when use the relative mode to animate a face, use 'find_best_frame=True' can get better quality result

pixel = 256 # for vox, taichi and mgif, the resolution is 256*256
if(dataset_name == 'ted'): # for ted, the resolution is 384*384
    pixel = 384

if find_best_frame:
  !pip install face_alignment

# Evaluate pretrained model

In [None]:
from frames_dataset import FramesDataset
from demo import load_checkpoints
import os
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import numpy as np
import imageio
import yaml
from modules.inpainting_network import InpaintingNetwork
from modules.bg_motion_predictor import BGMotionPredictor
from modules.keypoint_detector import KPDetector
from modules.dense_motion import DenseMotionNetwork
from modules.avd_network import AVDNetwork

def load_model_from_checkpoint(config, checkpoint_path):
    inpainting_network = InpaintingNetwork(**config['model_params']['generator_params'],
                                           **config['model_params']['common_params'])

    if torch.cuda.is_available():
        inpainting_network.to(device)

    kp_detector = KPDetector(**config['model_params']['common_params'])
    dense_motion_network = DenseMotionNetwork(**config['model_params']['common_params'],
                                              **config['model_params']['dense_motion_params'])
                                                           
    if torch.cuda.is_available():
        kp_detector.to(device)
        dense_motion_network.to(device)

    bg_predictor = None
    if (config['model_params']['common_params']['bg']):
        bg_predictor = BGMotionPredictor()
        if torch.cuda.is_available():
            bg_predictor.to(device)
    
    checkpoint = torch.load(checkpoint_path)
    if inpainting_network is not None:
        inpainting_network.load_state_dict(checkpoint['inpainting_network'])
    if kp_detector is not None:
        kp_detector.load_state_dict(checkpoint['kp_detector'])
    if bg_predictor is not None and 'bg_predictor' in checkpoint:
        bg_predictor.load_state_dict(checkpoint['bg_predictor'])
    if dense_motion_network is not None:
        dense_motion_network.load_state_dict(checkpoint['dense_motion_network'])

    return inpainting_network, kp_detector, bg_predictor, dense_motion_network 

def load_dataset(mode, config_path):
    with open(config_path) as f:
        config = yaml.safe_load(f)
    dataset = FramesDataset(is_train=(mode == 'train'), **config['dataset_params'])
    return dataset

def reconstruction(config_path, inpainting_network, kp_detector, bg_predictor, dense_motion_network, dataset):
    with open(config_path) as f:
        config = yaml.safe_load(f)
    log_dir = os.path.join('checkpoints')
    png_dir = os.path.join(log_dir, 'reconstruction/png')
    log_dir = os.path.join(log_dir, 'reconstruction')

    dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    if not os.path.exists(png_dir):
        os.makedirs(png_dir)
    
    loss_list = []
    test_loss_per_video = []

    inpainting_network.eval()
    kp_detector.eval()
    dense_motion_network.eval()
    if bg_predictor:
        bg_predictor.eval()

    for it, x in tqdm(enumerate(dataloader)):
        with torch.no_grad():
            predictions = []
            visualizations = []
            if torch.cuda.is_available():
                x['video'] = x['video'].cuda()
            kp_source = kp_detector(x['video'][:, :, 0])
            test_loss = 0
            for frame_idx in range(x['video'].shape[2]):
                source = x['video'][:, :, 0]
                driving = x['video'][:, :, frame_idx]
                kp_driving = kp_detector(driving)
                bg_params = None
                if bg_predictor:
                    bg_params = bg_predictor(source, driving)
                
                dense_motion = dense_motion_network(source_image=source, kp_driving=kp_driving,
                                                    kp_source=kp_source, bg_param = bg_params, 
                                                    dropout_flag = False)
                out = inpainting_network(source, dense_motion)
                out['kp_source'] = kp_source
                out['kp_driving'] = kp_driving

                predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
                loss = torch.abs(out['prediction'] - driving).mean().cpu().numpy()
                
                loss_list.append(loss)
                test_loss += loss
            # print(np.mean(loss_list))
            test_loss_per_video.append(test_loss/x['video'].shape[2])
            predictions = np.concatenate(predictions, axis=1)
            imageio.imsave(os.path.join(png_dir, x['name'][0] + '.png'), (255 * predictions).astype(np.uint8))
    print("\nReconstruction L1 loss: %s" % np.mean(loss_list))
    return test_loss_per_video

## pretrained model on mgif dataset

In [None]:
# Evaluate pre-trained model performance
with open(config_path) as f:
    config = yaml.safe_load(f)
checkpoint_path = 'checkpoints/mgif.pth.tar'
inpainting_network, kp_detector, bg_predictor, dense_motion_network = load_model_from_checkpoint(config, checkpoint_path)
dataset = load_dataset('reconstruction', config_path)
# Reconstructed videos will be saved to ./checkpoints/reconstruction/png
mgif_test_loss = reconstruction(config_path, inpainting_network, kp_detector, bg_predictor, dense_motion_network, dataset)

None
Use predefined train-test split.


100it [03:56,  2.37s/it]


Reconstruction L1 loss: 0.02012752





In [None]:
fig = plt.figure(figsize=(40, 20))
plt.plot(test_loss)
plt.xticks(np.arange(100))
plt.yticks(np.arange(0,0.1,step=0.005))
plt.savefig("test_loss_graph")

In [64]:
driving = imageio.imread("moving-gif/test/00001.gif")

In [73]:
str(12).zfill(5)

'00012'

In [70]:
fig = plt.figure(figsize=(50, 50))
plt.imshow(driving)
plt.savefig("random")

In [87]:
fig = plt.figure(figsize=(50, 10))
sorted_loss = sorted(enumerate(mgif_test_loss), key=lambda x: x[1], reverse=True )[:5]
for i in range(5):
    a = fig.add_subplot(1, 5, i+1)
    gif = imageio.imread(f"moving-gif/test/{str(sorted_loss[i][0]).zfill(5)}.gif")
    if len(gif.shape) < 3:
      gif = gif[..., np.newaxis]
      plt.imshow(gif, cmap='gray', vmin=0, vmax=255)
    else:
      plt.imshow(gif)
    a.axis("off")
plt.savefig("top_5_worst")

In [81]:
sorted_loss

[(70, 0.09117175338569083),
 (67, 0.07672411038400355),
 (80, 0.06674102617907161),
 (87, 0.06580174004525219),
 (89, 0.06377158158190355)]

In [None]:
# Evaluate pre-trained model performance
checkpoint_path = 'checkpoints/vox.pth.tar'
inpainting_network, kp_detector, bg_predictor, dense_motion_network = load_model_from_checkpoint(config, checkpoint_path)
dataset = load_dataset('reconstruction', config_path)
# Reconstructed videos will be saved to ./checkpoints/reconstruction/png
vox_test_loss = reconstruction(config_path, inpainting_network, kp_detector, bg_predictor, dense_motion_network, dataset)

None
Use predefined train-test split.


100it [05:55,  3.55s/it]


Reconstruction L1 loss: 0.07166834





In [None]:
# Evaluate pre-trained model performance
checkpoint_path = 'checkpoints/ted.pth.tar'
inpainting_network, kp_detector, bg_predictor, dense_motion_network = load_model_from_checkpoint(config, checkpoint_path)
dataset = load_dataset('reconstruction', config_path)
# Reconstructed videos will be saved to ./checkpoints/reconstruction/png
ted_test_loss = reconstruction(config_path, inpainting_network, kp_detector, bg_predictor, dense_motion_network, dataset)

None
Use predefined train-test split.


100it [04:13,  2.53s/it]


Reconstruction L1 loss: 0.058726206





In [None]:
# Evaluate pre-trained model performance
checkpoint_path = 'checkpoints/taichi.pth.tar'
inpainting_network, kp_detector, bg_predictor, dense_motion_network = load_model_from_checkpoint(config, checkpoint_path)
dataset = load_dataset('reconstruction', config_path)
# Reconstructed videos will be saved to ./checkpoints/reconstruction/png
taichi_test_loss = reconstruction(config_path, inpainting_network, kp_detector, bg_predictor, dense_motion_network, dataset)

None
Use predefined train-test split.


100it [05:12,  3.13s/it]


Reconstruction L1 loss: 0.05116591





In [90]:
fig = plt.figure(figsize=(40, 20))
losses = [np.mean(mgif_test_loss),np.mean(vox_test_loss),np.mean(ted_test_loss),np.mean(taichi_test_loss)]
plt.plot(['mgif','vox','ted','taichi'], losses)
plt.xticks(fontsize=50)
plt.yticks(fontsize=35)
plt.savefig("generalization_loss_graph")

# Visual examples of reconstructed videos

In [None]:
try:
  import imageio
  import imageio_ffmpeg
except:
  !pip install imageio_ffmpeg
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from skimage.transform import resize
from IPython.display import HTML
import warnings
import os

warnings.filterwarnings("ignore")
def display(source, driving, generated=None):
    fig = plt.figure(figsize=(8 + 4 * (generated is not None), 6))

    ims = []
    for i in range(len(driving)):
        cols = [source]
        cols.append(driving[i])
        if generated is not None:
            cols.append(generated[i])
        im = plt.imshow(np.concatenate(cols, axis=1), animated=True)
        plt.axis('off')
        ims.append([im])

    ani = animation.ArtistAnimation(fig, ims, interval=50, repeat_delay=1000)
    plt.close()
    return ani
def preprocess( source_image_path, driving_video_path, predict_mode ):
    source_image = imageio.imread(source_image_path)
    reader = imageio.get_reader(driving_video_path)

    source_image = resize(source_image, (pixel, pixel))[..., :3]

    fps = reader.get_meta_data()['fps']
    driving_video = []
    try:
        for im in reader:
            driving_video.append(im)
    except RuntimeError:
        pass
    reader.close()

    driving_video = [resize(frame, (pixel, pixel))[..., :3] for frame in driving_video]
    return source_image, driving_video, fps

In [None]:
from demo import make_animation
from skimage import img_as_ubyte

def predict( inpainting, kp_detector, dense_motion_network, avd_network, source_image, driving_video, fps ):
  
  if predict_mode=='relative' and find_best_frame:
      from demo import find_best_frame as _find
      i = _find(source_image, driving_video, device.type=='cpu')
      print ("Best frame: " + str(i))
      driving_forward = driving_video[i:]
      driving_backward = driving_video[:(i+1)][::-1]
      predictions_forward = make_animation(source_image, driving_forward, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode)
      predictions_backward = make_animation(source_image, driving_backward, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode)
      predictions = predictions_backward[::-1] + predictions_forward[1:]
  else:
      predictions = make_animation(source_image, driving_video, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode)

  #save resulting video
  imageio.mimsave(output_video_path, [img_as_ubyte(frame) for frame in predictions], fps=fps)
  return predictions

In [None]:
# Use pretrained model to predict on single data point
from demo import load_checkpoints
inpainting, kp_detector, dense_motion_network, avd_network = load_checkpoints(config_path = config_path, checkpoint_path = checkpoint_path, device = device)
source_image, driving_video, fps = preprocess(source_image_path, driving_video_path, predict_mode)
HTML(display(source_image, driving_video).to_html5_video())

In [None]:
predictions = predict(inpainting, kp_detector, dense_motion_network, avd_network, source_image, driving_video, fps)
HTML(display(source_image, driving_video, predictions).to_html5_video())

In [None]:
# Use model pretrained on mgif to predict reconstructed video with different image animation mode on a mgif video
nCols = 1+len(driving_video)
fig = plt.figure(figsize=(nCols*10, 20))
a = fig.add_subplot(4, nCols, 1)
plt.imshow(source_image)
a.axis("off")
for i, frame in enumerate( driving_video ):
    a = fig.add_subplot(4, nCols, i+2)
    plt.imshow(frame)
    a.axis("off")
for i, mode in enumerate(['standard', 'relative', 'avd']):
    source_image, driving_video, fps = preprocess(source_image_path, driving_video_path, predict_mode)
    predictions = predict(inpainting, kp_detector, dense_motion_network, avd_network, source_image, driving_video, fps)
    for j, frame in enumerate( predictions ):
        a = fig.add_subplot(4, nCols, j+2+(i+1)*nCols)
        plt.imshow(frame)
        a.axis("off")
    fig.text(0.5, 0.50-0.20*i, f'reconstructed video with mode {mode}', ha='center', va='center', fontsize=80)
fig.text(0.15, 0.70, 'source image', ha='center', va='center', fontsize=80)
fig.text(0.5, 0.70, 'driving video', ha='center', va='center', fontsize=80)
plt.savefig("mgif4-comparison")