<a href="https://colab.research.google.com/github/verneh/transformers/blob/main/mttr_experiment.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Dependency Installation

In [1]:
%%capture
!pip install av moviepy gdown yt-dlp ruamel.yaml einops timm transformers
import torch
import torchvision
import torchvision.transforms.functional as F
from einops import rearrange
import numpy as np
from PIL import Image, ImageDraw, ImageOps, ImageFont
from yt_dlp import YoutubeDL
from moviepy.editor import VideoFileClip, AudioFileClip, ImageSequenceClip
from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip
from IPython.display import HTML
from base64 import b64encode
from tqdm.notebook import trange, tqdm
from transformers import logging
logging.set_verbosity_error()

Initialization

In [38]:
video_url, (start_pt, end_pt), text_queries =  'https://www.youtube.com/watch?v=Nid2HId9EVY', (0.0, 3.0), ['a person suffering from emotional damage']

assert  0 < end_pt - start_pt <= 10, 'error - the subclip length must be 0-10 seconds long'
assert  1 <= len(text_queries) <= 2, 'error - 1-2 input text queries are expected'

In [39]:
download_resolution = 360
full_video_path = 'full_video.mp4'
input_clip_path = 'input_clip.mp4'

# download parameters:
ydl_opts = {'format': f'best[height<={download_resolution}]', 'overwrites': True, 'outtmpl': full_video_path}
# download the whole video:
with YoutubeDL(ydl_opts) as ydl:
    ydl.download([video_url])

# extract the relevant subclip:
with VideoFileClip(full_video_path) as video:
    subclip = video.subclip(start_pt, end_pt)
    subclip.write_videofile(input_clip_path)
    
# visualize the input clip:
input_clip = open(input_clip_path,'rb').read()
data_url = "data:video/mp4;base64," + b64encode(input_clip).decode()
HTML("""<video width=720 controls><source src="%s" type="video/mp4"></video>""" % data_url)

[youtube] Nid2HId9EVY: Downloading webpage
[youtube] Nid2HId9EVY: Downloading android player API JSON
[info] Nid2HId9EVY: Downloading 1 format(s): 18
Deleting existing file full_video.mp4
[download] Destination: full_video.mp4
[download] 100% of 214.58KiB in 00:00               
[MoviePy] >>>> Building video input_clip.mp4
[MoviePy] Writing audio in input_clipTEMP_MPY_wvf_snd.mp3


100%|██████████| 67/67 [00:00<00:00, 876.59it/s]

[MoviePy] Done.
[MoviePy] Writing video input_clip.mp4



100%|██████████| 83/83 [00:00<00:00, 135.68it/s]


[MoviePy] Done.
[MoviePy] >>>> Video ready: input_clip.mp4 



Inference

In [40]:
model, postprocessor = torch.hub.load('mttr2021/MTTR:main','mttr_refer_youtube_vos', force_reload=True)
model = model.cuda()

Downloading: "https://github.com/mttr2021/MTTR/archive/main.zip" to /root/.cache/torch/hub/main.zip


  0%|          | 0.00/2.83k [00:00<?, ?B/s]

In [41]:
class NestedTensor(object):
    def __init__(self, tensors, mask):
        self.tensors = tensors
        self.mask = mask

def nested_tensor_from_videos_list(videos_list):
    def _max_by_axis(the_list):
      maxes = the_list[0]
      for sublist in the_list[1:]:
          for index, item in enumerate(sublist):
              maxes[index] = max(maxes[index], item)
      return maxes

    max_size = _max_by_axis([list(img.shape) for img in videos_list])
    padded_batch_shape = [len(videos_list)] + max_size
    b, t, c, h, w = padded_batch_shape
    dtype = videos_list[0].dtype
    device = videos_list[0].device
    padded_videos = torch.zeros(padded_batch_shape, dtype=dtype, device=device)
    videos_pad_masks = torch.ones((b, t, h, w), dtype=torch.bool, device=device)
    for vid_frames, pad_vid_frames, vid_pad_m in zip(videos_list, padded_videos, videos_pad_masks):
        pad_vid_frames[:vid_frames.shape[0], :, :vid_frames.shape[2], :vid_frames.shape[3]].copy_(vid_frames)
        vid_pad_m[:vid_frames.shape[0], :vid_frames.shape[2], :vid_frames.shape[3]] = False
    return NestedTensor(padded_videos.transpose(0, 1), videos_pad_masks.transpose(0, 1))

def apply_mask(image, mask, color, transparency=0.7):
    mask = mask[..., np.newaxis].repeat(repeats=3, axis=2)
    mask = mask * transparency
    color_matrix = np.ones(image.shape, dtype=np.float) * color
    out_image = color_matrix * mask + image * (1.0 - mask)
    return out_image

In [42]:
window_length = 24  # length of window during inference
window_overlap = 6  # overlap (in frames) between consecutive windows

with torch.inference_mode():
  # read and preprocess the video clip:
  video, audio, meta = torchvision.io.read_video(filename=input_clip_path)
  video = rearrange(video, 't h w c -> t c h w')
  input_video = F.resize(video, size=360, max_size=640).cuda()
  input_video = input_video.to(torch.float).div_(255)
  input_video = F.normalize(input_video, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  video_metadata = {'resized_frame_size': input_video.shape[-2:], 'original_frame_size': video.shape[-2:]}
  
  # partition the clip into overlapping windows of frames:
  windows = [input_video[i:i+window_length] for i in range(0, len(input_video), window_length - window_overlap)]
  # clean up the text queries:
  text_queries = [" ".join(q.lower().split()) for q in text_queries]

  pred_masks_per_query = []
  t, _, h, w = video.shape
  for text_query in tqdm(text_queries, desc='text queries'):
    pred_masks = torch.zeros(size=(t, 1, h, w))
    for i, window in enumerate(tqdm(windows, desc='windows')):
      window = nested_tensor_from_videos_list([window])
      valid_indices = torch.arange(len(window.tensors)).cuda()
      outputs = model(window, valid_indices, [text_query])
      window_masks = postprocessor(outputs, [video_metadata], window.tensors.shape[-2:])[0]['pred_masks']
      win_start_idx = i*(window_length-window_overlap)
      pred_masks[win_start_idx:win_start_idx + window_length] = window_masks
    pred_masks_per_query.append(pred_masks)

text queries:   0%|          | 0/1 [00:00<?, ?it/s]

windows:   0%|          | 0/5 [00:00<?, ?it/s]

In [43]:
# RGB colors for instance masks:
light_blue = (41, 171, 226)
purple = (237, 30, 121)
dark_green = (35, 161, 90)
orange = (255, 148, 59)
colors = np.array([light_blue, purple, dark_green, orange])

# width (in pixels) of the black strip above the video on which the text queries will be displayed:
text_border_height_per_query = 36

video_np = rearrange(video, 't c h w -> t h w c').numpy() / 255.0
# del video
pred_masks_per_frame = rearrange(torch.stack(pred_masks_per_query), 'q t 1 h w -> t q h w').numpy()
masked_video = []
for vid_frame, frame_masks in tqdm(zip(video_np, pred_masks_per_frame), total=len(video_np), desc='applying masks...'):
  # apply the masks:
  for inst_mask, color in zip(frame_masks, colors):
    vid_frame = apply_mask(vid_frame, inst_mask, color / 255.0)
  vid_frame = Image.fromarray((vid_frame * 255).astype(np.uint8))
  # visualize the text queries:
  vid_frame = ImageOps.expand(vid_frame, border=(0, len(text_queries)*text_border_height_per_query, 0, 0))
  W, H = vid_frame.size
  draw = ImageDraw.Draw(vid_frame)
  font = ImageFont.truetype(font='LiberationSans-Regular.ttf', size=30)
  for i, (text_query, color) in enumerate(zip(text_queries, colors), start=1):
      w, h = draw.textsize(text_query, font=font)
      draw.text(((W - w) / 2, (text_border_height_per_query * i) - h - 3),
                text_query, fill=tuple(color) + (255,), font=font)
  masked_video.append(np.array(vid_frame))

# generate and save the output clip:
output_clip_path = 'output_clip.mp4'
clip = ImageSequenceClip(sequence=masked_video, fps=meta['video_fps'])
clip = clip.set_audio(AudioFileClip(input_clip_path))
clip.write_videofile(output_clip_path, fps=meta['video_fps'], audio=True)
del masked_video

# visualize the output clip:
output_clip = open(output_clip_path,'rb').read()
data_url = "data:video/mp4;base64," + b64encode(output_clip).decode()
HTML("""<video width=720 controls><source src="%s" type="video/mp4"></video>""" % data_url)

applying masks...:   0%|          | 0/83 [00:00<?, ?it/s]

[MoviePy] >>>> Building video output_clip.mp4
[MoviePy] Writing audio in output_clipTEMP_MPY_wvf_snd.mp3


100%|██████████| 67/67 [00:00<00:00, 836.18it/s]

[MoviePy] Done.
[MoviePy] Writing video output_clip.mp4



100%|██████████| 84/84 [00:00<00:00, 148.16it/s]


[MoviePy] Done.
[MoviePy] >>>> Video ready: output_clip.mp4 

