# Towards An <strong>E</strong>nd-to-<strong>E</strong>nd Framework for <strong>F</strong>low-<strong>G</strong>uided <strong>V</strong>ideo <strong>I</strong>npainting (CVPR 2022)

In this demo, you can try to inpaint an example video through our framework.

# Setup Environment

In [1]:
# Install Pytorch
!pip install torch==1.5.1+cu101 torchvision==0.6.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html 
# Install MMCV
!pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.5/index.html

# prepare code
import os
CODE_DIR = 'E2FGVI'
os.makedirs(f'./{CODE_DIR}')
!git clone https://github.com/MCG-NKU/E2FGVI.git $CODE_DIR
os.chdir(f'./{CODE_DIR}')


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://download.pytorch.org/whl/torch_stable.html
Collecting torch==1.5.1+cu101
  Downloading https://download.pytorch.org/whl/cu101/torch-1.5.1%2Bcu101-cp37-cp37m-linux_x86_64.whl (704.4 MB)
[K     |████████████████████████████████| 704.4 MB 1.3 kB/s 
[?25hCollecting torchvision==0.6.1+cu101
  Downloading https://download.pytorch.org/whl/cu101/torchvision-0.6.1%2Bcu101-cp37-cp37m-linux_x86_64.whl (6.6 MB)
[K     |████████████████████████████████| 6.6 MB 66.9 MB/s 
Installing collected packages: torch, torchvision
  Attempting uninstall: torch
    Found existing installation: torch 1.12.0+cu113
    Uninstalling torch-1.12.0+cu113:
      Successfully uninstalled torch-1.12.0+cu113
  Attempting uninstall: torchvision
    Found existing installation: torchvision 0.13.0+cu113
    Uninstalling torchvision-0.13.0+cu113:
      Successfully uninstalled torchvision-0.13.0+cu1

## Download Model

In [2]:
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials
import os

download_with_pydrive = True

class Downloader(object):
    def __init__(self, use_pydrive):
        self.use_pydrive = use_pydrive
        current_directory = os.getcwd()
        self.save_dir = os.path.join(os.path.dirname(current_directory), CODE_DIR, "release_model")
        if not os.path.exists(self.save_dir):        
            os.makedirs(self.save_dir)
        if self.use_pydrive:
            self.authenticate()

    def authenticate(self):
        auth.authenticate_user()
        gauth = GoogleAuth()
        gauth.credentials = GoogleCredentials.get_application_default()
        self.drive = GoogleDrive(gauth)

    def download_file(self, file_id, file_name):
        file_dst = f'{self.save_dir}/{file_name}'
        if os.path.exists(file_dst):
            print(f'{file_name} already exists!')
            return
        downloaded = self.drive.CreateFile({'id':file_id})
        downloaded.FetchMetadata(fetch_all=True)
        downloaded.GetContentFile(file_dst)

downloader = Downloader(download_with_pydrive)
#path = {"id": "1tNJMTJ2gmWdIXJoHVi5-H504uImUiJW9", "name": "E2FGVI_CVPR22_models.zip"}
#downloader.download_file(file_id=path["id"], file_name=path["name"])

downloader.download_file('10wGdKSUOie0XmCr8SQ2A2FeDe-mfn5w3', 'E2FGVI-HQ-CVPR22.pth')
!mkdir /content/input
downloader.download_file('1Dx-kHrSAcrheuXtCqby9BKbm093dhvhh', '../../input/input.zip')
os.chdir('/content/input')
!unzip input.zip
os.chdir('/content/E2FGVI')

Archive:  input.zip
   creating: delogo_examples/
   creating: delogo_examples/mask/
  inflating: delogo_examples/mask/test_01_mask.png  
  inflating: delogo_examples/mask/test_02_mask.png  
  inflating: delogo_examples/mask/test_03_mask.png  
  inflating: delogo_examples/mask/test_04_mask.png  
  inflating: delogo_examples/mask/test_05_mask.png  
  inflating: delogo_examples/mask/west1_mask.png  
  inflating: delogo_examples/mask/west2_mask.png  
  inflating: delogo_examples/mask/west3_mask.png  
  inflating: delogo_examples/mask/west4_mask.png  
  inflating: delogo_examples/mask/west5_mask.png  
  inflating: delogo_examples/mask/west6_mask.png  
  inflating: delogo_examples/test_01.mp4  
  inflating: delogo_examples/test_02.mp4  
  inflating: delogo_examples/test_03.mp4  
  inflating: delogo_examples/test_04.mp4  
  inflating: delogo_examples/test_05.mp4  
  inflating: delogo_examples/west1.mp4  
  inflating: delogo_examples/west2.mp4  
  inflating: delogo_examples/west3.mp4  
  infl

# Inpainting 


### Change Directory if need

In [None]:
## chdir if need
import os
os.chdir('/content/E2FGVI')

### Import modules

In [6]:
import cv2
from PIL import Image
import numpy as np
import importlib
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib import animation
import torch
from core.utils import to_tensors

### Setup Global Variables

In [3]:
# global variables
ref_length = 10  # ref_step
num_ref = 3
neighbor_stride = 3
default_fps = 24
video_path = '/content/input/delogo_examples/test_03.mp4'
mask_path = '/content/input/delogo_examples/mask/test_03_mask.png'
use_mp4 = True if video_path.endswith('.mp4') else False
ckpt = 'release_model/E2FGVI-HQ-CVPR22.pth'
model_name = 'e2fgvi_hq'
if model_name == 'e2fgvi_hq':
    size = (960, 640) # 720p
    # size = (1920, 1080)
else:
    size = (432, 240)

### Utility Functions

In [4]:
# sample reference frames from the whole video
def get_ref_index(f, neighbor_ids, length):
    ref_index = []
    if num_ref == -1:
        for i in range(0, length, ref_length):
            if i not in neighbor_ids:
                ref_index.append(i)
    else:
        start_idx = max(0, f - ref_length * (num_ref // 2))
        end_idx = min(length - 1, f + ref_length * (num_ref // 2))
        for i in range(start_idx, end_idx + 1, ref_length):
            if i not in neighbor_ids:
                if len(ref_index) > num_ref:
                    break
                ref_index.append(i)
    return ref_index


# read frame-wise masks
def read_mask(mpath, size):
    masks = []
    mnames = os.listdir(mpath)
    mnames.sort()
    for mp in mnames:
        m = Image.open(os.path.join(mpath, mp))
        m = m.resize(size, Image.NEAREST)
        m = np.array(m.convert('L'))
        m = np.array(m > 0).astype(np.uint8)
        m = cv2.dilate(m,
                       cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)),
                       iterations=4)
        masks.append(Image.fromarray(m * 255))
    return masks


# read frame-wise masks
def read_mask_lst(mpath, size, lst):
    masks = []
    mnames = os.listdir(mpath)
    mnames.sort()
    for i in lst:
        #for mp in mnames:
        mp = mnames[i]
        m = Image.open(os.path.join(mpath, mp))
        m = m.resize(size, Image.NEAREST)
        m = np.array(m.convert('L'))
        m = np.array(m > 0).astype(np.uint8)
        m = cv2.dilate(m,
                       cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)),
                       iterations=4)
        masks.append(Image.fromarray(m * 255))
    return masks


def read_mask_static(mpath, size, n):
    masks = []
    m = Image.open(mpath)
    m = m.resize(size, Image.NEAREST)
    m = np.array(m.convert('L'))
    m = np.array(m > 0).astype(np.uint8)
    m = cv2.dilate(m,
                   cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)),
                   iterations=4)
    mm = Image.fromarray(m * 255)
    for i in range(0, n):
        masks.append(mm)
    return masks


def get_frame_count():
    if use_mp4:
        vidcap = cv2.VideoCapture(video_path)
        length = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
    else:
        lst = os.listdir(video_path)
        length = len(lst)
    return length


def read_frame_from_videos_by_index_list(index_lst):
    frames = []
    if use_mp4:
        vidcap = cv2.VideoCapture(video_path)
        for i in index_lst:
            vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
            success, image = vidcap.read()
            if not success:
                exit(1)
            image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
            frames.append(image)
    else:
        lst = os.listdir(video_path)
        lst.sort()
        fr_lst = [video_path + '/' + name for name in lst]
        for i in index_lst:
            image = cv2.imread(fr_lst[i])
            image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
            frames.append(image)
    return frames


#  read frames from video
def read_frame_from_videos():
    frames = []
    if use_mp4:
        vidcap = cv2.VideoCapture(video_path)
        success, image = vidcap.read()
        count = 0
        while success:
            image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
            frames.append(image)
            success, image = vidcap.read()
            count += 1
    else:
        lst = os.listdir(video_path)
        lst.sort()
        fr_lst = [video_path + '/' + name for name in lst]
        for fr in fr_lst:
            image = cv2.imread(fr)
            image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
            frames.append(image)
    return frames


# resize frames
def resize_frames(frames, size=None):
    if size is not None:
        frames = [f.resize(size) for f in frames]
    else:
        size = frames[0].size
    return frames, size

## Main Wroker

In [None]:
# set up models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

net = importlib.import_module('model.' + model_name)
model = net.InpaintGenerator().to(device)
data = torch.load(ckpt, map_location=device)
model.load_state_dict(data)
print(f'Loading model from: {ckpt}')
model.eval()

In [None]:

# video_path_lst = [
#                 '/content/input/delogo_examples/test_04.mp4',
#                 '/content/input/delogo_examples/test_05.mp4',
#                 '/content/input/delogo_examples/west1.mp4',
#                 '/content/input/delogo_examples/west2.mp4',
#                 '/content/input/delogo_examples/west3.mp4',
#                 '/content/input/delogo_examples/west4.mp4',
#                 '/content/input/delogo_examples/west5.mp4',
#                 '/content/input/delogo_examples/west6.mp4',
# ]
# mask_path_lst = [
#                  '/content/drive/input/delogo_examples/mask/test_04_mask.png',
#                  '/content/drive/input/delogo_examples/mask/test_05_mask.png',
#                  '/content/drive/input/delogo_examples/mask/west1_mask.png',
#                  '/content/drive/input/delogo_examples/mask/west2_mask.png',
#                  '/content/drive/input/delogo_examples/mask/west3_mask.png',
#                  '/content/drive/input/delogo_examples/mask/west5_mask.png',
#                  '/content/drive/input/delogo_examples/mask/west6_mask.png',
# ]

video_path_lst = [
                '/content/input/detext_examples/chinese1.mp4',
                '/content/input/detext_examples/chinese2.mp4',
                '/content/input/detext_examples/chinese3.mp4',
                '/content/input/detext_examples/chinese4.mp4',
                '/content/input/detext_examples/chinese5.mp4',
                '/content/input/detext_examples/english1.mp4',
                '/content/input/detext_examples/english2.mp4',
                '/content/input/detext_examples/french1.mp4',
                '/content/input/detext_examples/french2.mp4',
                '/content/input/detext_examples/others.mp4',
                '/content/input/detext_examples/russian.mp4',
                '/content/input/detext_examples/spanish.mp4',
]
mask_path_lst = [
                 '/content/input/detext_examples/mask/chinese1_mask.png',
                 '/content/input/detext_examples/mask/chinese2_mask.png',
                 '/content/input/detext_examples/mask/chinese3_mask.png',
                 '/content/input/detext_examples/mask/chinese4_mask.png',
                 '/content/input/detext_examples/mask/chinese5_mask.png',
                 '/content/input/detext_examples/mask/english1_mask.png',
                 '/content/input/detext_examples/mask/english2_mask.png',
                 '/content/input/detext_examples/mask/french1._mask.png',
                 '/content/input/detext_examples/mask/french2_mask.png',
                 '/content/input/detext_examples/mask/others._mask.png',
                 '/content/input/detext_examples/mask/russian_mask.png',
                 '/content/input/detext_examples/mask/spanish_mask.png',
]

for i in range(0, len(video_path_lst)):
  video_path = video_path_lst[i]
  mask_path = mask_path_lst[i]
  # prepare datset
  print(
      f'Loading videos and masks from: {video_path} | INPUT MP4 format: {use_mp4}'
  )
  video_length = get_frame_count()
  print('video_length={}'.format(video_length))

  h, w = size[1], size[0]
  comp_frames = [None] * video_length

  # completing holes by e2fgvi
  print(f'Start test...')
  for f in tqdm(range(0, video_length, neighbor_stride)):
      neighbor_ids = [
          i for i in range(max(0, f - neighbor_stride),
                              min(video_length, f + neighbor_stride + 1))
      ]
      ref_ids = get_ref_index(f, neighbor_ids, video_length)

      # read temp imgs and masks
      index_lst = neighbor_ids+ref_ids
      selected_frames = read_frame_from_videos_by_index_list(index_lst)
      selected_frames, size = resize_frames(selected_frames, size)

      selected_imgs = to_tensors()(selected_frames).unsqueeze(0) * 2 - 1

      selected_frames = [np.array(f).astype(np.uint8) for f in selected_frames]
      selected_imgs = selected_imgs.to(device)

      if mask_path.endswith('.png'):
          selected_masks_data = read_mask_static(mask_path, size, len(index_lst))
      else:
          selected_masks_data = read_mask_lst(mask_path, size, index_lst)
      binary_masks = [
          np.expand_dims((np.array(m) != 0).astype(np.uint8), 2) for m in selected_masks_data
      ]
      selected_masks = to_tensors()(selected_masks_data).unsqueeze(0).to(device)

      #selected_imgs = imgs[:1, neighbor_ids + ref_ids, :, :, :].to(device)
      #selected_masks = masks[:1, neighbor_ids + ref_ids, :, :, :].to(device)
      with torch.no_grad():
          masked_imgs = selected_imgs * (1 - selected_masks)
          mod_size_h = 60
          mod_size_w = 108
          h_pad = (mod_size_h - h % mod_size_h) % mod_size_h
          w_pad = (mod_size_w - w % mod_size_w) % mod_size_w
          masked_imgs = torch.cat(
              [masked_imgs, torch.flip(masked_imgs, [3])],
              3)[:, :, :, :h + h_pad, :]
          masked_imgs = torch.cat(
              [masked_imgs, torch.flip(masked_imgs, [4])],
              4)[:, :, :, :, :w + w_pad]
          pred_imgs, _ = model(masked_imgs, len(neighbor_ids))
          pred_imgs = pred_imgs[:, :, :h, :w]
          pred_imgs = (pred_imgs + 1) / 2
          pred_imgs = pred_imgs.cpu().permute(0, 2, 3, 1).numpy() * 255
          for i in range(len(neighbor_ids)):
              idx = neighbor_ids[i]
              img = np.array(pred_imgs[i]).astype(
                  np.uint8) * binary_masks[i] + selected_frames[i] * (
                      1 - binary_masks[i])
              if comp_frames[idx] is None:
                  comp_frames[idx] = img
              else:
                  comp_frames[idx] = comp_frames[idx].astype(
                      np.float32) * 0.5 + img.astype(np.float32) * 0.5

  print('Saving videos...')
  save_dir_name = '/content/drive/MyDrive/video_inpating/results/detext'
  ext_name = '_results.mp4'
  save_base_name = video_path.split('/')[-1]
  save_name = save_base_name.replace(
      '.mp4', ext_name) if use_mp4 else save_base_name + ext_name
  if not os.path.exists(save_dir_name):
      os.makedirs(save_dir_name)
  save_path = os.path.join(save_dir_name, save_name)
  writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*"mp4v"),
                              default_fps, size)
  for f in range(video_length):
      comp = comp_frames[f].astype(np.uint8)
      writer.write(cv2.cvtColor(comp, cv2.COLOR_BGR2RGB))
  writer.release()
  print(f'Finish test! The result video is saved in: {save_path}.')

Loading videos and masks from: /content/input/detext_examples/chinese1.mp4 | INPUT MP4 format: True
video_length=591
Start test...


 20%|█▉        | 39/197 [02:17<09:28,  3.60s/it]