In [8]:
import os
import cv2
import torch
import torch.nn as nn
import torchvision
import numpy as np
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
import pandas as pd
from PIL import Image

## Extracting Optical Flow Images

In [3]:
cap = cv2.VideoCapture('test.mp4')
ret, frame1 = cap.read()
prv = cv2.cvtColor(frame1, cv2.COLOR_BGR2GRAY)
hsv = np.zeros_like(frame1)
hsv[...,1] = 255

indx = 2
while(True):
    ret, frame2 = cap.read()
    if ret == False:
        break
    nxt = cv2.cvtColor(frame2, cv2.COLOR_BGR2GRAY)
    
    flow = cv2.calcOpticalFlowFarneback(prv, nxt, None, 0.5, 3, 15, 3, 5, 1.1, 1)
    mag, ang = cv2.cartToPolar(flow[...,0], flow[...,1])
    hsv[...,0] = ang*180/np.pi/2
    hsv[...,2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX)
    bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
    
    cv2.imwrite('testflow2/'+str(indx)+'.jpg', bgr)
    prv = nxt;
#     cv2.imshow("flow", bgr)
    
    indx = indx + 1
    
    if cv2.waitKey(1) == 27:
    	break

cap.release()
cv2.destroyAllWindows()

## Defining global parameters

In [9]:
nc = 3
ngf = 64
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
data_path = 'flow/'
batch_size = 8
num_epochs = 10
num_images = 16319

print(device)

cuda:0


## Data Loader class which outputs a batch of (Images, Labels)

In [22]:
class DataLoader:
    
    def __init__(self):
        self.device = device
        self.batch_size = batch_size
        self.data_path = data_path
        imgs = os.listdir(self.data_path)
        imgs = sorted(imgs, key=lambda x: int(os.path.splitext(x)[0]))
        lbls = pd.read_csv('train.txt', sep=" ", header=None)
        lbls = lbls[lbls.columns[0]][1:].tolist()
        self.train_images, self.test_images, self.train_lbls, self.test_lbls = train_test_split(imgs,lbls, 
                                                                        test_size=0.2, random_state=42)
        self.data_transforms = torchvision.transforms.Compose([
#                 torchvision.transforms.Resize([224, 224]), 
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])
        
    def load_img(self, image_name):
        image = Image.open(image_name)
        image = self.data_transforms(image).float()
        image = torch.autograd.Variable(image, requires_grad=False)
        image = image.unsqueeze(0)
        return image.to(device)
    
    def get_data(self, train=True):
        if train == True:
            images = self.train_images
            lb = self.train_lbls
        else:
            images = self.test_images
            lb = self.test_lbls
        
        while True:
            ix = np.random.choice(np.arange(len(images)), self.batch_size)
            x = []
            y = []
            for i in ix:
                x.append(self.load_img(self.data_path + images[i]))
                y.append(torch.tensor((lb[i])))
            yield torch.stack(x), torch.stack(y)

In [23]:
loader = DataLoader()


## Weights Initialization

In [12]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.01)
    elif classname.find('BatchNorm2d') != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.01)
        torch.nn.init.constant_(m.bias.data, 0.0)

## Network used for Prediction from Optical Flow images

In [13]:
class network(nn.Module):
    
    def __init__(self):
        super(network, self).__init__()
        
        # 480 x 640
        self.conv1 = nn.Conv2d(nc, ngf, 4, 2, 1, bias=True)
        self.r1 = nn.ReLU(inplace=True)
        
        # 240 x 320
        self.conv2 = nn.Conv2d(ngf, ngf*2, 4, 2, 1, bias=True)
        self.in2 = nn.InstanceNorm2d(ngf*2)
        self.r2 = nn.ReLU(inplace=True)
        
        # 120 x 160
        self.conv3 = nn.Conv2d(ngf*2, ngf*4, 4, 2, 1, bias=True)
        self.in3 = nn.InstanceNorm2d(ngf*4)
        self.r3 = nn.ReLU(inplace=True)
        
        # 60 x 80
        self.conv4 = nn.Conv2d(ngf*4, ngf*8, 4, 2, 1, bias=True)
        self.in4 = nn.InstanceNorm2d(ngf*8)
        self.r4 = nn.ReLU(inplace=True)
        
        # 30 x 40
        self.conv5 = nn.Conv2d(ngf*8, ngf*16, 4, 2, 1, bias=True)
        self.in5 = nn.InstanceNorm2d(ngf*16)
        self.r5 = nn.ReLU(inplace=True)
        
        # 15 x 20
        self.conv6 = nn.Conv2d(ngf*16, ngf*8, 1, 1, 0, bias=True)
        self.in6 = nn.InstanceNorm2d(ngf*8)
        self.r6 = nn.ReLU(inplace=True)
        
        self.conv7 = nn.Conv2d(ngf*8, ngf*4, 1, 1, 0, bias=True)
        self.in7 = nn.InstanceNorm2d(ngf*4)
        self.r7 = nn.ReLU(inplace=True)
        
        self.conv8 = nn.Conv2d(ngf*4, ngf*2, 1, 1, 0, bias=True)
        self.in8 = nn.InstanceNorm2d(ngf*2)
        self.r8 = nn.ReLU(inplace=True)
        
        self.conv9 = nn.Conv2d(ngf*2, ngf, 1, 1, 0, bias=True)
        self.in9 = nn.InstanceNorm2d(ngf)
        self.r9 = nn.ReLU(inplace=True)
        
        self.conv10 = nn.Conv2d(ngf, 3, 1, 1, 0, bias=True)
        self.in10 = nn.InstanceNorm2d(ngf/64)
        self.r10 = nn.ReLU(inplace=True)
        
        self.lin11 = nn.Linear(15*20*3, 100)
        
        self.lin12 = nn.Linear(100, 1)
        
    def forward(self, x):
        c1 = self.conv1(x)
        c2 = self.in2(self.conv2(self.r1(c1)))
        c3 = self.in3(self.conv3(self.r2(c2)))
        c4 = self.in4(self.conv4(self.r3(c3)))
        c5 = self.in5(self.conv5(self.r4(c4)))
        c6 = self.in6(self.conv6(self.r5(c5)))
        c7 = self.in7(self.conv7(self.r6(c6)))
        c8 = self.in8(self.conv8(self.r7(c7)))
        c9 = self.in9(self.conv9(self.r8(c8)))
        c10 = self.in10(self.conv10(self.r9(c9)))
        l11 = self.lin11(self.r10(c10.view(c10.size()[0], -1)))
        l12 = self.lin12(l11)
        
        return l12

In [14]:
net = network().to(device)
net.apply(weights_init_normal)
criterion = torch.nn.MSELoss()
optim_net = torch.optim.Adam(net.parameters())

## Training Loop

In [None]:
print_interval = 10
checkpoint_interval = 500
file_name = 'train/net.pth'
if os.path.isfile(file_name):
    net.load_state_dict(torch.load(file_name))

for epoch in range(0, num_epochs):
    for i in range(num_images // batch_size):
        x, y = next(loader.get_data(True))
        real_x = torch.autograd.Variable(x).to(device)
        real_y = torch.autograd.Variable(y).to(device)
        
        optim_net.zero_grad();
        pred = net(real_x).view(-1)
        loss = criterion(pred, real_y)
        loss.backward()
        optim_net.step()
        
        if i % print_interval == 0: 
            print("\r[Epoch %d/%d] [Batch %d/%d] [Train_loss: %f]" %
                                                    (epoch, num_epochs, i, num_images//batch_size, loss))
        if i % checkpoint_interval == 0:
            v_loss = validate()
            print("\r[Epoch %d/%d] [Batch %d/%d] [Train_loss: %f] [Validation_loss: %f]" %
                                                    (epoch, num_epochs, i, num_images//batch_size, loss, v_loss))
            torch.save(net.state_dict(), file_name)     
        
    torch.save(net.state_dict(), "train/net_{}.pth".format(epoch+1))

## Validation for Early Stopping

In [None]:
def validate():
    x, y = next(loader.get_data(False))
    real_x = torch.autograd.Variable(x).to(device)
    real_y = torch.autograd.Variable(y).to(device)
    
    pred = net(real_x).view(-1)
    loss = criterion(pred, real_y)
    return loss

## Load optical flow for test video and write

In [None]:
def test_video():
    net.load_state_dict(torch.load("train/custom_net/net_10.pth"))
    out = cv2.VideoWriter('speed.avi',cv2.VideoWriter_fourcc('M','J','P','G'), 20, (640, 480))
    for i in range(2, 20400):
        im = cv2.imread("frames/{}.jpg".format(i))
        flow = loader.load_img("flow/{}.jpg".format(i))
        real_x = torch.autograd.Variable(flow).to(device)
        pred = net(real_x).view(-1)
        frame = cv2.putText(im, "{}".format(pred.item()), org=(20,20), fontFace=3, fontScale=.5, color=(255, 255, 255), thickness=1)
        cv2.imshow('frame', frame)
        out.write(frame)
    
    cv2.destroyAllWindows()    
    out.release()
    
test_video()