<a href="https://colab.research.google.com/github/softmurata/colab_notebooks/blob/main/analysis/ptlflow_inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# PTLFlow inference demo

This notebook shows a basic example on how to use PTLFlow ([https://github.com/hmorimitsu/ptlflow](https://github.com/hmorimitsu/ptlflow)) to estimate the optical flow between a pair of images.

In the first example, we will use the `infer.py` script provided by PTLFlow to do the estimation. The second example will show how to write a simple code to estimate the optical flow without the script.

More details can be found in the official documentation at [https://ptlflow.readthedocs.io](https://ptlflow.readthedocs.io).

In [None]:
# First install the PTLFlow package with pip
!pip install ptlflow

In [None]:
# Compatible occurs, Please uninstall or install suitable version with pytorch and torchvision
!pip uninstall torchtext torchaudio

In [None]:
# Import ptlflow and some dependencies for the example
import ptlflow

import cv2 as cv
from google.colab.patches import cv2_imshow

In [None]:
# Download two images to serve as inputs to the optical flow model
# The images below are from the MPI-Sintel dataset: http://sintel.is.tue.mpg.de/
!wget https://github.com/hmorimitsu/sift-flow-gpu/raw/master/mpi_sintel_images/frame_0001.png
!wget https://github.com/hmorimitsu/sift-flow-gpu/raw/master/mpi_sintel_images/frame_0002.png
cv2_imshow(cv.imread('frame_0001.png'))
cv2_imshow(cv.imread('frame_0002.png'))

## Example 1 - with infer.py

We first need to download the `infer.py` script. This can be done with the code below.

In [None]:
ptlflow.download_scripts()

# If you want to download the script directly from a terminal, you can run:
# python -c "import ptlflow; ptlflow.download_scripts()"

# Go to the folder where the scripts were downloaded to
%cd ptlflow_scripts

/content/ptlflow_scripts


Now that we have the script, we can use it to estimate the optical flow between our two images.

The code below does this using the small version of the RAFT model (see [https://github.com/princeton-vl/RAFT](https://github.com/princeton-vl/RAFT)). We are also going to initialize the RAFT network with the weights obtained after training on the FlyingThings3D dataset.

The `--write_outputs` argument is used to save the outputs of the network to the disk.

In [None]:
!python infer.py craft --pretrained_ckpt things --input_path ../frame_0001.png ../frame_0002.png --write_outputs

In [None]:
# Let's visualize the predicted flow
flow_pred = cv.imread('outputs/inference/flows_viz/frame_0001.png')
cv2_imshow(flow_pred)

## Example 2 - without infer.py

This example will show how to write a short code to do the same thing as in the previous example.

In [None]:
# Additional dependencies for this example
from ptlflow.utils.io_adapter import IOAdapter
from ptlflow.utils import flow_utils

In [None]:
# Load the two images
img1 = cv.imread('../frame_0001.png')
img2 = cv.imread('../frame_0002.png')

# Get an initialized model from PTLFlow
model = ptlflow.get_model('raft_small', 'things')
model.eval()

# IOAdapter is a helper to transform the two images into the input format accepted by PTLFlow models
io_adapter = IOAdapter(model, img1.shape[:2])
inputs = io_adapter.prepare_inputs([img1, img2])

In [None]:
# Forward the inputs to obtain the model predictions
predictions = model(inputs)

# Some padding may have been added during prepare_inputs. The line below ensures that the padding is removed
# to make the predictions have the same size as the original images.
predictions = io_adapter.unpad_and_unscale(predictions)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [None]:
# Visualize the predicted flow
flow = predictions['flows'][0, 0]  # Remove batch and sequence dimensions
flow = flow.permute(1, 2, 0)  # change from CHW to HWC shape
flow = flow.detach().numpy()
flow_viz = flow_utils.flow_to_rgb(flow)  # Represent the flow as RGB colors
flow_viz = cv.cvtColor(flow_viz, cv.COLOR_BGR2RGB)  # OpenCV uses BGR format
cv2_imshow(flow_viz)

In [None]:
#@title Training
# !python -c "import ptlflow; print(ptlflow.get_trainable_model_names())"