This repository is the PyTorch implementation of STaRFlow, a recurrent convolutional neural network for multi-frame optical flow estimation. This algorithm is presented in our paper STaRFlow: A SpatioTemporal Recurrent Cell for Lightweight Multi-Frame Optical Flow Estimation, Pierre Godet, Alexandre Boulch, Aurélien Plyer, and Guy Le Besnerais. [Preprint]
Please cite our paper if you find our work useful.
@article{godet2020starflow,
title={STaRFlow: A SpatioTemporal Recurrent Cell for Lightweight Multi-Frame Optical Flow Estimation},
author={Godet, Pierre and Boulch, Alexandre and Plyer, Aur{\'e}lien and Le Besnerais, Guy},
journal={arXiv preprint arXiv:2007.05481},
year={2020}
}
Contact: pierre.godet@onera.fr
This code has been developed and tested under Anaconda(Python 3.7, scipy 1.1, numpy 1.16), Pytorch 1.1 and CUDA 10.1 on Ubuntu 18.04.
-
Please install the followings:
- Anaconda (Python 3.7)
- PyTorch 1.1 (Linux, Conda, Python 3.7, CUDA 10) (
conda install pytorch==1.1.0 torchvision==0.3.0 cudatoolkit=10.0 -c pytorch
) - Depending on your system, configure
-gencode
,-ccbin
,cuda-path
inmodels/correlation_package/setup.py
accordingly - scipy 1.1 (
conda install scipy=1.1
) - colorama (
conda install colorama
) - tqdm 4.32 (
conda install -c conda-forge tqdm=4.32
) - pypng (
pip install pypng
)
-
Then, install the correlation package:
./install.sh
The saved_checkpoint
folder contains the pre-trained models of STaRFlow trained on
- FlyingChairsOcc -> FlyingThings3D, or
- FlyingChairsOcc -> FlyingThings3D -> MPI Sintel, or
- FlyingChairsOcc -> FlyingThings3D -> KITTI (2012 and 2015).
The script inference.py
can be used for testing the pre-trained models. Example:
python inference.py \
--model StarFlow \
--checkpoint saved_checkpoint/StarFlow_things/checkpoint_best.ckpt \
--data-root /data/mpisintelcomplete/training/final/ambush_6/ \
--file-list frame_0004.png frame_0005.png frame_0006.png frame_0007.png
By default, it saves the results in ./output/
.
Data-loaders for multi-frame training can be found in the datasets
folder, multi-frame losses are in losses.py
, and every architecture used in the experiments presented in our paper is available in the models
folder.
The datasets used for this project are followings:
- FlyingChairsOcc dataset
- MPI Sintel Dataset
- KITTI Optical Flow 2015 and KITTI Optical Flow 2012
- FlyingThings3D subset
The scripts
folder contains training scripts for STaRFlow.
To train the model, you can simply run the script file, e.g., ./train_starflow_chairsocc.sh
.
In script files, please configure your own experiment directory (EXPERIMENTS_HOME) and dataset directory in your local system (e.g., SINTEL_HOME or KITTI_HOME).
This repository is a fork of the IRR-PWC implementation from Junhwa Hur and Stefan Roth.