# Training fine flow prediction
Assuming source image $I_s$ and target image $I_t$ are already coarsely aligned, this notebook will try to predict a fine flow $F_{s\rightarrow t}$ between them. 

TODO describe objective functions used in this project

We assume you already have a folder called `workspace` that contains zipped dataset.

In [None]:
%cd ../notebooks/workspace

Import packages that we will use throughout this notebook.

In [None]:
from pytorch_lightning import Trainer
from pytorch_lightning.callback import EarlyStoppping
from pytorch_lightning.loggers import TensorBoardLogger

## Prepare dataset
We already pack some datasets used in the original paper as `LightningDataModule`. We will import it here.

In [None]:
from ransacflow.data import MegaDepthDataModule

In [None]:
#TODO add some sanity check for the dataset here, previews

In [None]:
# TODO setup environments for the following training sessions, how?

## Stage 1
Only train the **reconstruction loss**.

In [None]:
from ransacflow.train import RANSACFlowModelStage1

ransac_flow = RANSACFlowModelStage1(alpha=0, beta=0, gamma=0, kernel_size=7, lr=2e-4)

# FIXME unify TB logging location and experiment name
trainer = Trainer(
    max_epochs=200,
    logger=TensorBoardLogger("tb_logs", name="RANSAC-Flow"),
    callbacks=[EarlyStoppping(monitor="val_loss", min_delta=0.01, patience=3)],
)
trainer.fit(ransac_flow, MegaDepthDataModule)


All following command line interface are copied from the original implementation, temporarily.

In [None]:
    --nEpochs 200 
    --lr 2e-4
    --kernelSize 7 
--imgSize 224 
--batchSize 16 
    --lambda-match 0.0, alpha 
    --mu-cycle 0.0, beta 
    --grad 0.0, gamma  
    --trainMode flow 
--margin 88 

## Stage 2
Train jointly the **reconstruction loss** and **cycle consistency of the flow**.

In [None]:
from ransacflow.train import RANSACFlowModelStage2

ransac_flow = RANSACFlowModelStage1(alpha=0, beta=1, gamma=0, kernel_size=7, lr=2e-4)

# FIXME unify TB logging location and experiment name
trainer = Trainer(
    max_epochs=50,
    logger=TensorBoardLogger("tb_logs", name="RANSAC-Flow_stage2"),
    callbacks=[EarlyStoppping(monitor="val_loss", min_delta=0.01, patience=3)],
)
trainer.fit(ransac_flow, MegaDepthDataModule)

In [None]:

    --nEpochs 50 
    --lr 2e-4 
    --kernelSize 7 
--imgSize 224 
--batchSize 16 
    --lambda-match 0.0, alpha
    --mu-cycle 1.0, beta
    --grad 0.0, gamma
    --trainMode flow 
--margin 88 

## Stage 3
Train all three losses together: **reconstruction loss**, **cycle consistency of the flow**, and **matchability loss**.

In [None]:
from ransacflow.train import RANSACFlowModelStage3

ransac_flow = RANSACFlowModelStage3(alpha=0.01, beta=1, gamma=0, kernel_size=7, lr=2e-4)

# FIXME unify TB logging location and experiment name
trainer = Trainer(
    max_epochs=50,
    logger=TensorBoardLogger("tb_logs", name="RANSAC-Flow_stage3"),
    callbacks=[EarlyStoppping(monitor="val_loss", min_delta=0.01, patience=3)],
)
trainer.fit(ransac_flow, MegaDepthDataModule)


In [None]:
    --nEpochs 50 
    --lr 2e-4
    --kernelSize 7 
--imgSize 224 
--batchSize 16 
    --lambda-match 0.01, alpha
    --mu-cycle 1.0, beta
    --grad 0.0, gamma
    --trainMode flow+match 
--margin 88 


## Stage 4.1
This additional stage fine tune on SOMETHING MAGICAL, so the output image introduce less distortions.

TODO need to update description from the original paper

## Stage 4.2
This additional stage uses perceptual loss, 

TODO add description about why and how to use perceptual loss