In [71]:
import sys
sys.path.append('core')

from raft import RAFT
from utils import flow_viz
from utils.utils import InputPadder

import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation ## FOR playing images as video on jupyter notebook
from IPython.display import HTML ## FOR playing images as video on jupyter notebook

import cv2
import torch
from torchvision import datasets, transforms

import argparse

In [2]:
device = 'cuda' if torch.cuda.is_available else 'cpu'

# 1. Generate Optical flow images from training video

## 1a. Load training video

Videos downloaded from comma ai speed challenge git repo:https://github.com/commaai/speedchallenge 

In [3]:
video_file = 'data/train.mp4'

Load video using openCV and copy all frames. Taking only first 100 frames for this demo

In [4]:
frames = []

cap_in = cv2.VideoCapture(video_file)
f_counter = 0  ## frame counter
while True:
    ret, frame = cap_in.read()
    if ret:
        f_counter += 1
        frames.append(frame)
    if f_counter == 100:
        break
# When everything done, release the capture
cap_in.release()

In [5]:
## openCV reads the frames in BGR format. So reformatting to RGB
frames = [cv2.cvtColor(f, cv2.COLOR_BGR2RGB) for f in frames]

In [9]:
class Playable():
    """
    Play the frames as video inside jupyter notebook
    
    Args:
        frames (list): List of numpy arrays (frames to play)
        
    """
    def __init__(self, frames):
        self.frames = frames
        self.fig = plt.figure()
        self.im = plt.imshow(frames[0])
        plt.close() # this is required to not display the generated image

    def init_anim(self):
        self.im.set_data(self.frames[0])

    def animate(self, i):
        self.im.set_data(self.frames[i])
        return self.im
    def play(self):
        """
        Play the frames
        """
        anim = animation.FuncAnimation(self.fig, self.animate, init_func=self.init_anim, frames=len(self.frames), interval=100)
        return HTML(anim.to_html5_video())

In [13]:
Playable(frames).play()

## 1b. Preprocess images (brightness adjusting and cropping) 

In [41]:
def change_brightness(image, bright_factor):
    """augment brightness"""
    hsv_image = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
    hsv_image[:,:,2] = hsv_image[:,:,2] * bright_factor
    image_rgb = cv2.cvtColor(hsv_image, cv2.COLOR_HSV2RGB)
    return image_rgb


def transform(image, bright_factor):
    """augment brightness, crop/resize"""
    image = change_brightness(image, bright_factor)
    return image[100:440, :, :]

In [57]:
def preprocess(prev_frame, curr_frame):
    bright_factor = 0.7 + np.random.uniform()
    """augment brightness"""
    prev_frame, curr_frame = transform(prev_frame, bright_factor), transform(curr_frame, bright_factor)
    return prev_frame, curr_frame

In [58]:
preprocessed_frames = [preprocess(frames[min(0, i-1)], frames[i]) for i in range(len(frames))]

In [59]:
pf = [f[1] for f in preprocessed_frames]

In [60]:
Playable(pf).play()

## 1c. Load pretrained RAFT model

RAFT (Recurrent All Pairs Field Transforms for Optical Flow) model pretrained on FlyingThings3D dataset (raft-things.pth) and code is downloaded from https://github.com/princeton-vl/RAFT

In [50]:
## add dummy args for loading RAFT model
sys.argv = ['foo']
parser = argparse.ArgumentParser()
args = parser.parse_args()

## create model object
model = torch.nn.DataParallel(RAFT(args), device_ids=[0])
model.to(device)
## load model
model.load_state_dict(torch.load('models/raft-things.pth'))

<All keys matched successfully>

## 1d. Generate optical flow images for all frames

In [15]:
flo_frames = []
for i in range(len(frames)-1):
    ## take consecutive frames
    im1 = torch.from_numpy(frames[i]).permute(2,0,1).float().unsqueeze(axis=0).to(device)
    im2 = torch.from_numpy(frames[i+1]).permute(2,0,1).float().unsqueeze(axis=0).to(device)
    ## Pad images such that dimensions are divisible by 8 (see RAFT paper)
    padder = InputPadder(im1.shape)
    im1, im2 = padder.pad(im1, im2)
    
    ## generate optical flow image which is a 2channel UV format image
    _, flo = model(im1, im2, iters=20, test_mode=True)
    flo = flo[0].permute(1,2,0).detach().cpu().numpy()
    # format flow to rgb image
    flo = flow_viz.flow_to_image(flo)
    flo_frames.append(flo)

In [16]:
Playable(flo_frames).play()

In [61]:
flo_frames = []
for i in range(len(preprocessed_frames)):
    ## take consecutive frames
    im1 = torch.from_numpy(preprocessed_frames[i][0]).permute(2,0,1).float().unsqueeze(axis=0).to(device)
    im2 = torch.from_numpy(preprocessed_frames[i][1]).permute(2,0,1).float().unsqueeze(axis=0).to(device)
    ## Pad images such that dimensions are divisible by 8 (see RAFT paper)
    padder = InputPadder(im1.shape)
    im1, im2 = padder.pad(im1, im2)
    
    ## generate optical flow image which is a 2channel UV format image
    _, flo = model(im1, im2, iters=20, test_mode=True)
    flo = flo[0].permute(1,2,0).detach().cpu().numpy()
    # format flow to rgb image
    flo = flow_viz.flow_to_image(flo)
    flo_frames.append(flo)

In [62]:
Playable(flo_frames).play()

# 2. Create PyTorch dataset of optical flow images

Create a pytorch dataset with the generated optical flow images as features and speeds at each frame as target values

In [64]:
speeds_file = 'data/train.txt'

In [65]:
y = np.loadtxt(speeds_file)

In [70]:
y = y[:100]

In [73]:
TT = transforms.ToTensor()

In [75]:
def create_tensor_dataset(x, y, transform):
    X_torch = torch.stack([transform(img) for img in x])
    y_torch = torch.from_numpy(y)
    dataset = torch.utils.data.TensorDataset(X_torch, y_torch)
    return(dataset)

In [81]:
train_val_split = int(np.floor(0.8*len(y)))

In [83]:
x_train = flo_frames[:train_val_split]
y_train = y[:train_val_split]

x_val = flo_frames[train_val_split:]
y_val = y[train_val_split:]

In [84]:
train_set = create_tensor_dataset(x_train,y_train,TT)
val_set = create_tensor_dataset(x_val,y_val,TT)

In [87]:
train_loader = torch.utils.data.DataLoader(train_set, batch_size=32,num_workers=8, shuffle=True)

In [88]:
val_loader = torch.utils.data.DataLoader(val_set, batch_size=32,num_workers=8, shuffle=False)