# Proprocess of an input video
### At the moment this scripts only supports the creation of three subfolders containing: (1) the frames (2) masks for each frame (3) the masked object for each frame
### We need to add another folder containing (4) the inpainted background for each frame

In [27]:
# change these to your own paths .....
path_to_videos_input = "/home/vinker/dev/input_images/videos_input"
path_to_output_frames = "/home/vinker/dev/input_images/videos_input/frames"
unet_dir = "/home/vinker/dev/backgroundCLIPasso/CLIPasso/U2Net_/saved_models/u2net.pth"
# the name of the video you want to process
video_filename = "horse_vid.mp4"

In [28]:
# imports
import pylab
import imageio
import matplotlib.pyplot as plt
import os
import sys 
from torchvision import transforms
from PIL import Image
import torch
import PIL
from skimage.transform import resize
import numpy as np
from scipy import ndimage
from torch.utils.data import DataLoader
from torchvision import models, transforms

p = os.path.abspath('..')
sys.path.insert(1, p)
import sketch_utils as sketch_utils
from U2Net_.model import U2NET
import u2net_utils

device = torch.device("cuda:0" if (
            torch.cuda.is_available() and torch.cuda.device_count() > 0) else "cpu")

In [29]:
video_path = f"{path_to_videos_input}/{video_filename}"
video_name = os.path.splitext(video_filename)[0]
frames_res_dir = f"{path_to_output_frames}/{video_name}"
if not os.path.isdir(frames_res_dir):
    os.mkdir(frames_res_dir)

## Utils

In [30]:
def extract_frames(video_path, frames_output_dir, rescale_frames=1, start_frame=137, end_frame=140, display_frames=1):
    # Notes: 
    # (1) This spesific script also cut the frames to be squared. Its better to insert a square video, and then change the flag "rescale_frames"
    # (2) Some videos ar every long, so here we only take the frames from "start_frame" to "end_frame". change this according to your video
    print("=" * 50)
    print(f"Reading video from [{video_path}] ...")
    print(f"Saving frames number [{start_frame}-{end_frame}] to [{frames_output_dir}]")
    if rescale_frames:
        print("Applying rescale to the video")
    vid = imageio.get_reader(video_path,  'ffmpeg')
    for frame_index, image in enumerate(vid):
        if frame_index > end_frame:
            break
        elif frame_index > start_frame:
            if rescale_frames:
                height, width = image.shape[0], image.shape[1]
                max_, min_ = max(height, width), min(height, width)
                if width > height:     
                    image = image[:, (max_ - min_) // 2: (max_ - min_) // 2 + min_]
                else:
                    image = image[(max_ - min_) // 2: (max_ - min_) // 2 + min_, :,: ]
            if display_frames:
                plt.imshow(image)
                plt.show()
                plt.close()
            imageio.imsave(f"{frames_output_dir}/{frame_index:03d}.png", image)

def create_output_dirs(frames_output_dir_top):
    # for each video, we want to extract its frames, the mask for each frame, the masked object, and the inpainted background
    if not os.path.isdir(frames_output_dir_top):
        os.mkdir(frames_output_dir_top)
    subdirs = ["scene", "masks", "object", "background"]
    for dirname in subdirs:
        if not os.path.isdir(f"{frames_output_dir_top}/{dirname}"):
            os.mkdir(f"{frames_output_dir_top}/{dirname}")
    print(f"Your frames will be saved to [{frames_output_dir_top}]")

    
def extract_masks(scene_frames_dir, output_masks_path, output_object_path, use_gpu=1, display_frames=1):
    model_dir = os.path.join(unet_dir)
    net = U2NET(3, 1)
    if torch.cuda.is_available() and use_gpu:
        net.load_state_dict(torch.load(model_dir))
        net.to(device)
    else:
        net.load_state_dict(torch.load(model_dir, map_location='cpu'))
    net.eval()
    
    for frame_i in os.listdir(scene_frames_dir):
        frame_path = f"{scene_frames_dir}/{frame_i}"
        frame = Image.open(frame_path).convert("RGB")
        if display_frames:
            plt.imshow(frame)
            plt.show()
            plt.close()
        
        w, h = frame.size[0], frame.size[1]
        test_salobj_dataset = u2net_utils.SalObjDataset(imgs_list=[frame],
                                                        lbl_name_list=[],
                                                        transform=transforms.Compose([u2net_utils.RescaleT(320),
                                                                                      u2net_utils.ToTensorLab(flag=0)]))
        test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                            batch_size=1,
                                            shuffle=False,
                                            num_workers=1)
        input_im_trans = next(iter(test_salobj_dataloader))
        
        with torch.no_grad():
            input_im_trans = input_im_trans.type(torch.FloatTensor)
            d1, d2, d3, d4, d5, d6, d7 = net(input_im_trans.to(device))
        pred = d1[:, 0, :, :]
        pred = (pred - pred.min()) / (pred.max() - pred.min())
        predict = pred
        predict[predict < 0.5] = 0
        predict[predict >= 0.5] = 1

        mask = torch.cat([predict, predict, predict], axis=0).permute(1, 2, 0)
        mask = mask.cpu().numpy()
        mask = resize(mask, (h, w), anti_aliasing=False)
        mask[mask < 0.5] = 0
        mask[mask >= 0.5] = 1
        
        imageio.imsave(f"{output_masks_path}/{frame_i}", mask)
        if display_frames:
            plt.imshow(mask)
            plt.show()
            plt.close()
            
        frame_np = np.array(frame)
        frame_np = frame_np / frame_np.max()
        masked_obg = mask * frame_np
        masked_obg[mask == 0] = 1
        masked_obg = (masked_obg / masked_obg.max() * 255).astype(np.uint8)
        masked_obg = Image.fromarray(masked_obg)
        
        imageio.imsave(f"{output_object_path}/{frame_i}", masked_obg)
        if display_frames:
            plt.imshow(masked_obg)
            plt.show()
            plt.close()

## Run the entire preprocess

In [31]:
display_frames = False
extract_frames(video_path, f"{frames_res_dir}/scene", rescale_frames=1, start_frame=137, end_frame=140, display_frames=display_frames)
create_output_dirs(frames_res_dir)
extract_masks(f"{frames_res_dir}/scene", f"{frames_res_dir}/masks", f"{frames_res_dir}/object", display_frames=display_frames)

Reading video from [/home/vinker/dev/input_images/videos_input/horse_vid.mp4] ...
Saving frames number [137-140] to [/home/vinker/dev/input_images/videos_input/frames/horse_vid/scene]
Applying rescale to the video
Your frames will be saved to [/home/vinker/dev/input_images/videos_input/frames/horse_vid]


