-
Notifications
You must be signed in to change notification settings - Fork 7.2k
Add example for optical flow visualizaition and RAFT #5316
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
355ae63
17364b4
d8a7237
1948937
68d6795
ff92efe
f2ab8a3
7bc3482
dd8dc34
7c6a550
d1feb73
943a1dc
d786cf8
ffa8212
2bd7f93
f28bf56
ee55699
1c171f5
4322dad
c9c7761
2c7a468
628b993
dd7e1a8
2baa55d
e305f16
0352a47
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,198 @@ | ||||||||||||||||||
""" | ||||||||||||||||||
===================================================== | ||||||||||||||||||
Optical Flow: Predicting movement with the RAFT model | ||||||||||||||||||
===================================================== | ||||||||||||||||||
|
||||||||||||||||||
Optical flow is the task of predicting movement between two images, usually two | ||||||||||||||||||
consecutive frames of a video. Optical flow models take two images as input, and | ||||||||||||||||||
predict a flow: the flow indicates the displacement of every single pixel in the | ||||||||||||||||||
first image, and maps it to its corresponding pixel in the second image. Flows | ||||||||||||||||||
are (2, H, W)-dimensional tensors, where the first axis corresponds to the | ||||||||||||||||||
predicted horizontal and vertical displacements. | ||||||||||||||||||
|
||||||||||||||||||
The following example illustrates how torchvision can be used to predict flows | ||||||||||||||||||
using our implementation of the RAFT model. We will also see how to convert the | ||||||||||||||||||
predicted flows to RGB images for visualization. | ||||||||||||||||||
""" | ||||||||||||||||||
|
||||||||||||||||||
import numpy as np | ||||||||||||||||||
import torch | ||||||||||||||||||
import matplotlib.pyplot as plt | ||||||||||||||||||
import torchvision.transforms.functional as F | ||||||||||||||||||
import torchvision.transforms as T | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
plt.rcParams["savefig.bbox"] = "tight" | ||||||||||||||||||
# sphinx_gallery_thumbnail_number = 2 | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
def plot(imgs, **imshow_kwargs): | ||||||||||||||||||
if not isinstance(imgs[0], list): | ||||||||||||||||||
# Make a 2d grid even if there's just 1 row | ||||||||||||||||||
imgs = [imgs] | ||||||||||||||||||
|
||||||||||||||||||
num_rows = len(imgs) | ||||||||||||||||||
num_cols = len(imgs[0]) | ||||||||||||||||||
_, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False) | ||||||||||||||||||
for row_idx, row in enumerate(imgs): | ||||||||||||||||||
for col_idx, img in enumerate(row): | ||||||||||||||||||
ax = axs[row_idx, col_idx] | ||||||||||||||||||
img = F.to_pil_image(img.to("cpu")) | ||||||||||||||||||
ax.imshow(np.asarray(img), **imshow_kwargs) | ||||||||||||||||||
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) | ||||||||||||||||||
|
||||||||||||||||||
plt.tight_layout() | ||||||||||||||||||
|
||||||||||||||||||
################################### | ||||||||||||||||||
# Reading Videos Using Torchvision | ||||||||||||||||||
# -------------------------------- | ||||||||||||||||||
# We will first read a video using :func:`~torchvision.io.read_video`. | ||||||||||||||||||
# Alternatively one can use the new :class:`~torchvision.io.VideoReader` API (if | ||||||||||||||||||
# torchvision is built from source). | ||||||||||||||||||
# The video we will use here is free of use from `pexels.com | ||||||||||||||||||
# <https://www.pexels.com/video/a-man-playing-a-game-of-basketball-5192157/>`_, | ||||||||||||||||||
# credits go to `Pavel Danilyuk <https://www.pexels.com/@pavel-danilyuk>`_. | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
import tempfile | ||||||||||||||||||
from pathlib import Path | ||||||||||||||||||
from urllib.request import urlretrieve | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
video_url = "https://download.pytorch.org/tutorial/pexelscom_pavel_danilyuk_basketball_hd.mp4" | ||||||||||||||||||
video_path = Path(tempfile.mkdtemp()) / "basketball.mp4" | ||||||||||||||||||
_ = urlretrieve(video_url, video_path) | ||||||||||||||||||
|
||||||||||||||||||
######################### | ||||||||||||||||||
# :func:`~torchvision.io.read_video` returns the video frames, audio frames and | ||||||||||||||||||
# the metadata associated with the video. In our case, we only need the video | ||||||||||||||||||
# frames. | ||||||||||||||||||
# | ||||||||||||||||||
# Here we will just make 2 predictions between 2 pre-selected pairs of frames, | ||||||||||||||||||
# namely frames (100, 101) and (150, 151). Each of these pairs corresponds to a | ||||||||||||||||||
# single model input. | ||||||||||||||||||
|
||||||||||||||||||
from torchvision.io import read_video | ||||||||||||||||||
frames, _, _ = read_video(str(video_path)) | ||||||||||||||||||
frames = frames.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) | ||||||||||||||||||
|
||||||||||||||||||
img1_batch = torch.stack([frames[100], frames[150]]) | ||||||||||||||||||
img2_batch = torch.stack([frames[101], frames[151]]) | ||||||||||||||||||
|
||||||||||||||||||
plot(img1_batch) | ||||||||||||||||||
|
||||||||||||||||||
######################### | ||||||||||||||||||
# The RAFT model that we will use accepts RGB float images with pixel values in | ||||||||||||||||||
# [-1, 1]. The frames we got from :func:`~torchvision.io.read_video` are int | ||||||||||||||||||
# images with values in [0, 255], so we will have to pre-process them. We also | ||||||||||||||||||
# reduce the image sizes for the example to run faster. Image dimension must be | ||||||||||||||||||
# divisible by 8. | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do the image dimensions need to be divisble by 8? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a hardcoded constraint within the model, the feature extractor downsamples the images by 8 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My doubt is it hard necessicity that image sizes should be divisible by 8. Or its adjusted by the model. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's really a hardcoded constraint and the model cannot accept images that aren't divisible by 8 (even if it wanted to):
It has to be exactly divisible by 8 because we first downsample the inputs by 8, predict a downsampled flow, and then upsample the predicted flow by a factor of 8: vision/torchvision/models/optical_flow/_utils.py Lines 26 to 32 in 74a1efc
If the image wasn't a multiple of 8 to begin with, we wouldn't be able to upsample the flow to the right dimensions. The fact that it's |
||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
def preprocess(batch): | ||||||||||||||||||
transforms = T.Compose( | ||||||||||||||||||
[ | ||||||||||||||||||
T.ConvertImageDtype(torch.float32), | ||||||||||||||||||
T.Normalize(mean=0.5, std=0.5), # map [0, 1] into [-1, 1] | ||||||||||||||||||
T.Resize(size=(520, 960)), | ||||||||||||||||||
] | ||||||||||||||||||
) | ||||||||||||||||||
batch = transforms(batch) | ||||||||||||||||||
return batch | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
# If you can, run this example on a GPU, it will be a lot faster. | ||||||||||||||||||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||||||||||||||||||
|
||||||||||||||||||
img1_batch = preprocess(img1_batch).to(device) | ||||||||||||||||||
img2_batch = preprocess(img2_batch).to(device) | ||||||||||||||||||
|
||||||||||||||||||
print(f"shape = {img1_batch.shape}, dtype = {img1_batch.dtype}") | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
#################################### | ||||||||||||||||||
# Estimating Optical flow using RAFT | ||||||||||||||||||
# ---------------------------------- | ||||||||||||||||||
# We will use our RAFT implementation from | ||||||||||||||||||
# :func:`~torchvision.models.optical_flow.raft_large`, which follows the same | ||||||||||||||||||
# architecture as the one described in the `original paper <https://arxiv.org/abs/2003.12039>`_. | ||||||||||||||||||
# We also provide the :func:`~torchvision.models.optical_flow.raft_small` model | ||||||||||||||||||
# builder, which is smaller and faster to run, sacrificing a bit of accuracy. | ||||||||||||||||||
|
||||||||||||||||||
from torchvision.models.optical_flow import raft_large | ||||||||||||||||||
|
||||||||||||||||||
model = raft_large(pretrained=True, progress=False).to(device) | ||||||||||||||||||
model = model.eval() | ||||||||||||||||||
|
||||||||||||||||||
list_of_flows = model(img1_batch.to(device), img2_batch.to(device)) | ||||||||||||||||||
print(f"type = {type(list_of_flows)}") | ||||||||||||||||||
print(f"length = {len(list_of_flows)} = number of iterations of the model") | ||||||||||||||||||
|
||||||||||||||||||
#################################### | ||||||||||||||||||
# The RAFT model outputs lists of predicted flows where each entry is a | ||||||||||||||||||
# (N, 2, H, W) batch of predicted flows that corresponds to a given "iteration" | ||||||||||||||||||
# in the model. For more details on the iterative nature of the model, please | ||||||||||||||||||
# refer to the `original paper <https://arxiv.org/abs/2003.12039>`_. Here, we | ||||||||||||||||||
# are only interested in the final predicted flows (they are the most acccurate | ||||||||||||||||||
# ones), so we will just retrieve the last item in the list. | ||||||||||||||||||
# | ||||||||||||||||||
# As described above, a flow is a tensor with dimensions (2, H, W) (or (N, 2, H, | ||||||||||||||||||
# W) for batches of flows) where each entry corresponds to the horizontal and | ||||||||||||||||||
# vertical displacement of each pixel from the first image to the second image. | ||||||||||||||||||
# Note that the predicted flows are in "pixel" unit, they are not normalized | ||||||||||||||||||
# w.r.t. the dimensions of the images. | ||||||||||||||||||
predicted_flows = list_of_flows[-1] | ||||||||||||||||||
print(f"dtype = {predicted_flows.dtype}") | ||||||||||||||||||
print(f"shape = {predicted_flows.shape} = (N, 2, H, W)") | ||||||||||||||||||
print(f"min = {predicted_flows.min()}, max = {predicted_flows.max()}") | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
#################################### | ||||||||||||||||||
# Visualizing predicted flows | ||||||||||||||||||
# --------------------------- | ||||||||||||||||||
# Torchvision provides the :func:`~torchvision.utils.flow_to_image` utlity to | ||||||||||||||||||
# convert a flow into an RGB image. It also supports batches of flows. | ||||||||||||||||||
oke-aditya marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||
# each "direction" in the flow will be mapped to a given RGB color. In the | ||||||||||||||||||
# images below, pixels with similar colors are assumed by the model to be moving | ||||||||||||||||||
# in similar directions. The model is properly able to predict the movement of | ||||||||||||||||||
# the ball and the player. Note in particular the different predicted direction | ||||||||||||||||||
# of the ball in the first image (going to the left) and in the second image | ||||||||||||||||||
# (going up). | ||||||||||||||||||
|
||||||||||||||||||
from torchvision.utils import flow_to_image | ||||||||||||||||||
|
||||||||||||||||||
flow_imgs = flow_to_image(predicted_flows) | ||||||||||||||||||
|
||||||||||||||||||
# The images have been mapped into [-1, 1] but for plotting we want them in [0, 1] | ||||||||||||||||||
img1_batch = [(img1 + 1) / 2 for img1 in img1_batch] | ||||||||||||||||||
|
||||||||||||||||||
grid = [[img1, flow_img] for (img1, flow_img) in zip(img1_batch, flow_imgs)] | ||||||||||||||||||
plot(grid) | ||||||||||||||||||
|
||||||||||||||||||
NicolasHug marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||
#################################### | ||||||||||||||||||
# Bonus: Creating GIFs of predicted flows | ||||||||||||||||||
# --------------------------------------- | ||||||||||||||||||
# In the example above we have only shown the predicted flows of 2 pairs of | ||||||||||||||||||
# frames. A fun way to apply the Optical Flow models is to run the model on an | ||||||||||||||||||
# entire video, and create a new video from all the predicted flows. Below is a | ||||||||||||||||||
# snippet that can get you started with this. We comment out the code, because | ||||||||||||||||||
# this example is being rendered on a machine without a GPU, and it would take | ||||||||||||||||||
# too long to run it. | ||||||||||||||||||
|
||||||||||||||||||
# from torchvision.io import write_jpeg | ||||||||||||||||||
# for i, (img1, img2) in enumerate(zip(frames, frames[1:])): | ||||||||||||||||||
# # Note: it would be faster to predict batches of flows instead of individual flows | ||||||||||||||||||
# img1 = preprocess(img1[None]).to(device) | ||||||||||||||||||
# img2 = preprocess(img2[None]).to(device) | ||||||||||||||||||
|
||||||||||||||||||
# list_of_flows = model(img1_batch, img2_batch) | ||||||||||||||||||
# predicted_flow = list_of_flows[-1][0] | ||||||||||||||||||
# flow_img = flow_to_image(predicted_flow).to("cpu") | ||||||||||||||||||
# output_folder = "/tmp/" # Update this to the folder of your choice | ||||||||||||||||||
# write_jpeg(flow_img, output_folder + f"predicted_flow_{i}.jpg") | ||||||||||||||||||
|
||||||||||||||||||
#################################### | ||||||||||||||||||
# Once the .jpg flow images are saved, you can convert them into a video or a | ||||||||||||||||||
# GIF using ffmpeg with e.g.: | ||||||||||||||||||
# | ||||||||||||||||||
# ffmpeg -f image2 -framerate 30 -i predicted_flow_%d.jpg -loop -1 flow.gif | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess we can put the rendered GIF here? (If sphinx can render it) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure sphinx can. It's a good idea though, I'll do it once I find a way to reduce the size of the gif to a decent size (the ones i have right now are 300MB...) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like we can. We need to compress the GIF though :( Maybe possible by running in very small video size say 120x60? For 3 seconds? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I couldn't find the thumbnail. You, sure it's there?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes this tells sphinx-gallery to use the second image as the thumbnail instead of the first one: https://1177465-73328905-gh.circle-artifacts.com/0/docs/generated/torchvision.models.optical_flow.raft_large.html#torchvision.models.optical_flow.raft_large