In [None]:
!pip install flow_vis

In [None]:
import torch
from torch import nn
import pytorch_lightning as pl
import numpy as np

import os
import flow_vis

import matplotlib.pyplot as plt
import torch.utils.data as data

from torch.utils.data import DataLoader
from torchvision import transforms
import pytorch_lightning as pl


from torch.utils.data import Dataset
from imageio import imread
from typing import Callable, Optional
import numpy as np


import time

In [None]:
def make_dataset_split(dataset: [], split: float = 0.8, random_seed: int = 42):

    num_train_samples = int(len(dataset)*split)
    num_val_samples = len(dataset) - num_train_samples

    seed = torch.Generator().manual_seed(random_seed)
    return data.random_split(dataset, [num_train_samples, num_val_samples], generator=seed)


def get_flying_chairs_data_paths(root: str):

    samples = []
    for name in sorted(os.listdir(root)):
        if name.endswith('_flow.flo'):
            sample_id = name[: -9]
            img1 = os.path.join(root, sample_id + "_img1.ppm")
            img2 = os.path.join(root, sample_id + "_img2.ppm")
            flow = os.path.join(root, name)
            samples.append([[img1, img2], flow])

    return samples

In [None]:
def crop(input, target):
    if input.shape[2: ] == target.shape[2: ]:
        return input
    else:
        return input[:, :, :target.size(2), :target.size(3)]


def mean_EPE(input_flow, target_flow):
    return torch.norm(target_flow-input_flow,p=2,dim=1).mean()

In [None]:
class FlowNetSimple(pl.LightningModule):

    def __init__(self):
        super().__init__()      
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=6, out_channels=64, kernel_size=(7, 7), stride=(2, 2), padding=3),
            nn.ReLU()
        )
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(5, 5), stride=(2, 2), padding=2),
            nn.ReLU()
        )
        
        self.conv3 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(5, 5), stride=(2, 2), padding=2),
            nn.ReLU()
        )
        
        self.conv3_1 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding='same'),
            nn.ReLU()
        )
        
        self.conv4 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(3, 3), stride=(2, 2), padding=1),
            nn.ReLU()
        )
        
        self.conv4_1 = nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding='same'),
            nn.ReLU()
        )
        
        self.conv5 = nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(2, 2), padding=1),
            nn.ReLU()
        )
        
        self.conv5_1 = nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding='same'),
            nn.ReLU()
        )
        
        self.conv6 = nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=(3, 3), stride=(2, 2), padding=1),
            nn.ReLU()
        )
        
              
           
        self.deconv5 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1), bias=False),
            nn.ReLU()
        )
        self.deconv4 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=1024, out_channels=256, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1), bias=False),
            nn.ReLU()
        )
        
        self.deconv3 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=770, out_channels=128, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1), bias=False),
            nn.ReLU()
        )
        
        self.deconv2 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=386, out_channels=64, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1), bias=False),
            nn.ReLU()
        )
        
        
        self.predict_flow5 = nn.Conv2d(1024,2,kernel_size=3,stride=1,padding=1,bias=False)
        self.predict_flow4 = nn.Conv2d(770,2,kernel_size=3,stride=1,padding=1,bias=False)
        self.predict_flow3 = nn.Conv2d(386,2,kernel_size=3,stride=1,padding=1,bias=False)
        self.predict_flow2 = nn.Conv2d(194,2,kernel_size=3,stride=1,padding=1,bias=False)
              
        self.upsampling = nn.ConvTranspose2d(2, 2, 5, 2, 1, bias=False)

        self.upsample_bilinear = nn.Upsample(scale_factor=4, mode='bilinear')

    def forward(self, x):
        
        out_conv1 = self.conv1(x)
        out_conv2 = self.conv2(out_conv1)
        out_conv3 = self.conv3(out_conv2)
        out_conv3_1 = self.conv3_1(out_conv3)
        out_conv4 = self.conv4(out_conv3_1)
        out_conv4_1 = self.conv4_1(out_conv4)
        out_conv5 = self.conv5(out_conv4_1)
        out_conv5_1 = self.conv5_1(out_conv5)
        out_conv6 = self.conv6(out_conv5_1)

        out_deconv5 = self.deconv5(out_conv6)
        input_to_deconv4 = torch.cat((crop(out_deconv5, out_conv5_1), out_conv5_1), 1)
        flow5 = self.predict_flow5(input_to_deconv4)

        upsampled_flow5_to_4 = crop(self.upsampling(flow5), out_conv4_1)
        out_deconv4 = self.deconv4(input_to_deconv4)
        input_to_deconv3 = torch.cat((crop(out_deconv4, out_conv4_1), out_conv4_1, upsampled_flow5_to_4), 1)
        flow4 = self.predict_flow4(input_to_deconv3)

        upsampled_flow4_to_3 = crop(self.upsampling(flow4), out_conv3_1)
        out_deconv3 = self.deconv3(input_to_deconv3)
        input_to_deconv2 = torch.cat((crop(out_deconv3, out_conv3_1), out_conv3_1, upsampled_flow4_to_3), 1)
        flow3 = self.predict_flow3(input_to_deconv2)

        upsampled_flow3_to_2 = crop(self.upsampling(flow3), out_conv2)
        out_deconv2 = self.deconv2(input_to_deconv2)
        input_to_upsamling = torch.cat((crop(out_deconv2, out_conv2), out_conv2, upsampled_flow3_to_2), 1)

        flow2 = self.predict_flow2(input_to_upsamling)
        output = self.upsample_bilinear(flow2)

        return output

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        x = torch.cat((x[0], x[1]), 3)
        x = np.swapaxes(x, 1, 3)
        x = np.swapaxes(x, 2, 3)
        y = np.swapaxes(y, 1, 3)
        y = np.swapaxes(y, 2, 3)
        prediction = self.forward(x)
        loss = mean_EPE(prediction, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        x = torch.cat((x[0], x[1]), 3)
        x = np.swapaxes(x, 1, 3)
        x = np.swapaxes(x, 2, 3)
        y = np.swapaxes(y, 1, 3)
        y = np.swapaxes(y, 2, 3)
        prediction = self.forward(x)
        loss = mean_EPE(prediction, y)
        self.log('val_loss', loss)

    def predict_step(self, pred_batch, batch_idx):
        x, y = pred_batch
        x = torch.cat((x[0], x[1]), 3)
        x = np.swapaxes(x, 1, 3)
        x = np.swapaxes(x, 2, 3)
        prediction = self.forward(x)
            
        w = y[0].shape[0]
        h = y[0].shape[1]
        
        fig = plt.figure(figsize=(8, 8))
            
        flow_y = y[0]
        flow_y = flow_y.cpu().detach().numpy()
        flow_color_y = flow_vis.flow_to_color(flow_y, convert_to_bgr=False)
        
        flow_pred = prediction[0]
        flow_pred = np.swapaxes(flow_pred, 0, 1)
        flow_pred = np.swapaxes(flow_pred, 1, 2)
        flow_pred = flow_pred.cpu().detach().numpy()
        flow_color_pred = flow_vis.flow_to_color(flow_pred, convert_to_bgr=False)
            
        fig.add_subplot(1, 2, 1)
        plt.imshow(flow_color_y)
        fig.add_subplot(1, 2, 2)
        plt.imshow(flow_color_pred)
        
        plt.show()
        
        return prediction
        
    def test_step(self, test_batch, batch_idx):
        x, y = test_batch
        x = torch.cat((x[0], x[1]), 3)
        x = np.swapaxes(x, 1, 3)
        x = np.swapaxes(x, 2, 3)
        
        prediction = self.forward(x)
        
        w = y[0].shape[0]
        h = y[0].shape[1]
        
        fig = plt.figure(figsize=(8, 8))
            
        flow_y = y[0]
        flow_y = flow_y.cpu().detach().numpy()
        flow_color_y = flow_vis.flow_to_color(flow_y, convert_to_bgr=False)
        
        flow_pred = prediction[0]
        flow_pred = np.swapaxes(flow_pred, 0, 1)
        flow_pred = np.swapaxes(flow_pred, 1, 2)
        flow_pred = flow_pred.cpu().detach().numpy()
        flow_color_pred = flow_vis.flow_to_color(flow_pred, convert_to_bgr=False)
            
        fig.add_subplot(1, 2, 1)
        plt.imshow(flow_color_y)
        fig.add_subplot(1, 2, 2)
        plt.imshow(flow_color_pred)
        
        plt.show()
        
        y = np.swapaxes(y, 1, 3)
        y = np.swapaxes(y, 2, 3)
        
        
        loss = mean_EPE(prediction, y)
        self.log('test_loss', loss)
        print(f"Batch id:  {batch_idx} -----> Validation Loss: {loss}")


    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
        return optimizer

In [None]:
def flying_chairs_loader(sample):

    inputs, target = sample[0], sample[1]
    img1, img2 = np.asarray(imread(inputs[0]), dtype=np.float32), np.asarray(imread(inputs[1]), dtype=np.float32)

    with open(target, 'rb') as f:
        magic = np.fromfile(f, np.float32, count=1)
        assert (202021.25 == magic), 'Magic number incorrect. Invalid .flo file'
        h = np.fromfile(f, np.int32, count=1)[0]
        w = np.fromfile(f, np.int32, count=1)[0]
        data = np.fromfile(f, np.float32, count=2 * w * h)
    
    data2D = np.resize(data, (w, h, 2))
        
    return [img1, img2], data2D


class CustomDataset(Dataset):
    def __init__(self, file_names: [str],
                 transform: Optional[Callable] = None,
                 target_transform: Optional[Callable] = None,
                 loader=flying_chairs_loader):

        self.file_names = file_names
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def __getitem__(self, idx):
        
        inputs, target = self.loader(self.file_names[idx])
        
        if self.transform:
            inputs[0] = self.transform(inputs[0])
            inputs[1] = self.transform(inputs[1])
        if self.target_transform:
            target = self.target_transform(target)
        return inputs, target

    def __len__(self):
        return len(self.file_names)

In [None]:
flyingChairsData = get_flying_chairs_data_paths('flyingChairs/FlyingChairs_release/dataset')

train_set, test_set = make_dataset_split(flyingChairsData, split=.9)
train_set, val_set = make_dataset_split(train_set)

train_data = CustomDataset(train_set)
val_data = CustomDataset(val_set)
train_loader = DataLoader(train_data, batch_size=8, num_workers=0)
val_loader = DataLoader(val_data, batch_size=8, num_workers=0)

test_data = CustomDataset(test_set)
test_loader = DataLoader(test_data, num_workers=0)

In [None]:
model = FlowNetSimple()

In [None]:
torch.set_float32_matmul_precision('high')

trainer = pl.Trainer(max_epochs=20)

In [None]:
trainer.fit(model, train_loader, val_loader)

In [None]:
trainer.test(model, test_loader)

In [None]:
flyingChairsDataTest = get_flying_chairs_data_paths('flyingChairs/FlyingChairs_release/test')

test = CustomDataset(flyingChairsDataTest)
loader = DataLoader(test, num_workers=0)

start_time = time.time()
trainer.test(model, loader)
print("--- %s seconds ---" % ((time.time() - start_time) / 165))