<a href="https://colab.research.google.com/github/zhangxs131/paper_demo_for_fun/blob/main/make_picture_move.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##Thin-Plate Spline Motion Model for Image Animation 

用于图像动画的运动模型,该论文是清华的CVPR2022的一篇文章，论文地址为：https://arxiv.org/abs/2203.14367

源码github地址为 ：https://github.com/yoyo-nb/thin-plate-spline-motion-model


In [1]:
#git 源码
!git clone https://github.com/yoyo-nb/Thin-Plate-Spline-Motion-Model.git

Cloning into 'Thin-Plate-Spline-Motion-Model'...
remote: Enumerating objects: 78, done.[K
remote: Counting objects: 100% (47/47), done.[K
remote: Compressing objects: 100% (33/33), done.[K
remote: Total 78 (delta 15), reused 32 (delta 14), pack-reused 31[K
Unpacking objects: 100% (78/78), done.


In [2]:
%cd Thin-Plate-Spline-Motion-Model

/content/Thin-Plate-Spline-Motion-Model


In [3]:
#下载预训练的模型参数

!mkdir checkpoints
!wget -c https://cloud.tsinghua.edu.cn/f/da8d61d012014b12a9e4/?dl=1 -O checkpoints/vox.pth.tar

--2022-05-04 12:25:59--  https://cloud.tsinghua.edu.cn/f/da8d61d012014b12a9e4/?dl=1
Resolving cloud.tsinghua.edu.cn (cloud.tsinghua.edu.cn)... 101.6.8.7
Connecting to cloud.tsinghua.edu.cn (cloud.tsinghua.edu.cn)|101.6.8.7|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cloud.tsinghua.edu.cn/seafhttp/files/a03fce5c-e169-4da6-a4e3-d86e9435e0a4/vox.pth.tar [following]
--2022-05-04 12:26:01--  https://cloud.tsinghua.edu.cn/seafhttp/files/a03fce5c-e169-4da6-a4e3-d86e9435e0a4/vox.pth.tar
Connecting to cloud.tsinghua.edu.cn (cloud.tsinghua.edu.cn)|101.6.8.7|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 350993469 (335M) [application/octet-stream]
Saving to: ‘checkpoints/vox.pth.tar’


2022-05-04 12:26:27 (13.2 MB/s) - ‘checkpoints/vox.pth.tar’ saved [350993469/350993469]



In [12]:
import torch

#配置config

device = torch.device('cuda:0')
dataset_name='vox'
#在此处修改照片和视频
#source_image_path = './assets/source.png'
source_image_path = '/content/myself.jpg'
driving_video_path = './assets/driving.mp4'
output_video_path = './generated.mp4'
config_path = 'config/vox-256.yaml'
checkpoint_path = 'checkpoints/vox.pth.tar'

predict_mode='relative'
find_best_frame=False

pixel=256

if(dataset_name == 'ted'): # for ted, the resolution is 384*384
    pixel = 384

if find_best_frame:
  !pip install face_alignment


##显示原图片和目标视频动作

这里为了好玩把照片使用了自己上传的照片，看来效果还不错

In [13]:
import imageio
import numpy as np
import matplotlib.pylab as plt
import matplotlib.animation as animation
from skimage.transform  import resize
from IPython.display import HTML

import warnings
import os

warnings.filterwarnings('ignore')

source_image=imageio.imread(source_image_path)
reader=imageio.get_reader(driving_video_path)

source_image=resize(source_image,(pixel,pixel))[...,:3]

fps=reader.get_meta_data()['fps']
driving_video=[]

try:
  for im in reader:
    driving_video.append(im)
except RuntimeError:
  pass
reader.close()

driving_video=[resize(frame,(pixel,pixel))[...,:3] for frame in driving_video]

def display(source,driving,generated=None):
  fig=plt.figure(figsize=(8+4*(generated is not None),6))

  ims=[]
  for i in range(len(driving)):
    cols=[source]
    cols.append(driving[i])
    if generated is not None:
      cols.append(generated[i])
    im=plt.imshow(np.concatenate(cols,axis=1),animated=True)
    plt.axis('off')
    ims.append([im])

  ani=animation.ArtistAnimation(fig,ims,interval=50,repeat_delay=1000)
  plt.close()
  return ani

HTML(display(source_image,driving_video).to_html5_video())





##调用模型，实现picture的animation化

In [14]:
from demo import load_checkpoints

inpainting, kp_detector, dense_motion_network, avd_network = load_checkpoints(config_path = config_path, checkpoint_path = checkpoint_path, device = device)

In [15]:
from demo import make_animation
from skimage import img_as_ubyte

if predict_mode=='relative' and find_best_frame:
    from demo import find_best_frame as _find
    i = _find(source_image, driving_video, device.type=='cpu')
    print ("Best frame: " + str(i))
    driving_forward = driving_video[i:]
    driving_backward = driving_video[:(i+1)][::-1]
    predictions_forward = make_animation(source_image, driving_forward, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode)
    predictions_backward = make_animation(source_image, driving_backward, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode)
    predictions = predictions_backward[::-1] + predictions_forward[1:]
else:
    predictions = make_animation(source_image, driving_video, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode)

#保存video
imageio.mimsave(output_video_path, [img_as_ubyte(frame) for frame in predictions], fps=fps)

HTML(display(source_image, driving_video, predictions).to_html5_video())

100%|██████████| 169/169 [00:55<00:00,  3.05it/s]
