<a href="https://colab.research.google.com/github/softmurata/colab_notebooks/blob/main/computervision/CoTracker.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Installation

In [None]:
!git clone https://github.com/facebookresearch/co-tracker
%cd co-tracker
!pip install -e .
!pip install opencv-python einops timm matplotlib moviepy flow_vis
!mkdir checkpoints
%cd checkpoints
!wget https://dl.fbaipublicfiles.com/cotracker/cotracker_stride_4_wind_8.pth

Utils

In [None]:
%cd ..
import os
import torch

from base64 import b64encode
from cotracker.utils.visualizer import Visualizer, read_video_from_path
from IPython.display import HTML

In [None]:
video = read_video_from_path('./assets/apple.mp4')
video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float()

def show_video(video_path):
    video_file = open(video_path, "r+b").read()
    video_url = f"data:video/mp4;base64,{b64encode(video_file).decode()}"
    return HTML(f"""<video width="640" height="480" autoplay loop controls><source src="{video_url}"></video>""")

show_video("./assets/apple.mp4")

Inference

In [4]:
from cotracker.predictor import CoTrackerPredictor

model = CoTrackerPredictor(
    checkpoint=os.path.join(
        './checkpoints/cotracker_stride_4_wind_8.pth'
    )
)

In [5]:
pred_tracks, pred_visibility = model(video, grid_size=30)

In [None]:
vis = Visualizer(save_dir='./videos', pad_value=100)
vis.visualize(video=video, tracks=pred_tracks, filename='teaser');
show_video("./videos/teaser_pred_track.mp4")

Tracking Manually selected point

In [None]:
queries = torch.tensor([
    [0., 400., 350.],  # point tracked from the first frame
    [10., 600., 500.], # frame number 10
    [20., 750., 600.], # ...
    [30., 900., 200.]
]).cuda()

# show
import matplotlib.pyplot as plt
# Create a list of frame numbers corresponding to each point
frame_numbers = queries[:,0].int().tolist()

fig, axs = plt.subplots(2, 2)
axs = axs.flatten()

for i, (query, frame_number) in enumerate(zip(queries, frame_numbers)):
    ax = axs[i]
    ax.plot(query[1].item(), query[2].item(), 'ro')

    ax.set_title("Frame {}".format(frame_number))
    ax.set_xlim(0, video.shape[4])
    ax.set_ylim(0, video.shape[3])
    ax.invert_yaxis()

plt.tight_layout()
plt.show()

In [8]:
pred_tracks, __ = model(video, queries=queries[None])

In [None]:
vis = Visualizer(
    save_dir='./videos',
    linewidth=6,
    mode='cool',
    tracks_leave_trace=-1
)
vis.visualize(
    video=video,
    tracks=pred_tracks,
    filename='queries');

show_video("./videos/queries_pred_track.mp4")

Tracking forward and backward from the frame number x

In [None]:
grid_size = 30
grid_query_frame = 20

pred_tracks, __ = model(video, grid_size=grid_size, grid_query_frame=grid_query_frame, backward_tracking=True)
vis.visualize(
    video=video,
    tracks=pred_tracks,
    filename='grid_query_20_backward',
    query_frame=grid_query_frame);

show_video("./videos/grid_query_20_backward_pred_track.mp4")

Regular Grid + Segmentation Mask

In [None]:
import numpy as np
from PIL import Image
grid_size = 120

input_mask = './assets/apple_mask.png'
segm_mask = np.array(Image.open(input_mask))

pred_tracks, __ = model(video, grid_size=grid_size, segm_mask=torch.from_numpy(segm_mask)[None, None])
vis = Visualizer(
    save_dir='./videos',
    pad_value=100,
    linewidth=2,
)
vis.visualize(
    video=video,
    tracks=pred_tracks,
    filename='segm_grid');

show_video("./videos/segm_grid_pred_track.mp4")

Tips

In [None]:
!git clone https://github.com/nstevens1040/mp42gif.git
%cd mp42gif
!pip install -r requirements.txt

In [None]:
!python mp42gif.py --input /content/co-tracker/videos/teaser_pred_track.mp4