In [None]:
#@title # Step 1: Setup

import os
os.environ['LANG'] = 'en_US.UTF-8'
os.environ['LC_ALL'] = 'en_US.UTF-8'
import locale
locale.getpreferredencoding = lambda: "UTF-8"

%cd /content
if not os.path.isdir('Thin-Plate-Spline-Motion-Model'):
    !git clone https://github.com/gee666/Thin-Plate-Spline-Motion-Model.git
%cd Thin-Plate-Spline-Motion-Model

In [None]:
#@title # Step 2: Load models
#@markdown ## uncomment the models you want to use

!mkdir checkpoints
!pip3 install wldhx.yadisk-direct

!curl -L $(yadisk-direct https://disk.yandex.com/d/i08z-kCuDGLuYA) -o checkpoints/vox.pth.tar
#!curl -L $(yadisk-direct https://disk.yandex.com/d/vk5dirE6KNvEXQ) -o checkpoints/taichi.pth.tar
#!curl -L $(yadisk-direct https://disk.yandex.com/d/IVtro0k2MVHSvQ) -o checkpoints/mgif.pth.tar
#!curl -L $(yadisk-direct https://disk.yandex.com/d/B3ipFzpmkB1HIA) -o checkpoints/ted.pth.tar

# different source
#!curl -L $(yadisk-direct https://disk.yandex.ru/d/YbOdosYEwYY_SA) -o checkpoints/vox.pth.tar
#!curl -L $(yadisk-direct https://disk.yandex.ru/d/6eKgFjCUA-7k2w) -o checkpoints/taichi.pth.tar
#!curl -L $(yadisk-direct https://disk.yandex.ru/d/PRSRPrSgIExosw) -o checkpoints/mgif.pth.tar
#!curl -L $(yadisk-direct https://disk.yandex.ru/d/YbOdosYEwYY_SA) -o checkpoints/ted.pth.tar

In [None]:
#@title # Step 3: Settings
#@markdown ##### Import your driving video and/or image before filling in the form
#@markdown ##### For best result video and image should be squared 256x256 px


import torch
import importlib.util
import os

testmode = False # for testing with little amount of frames
max_frames_in_testmode = 16

current_directory = os.getcwd()

device = torch.device('cuda:0')
dataset_name = 'vox' #@param {type:"string"} ['vox', 'taichi', 'ted', 'mgif']

source_image_path = '/content/Thin-Plate-Spline-Motion-Model/assets/source.png' #@param {type:"string"}
source_image_path = os.path.join(current_directory, source_image_path)
source_image_path = os.path.relpath(source_image_path, current_directory)

driving_video_path = '/content/Thin-Plate-Spline-Motion-Model/assets/driving.mp4' #@param {type:"string"}
driving_video_path = os.path.join(current_directory, driving_video_path)
driving_video_path = os.path.relpath(driving_video_path, current_directory)

output_frames_directory = './output_frames/'
output_video_chunks = './video_chunks'
output_video_path = './generated.mp4'
predict_mode = 'relative' #@param {type:"string"}  ['standard', 'relative', 'avd']

#@markdown ##### Max batch size will depend on your memory and video size
max_batch_size = 2000 #@param { type:"number" }

fps = 24 # later this value will be corrected
if testmode:
  max_batch_size = max_frames_in_testmode

find_best_frame = False
# for relative using find best frame, which can give beeter results
if predict_mode == 'relative':
  find_best_frame = True
  if importlib.util.find_spec('face_alignment') is None:
      !pip install face-alignment

# for vox, taichi and mgif, the resolution is 256*256
pixel = 256
# for ted, the resolution is 384*384
if(dataset_name == 'ted'):
    pixel = 384

config_path = f'config/{dataset_name}-{pixel}.yaml'
checkpoint_path = f'checkpoints/{dataset_name}.pth.tar'

In [None]:
#@title # Step 4: Define functions
try:
  import imageio
  import imageio_ffmpeg
except:
  !pip install imageio_ffmpeg
  import imageio
  import imageio_ffmpeg
import numpy as np
from skimage.transform import resize
import warnings
import os
from tqdm.auto import tqdm
from skimage import img_as_ubyte
import shutil
import gc
import sys

from demo import make_animation
from demo import load_checkpoints


# Define ensure and clean directory function
def prepare_directory(directory_path, clean=True):
    """
    Ensure the directory exists and is empty.

    If the directory exists, clear its contents. If not, create the directory.

    Parameters:
    - directory_path: The path to the directory to prepare.
    """
    if not os.path.exists(directory_path):
        # Directory does not exist, so create it
        os.makedirs(directory_path)
    elif clean:
        # Directory exists, remove any existing files and directories within it
        for filename in os.listdir(directory_path):
            file_path = os.path.join(directory_path, filename)
            try:
                if os.path.isfile(file_path) or os.path.islink(file_path):
                    os.unlink(file_path)
                elif os.path.isdir(file_path):
                    shutil.rmtree(file_path)
            except Exception as e:
                print('Failed to delete %s. Reason: %s' % (file_path, e))


def memory_usage():
  !echo -n "Available Memory: " && cat /proc/meminfo | grep 'MemAvailable' | awk '{print $2/1024 " MB"}'


def animate(source_image, driving_video):
  inpainting, kp_detector, dense_motion_network, avd_network = load_checkpoints(config_path = config_path, checkpoint_path = checkpoint_path, device = device)
  if predict_mode=='relative' and find_best_frame:
      from demo import find_best_frame as _find
      i = _find(source_image, driving_video, device.type=='cpu')
      print ("Best frame: " + str(i))
      driving_forward = driving_video[i:]
      driving_backward = driving_video[:(i+1)][::-1]
      predictions_forward = make_animation(source_image, driving_forward, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode)
      predictions_backward = make_animation(source_image, driving_backward, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode)
      predictions = predictions_backward[::-1] + predictions_forward[1:]
  else:
      predictions = make_animation(source_image, driving_video, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode)
  return predictions

def process_batch(source_image, reader, total_frames_processed):
    driving_video = []

    for frame in tqdm(reader, total=reader.count_frames(), desc="resizing frames: "):
        frame = resize(frame, (pixel, pixel))[..., :3]
        driving_video.append(frame)

    memory_usage()
    print("animating frames...")
    predictions = animate(source_image, driving_video)
    memory_usage()

    del driving_video
    gc.collect()

    print(f"saving the batch frames...")
    # Save the resulting frames into the directory "output_frames_directory"
    for idx, frame in enumerate(predictions):
      total_frames_processed += 1
      imageio.imsave(f'{output_frames_directory}/frame_{total_frames_processed:04d}.png', img_as_ubyte(frame))
    print(f"...batch saved")

    del predictions
    gc.collect()

    return total_frames_processed


def process_video(source_image):
    # List all video part files
    video_parts = sorted([file for file in os.listdir(output_video_chunks) if file.endswith('.mp4')])
    total_frames_processed = 0

    for video_part in video_parts:
        video_part_path = os.path.join(output_video_chunks, video_part)
        print(f"Processing video part: {video_part_path}")

        # read the current video part
        reader = imageio.get_reader(video_part_path)
        fps = reader.get_meta_data()['fps']

        # Process the entire part as a single batch
        total_frames_processed = process_batch(source_image, reader, total_frames_processed)

        reader.close()
        del reader
        gc.collect()

    print("_______________________")
    print(f"Total frames processed: {total_frames_processed}")
    print("_______________________")
    memory_usage()



def split_video():
  # Initialize the video reader
  reader = imageio.get_reader(driving_video_path)
  global fps
  fps=reader.get_meta_data()['fps']

  print("_______________________")
  print(f"Total frames count {reader.count_frames()}")
  print(f"Original fps {fps}")
  print("_______________________")

  # Initialize variables
  part_number = 1
  frame_count = 0
  writer = None

  # Iterate over frames
  for i, frame in enumerate(reader):
      # Start a new part if frame_count reached frames_per_part
      if frame_count == 0:
          if writer:
              writer.close()
          writer = imageio.get_writer(f'{output_video_chunks}/part_{part_number:02d}.mp4', fps=fps)
          part_number += 1

      # Write the current frame
      writer.append_data(frame)
      frame_count += 1

      # Reset frame count if it reaches the limit
      if frame_count == max_batch_size:
          frame_count = 0

      # Test mode condition
      if testmode and part_number > 1:
          break

  # Close the last writer
  if writer:
      writer.close()

  del reader, writer
  gc.collect()

In [None]:
#@title # Step 5: Here the magic happens

warnings.filterwarnings("ignore")

prepare_directory(output_frames_directory)
prepare_directory(output_video_chunks)

memory_usage()

# read the image
print("Preparing source image...")
source_image = imageio.imread(source_image_path)
# Ensure source_image has three channels
source_image = resize(source_image, (pixel, pixel))[..., :3]
memory_usage()

print("Splitting the video on chuncs...")
split_video()

memory_usage()
process_video(source_image)
gc.collect()

# When all frames are done, combine all the output_frames into the resulting video
print(f"Saving video to {output_video_path}...")
output_frames = sorted(os.listdir(output_frames_directory), key=lambda x: int(x.split('_')[1].split('.')[0]))
with imageio.get_writer(output_video_path, fps=fps) as writer:
    for frame_filename in output_frames:
        frame_path = os.path.join(output_frames_directory, frame_filename)
        frame = imageio.imread(frame_path)
        writer.append_data(frame)
print("done!")

del output_frames, source_image
gc.collect()
memory_usage()

In [None]:
#@title # Commit suicide
import os
os.kill(os.getpid(), 9)