In [None]:
!pip install pims av
!python --version

import numpy as np
import cv2
import pims
from tqdm.notebook import trange
import xml.etree.ElementTree as ET

import matplotlib.pyplot as plt
from matplotlib.pyplot import plot

import torch
import torch.nn as nn
import torch.nn.functional as F

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# CONSTANTS
# network input resolution
W = 320
H = 160

# annotations' resolution
annot_W = 480
annot_H = 320

In [None]:
# DATA FUNCTIONS
from os import listdir

# get polylines from file
def extract_polylines(filename):
  tree = ET.parse(filename)
  polylines = []
  for polyline in tree.iter(tag='polyline'):
    frame = polyline.get("frame")
    points = polyline.get("points").split(";")
    for i in range(len(points)):
      points[i] = points[i].split(",")
      for j in range(len(points[i])):
        points[i][j] = float(points[i][j])
    data = (int(frame), points)
    polylines.append(data)

  return sorted(polylines)

# get polylines from each frame
def extract_frame_lines(polylines):
  n_frames = polylines[-1][0]
  frames = []

  for i in range(n_frames+1):
    frame = []
    for polyline in polylines:
      if polyline[0] == i:
        frame.append(polyline[1])
    frames.append(sorted(frame))
    
  return frames

# convert annotations to new resolution
def convert_annotations(old_res, new_res, annotations):
  W, H = old_res
  new_W, new_H = new_res
  new_annotations = []
  for polylines in annotations:
    new_polylines = []
    for polyline in polylines:
      new_polyline = []
      for point in polyline:
        x, y = point
        new_x = (x*new_W) / W
        new_y = (y*new_H) / H
        new_polyline.append([new_x,new_y])
      new_polylines.append(new_polyline)
    new_annotations.append(new_polylines)
  return np.array(new_annotations)

# get training data from path
def get_data(video_path, annotations_path):
  # get video frames
  frames = pims.Video(video_path, format="mp4")
  
  # get road edges data
  annotations = extract_polylines(annotations_path)
  annotations = extract_frame_lines(annotations)
  annotations = convert_annotations((annot_W,annot_H), (W,H), annotations)

  return frames, annotations

# make pims video into actual numpy frames
def conv_frames(frames):
  imgs = []
  print("Getting frames into proper arrays")
  for frame in frames:
    imgs.append(cv2.resize(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), (W,H)))
  print("Frames converted to numpy arrays")
  return np.array(imgs)

base_dir = "/content/drive/MyDrive/OpenCRD_dataset/"
video_files = []
annot_files = []
for f in listdir(base_dir):
  if f.endswith(".mp4"):
    video_files.append(f)
  elif f.endswith(".xml"):
    annot_files.append(f)
video_files, annot_files = sorted(video_files), sorted(annot_files)

video_files = video_files[:3] # TODO: this is a temp hack, need to get all videos' annotations
print(video_files)
print(annot_files)

assert len(video_files) == len(annot_files), "Number of video files != number of annotation files"

# POLYLINES TRANSFORMATIONS

# TODO: this algorithm has bad complexity (O(n^3)), refactor if possible
# convert polylines per frame to net output vector (flattens the array)
def serialize_polylines(polylines, n_coords, n_points, max_n_lines):
  # check if we have more than n_points
  # TODO: instead of removing the whole line, just get polyline[:n_points]
  for polyline in polylines.copy():
    if len(polyline) != n_points:
      polylines.remove(polyline)
  assert len(polylines) <= max_n_lines, "More than max number of lines found"

  # fill the gaps with negative values (-1 or -10 or -100 == NULL => out of bounds)
  if len(polylines) < max_n_lines:
    for i in range(max_n_lines - len(polylines)):
      new_polyline = []
      for j in range(n_points):
        point = []
        for k in range(n_coords):
          point.append(-100.)
        new_polyline.append(point)
      polylines.append(new_polyline)
      
  # flatten
  ret = []
  for i in range(max_n_lines):
    for j in range(n_points):
      for k in range(n_coords):
        ret.append(polylines[i][j][k])

  return np.array(ret)

# TODO: this needs more work depending on the net output, since it is tested only on annotations
# convert network output vector to polylines per frame
def deserialize_polylines(net_output, n_coords, n_points, max_n_lines):
  polylines = []
  point = []
  line = []
  for i in range(len(net_output)):
    point.append(net_output[i])
    if len(point) == 2:
      line.append(point)
      point = []
    if len(line) == 4:
      polylines.append(line)
      line = []

  # remove (-1, -1)/out-of-bounds points from lines
  for polyline in polylines:
    while [-100., -100.] in polyline:
      #polyline.remove([-100., -100.]) # TODO: remove all negative numbers, not just (-1., -1.) pairs
      polyline.remove([-100., -100.])

  # remove empty lists
  while [] in polylines:
    polylines.remove([])

  return np.array(polylines)

In [None]:
# MODEL DEFINITION

# ResNet block
class ResBlock(nn.Module):
  def __init__(self, num_layers, in_channels, out_channels, identity_downsample=None, stride=1):
    super(ResBlock, self).__init__()

    self.num_layers = num_layers
    if self.num_layers > 34:
      self.expansion = 4
    else:
      self.expansion =1

    # ResNet50, 101 and 152 include additional layer of 1x1 kernels
    self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
    self.bn1 = nn.BatchNorm2d(out_channels)
    if self.num_layers > 34:
      self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1)
    else:
      # for ResNet18 and 34, connect input directly to 3x3 kernel (skip first 1x1)
      self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
    self.bn2 = nn.BatchNorm2d(out_channels)
    self.conv3 = nn.Conv2d(out_channels, out_channels*self.expansion, kernel_size=1, stride=1, padding=0)
    self.bn3 = nn.BatchNorm2d(out_channels*self.expansion)
    self.elu = nn.ELU()
    self.identity_downsample = identity_downsample

  def forward(self, x):
    identity = x
    if self.num_layers > 34:
      x = self.elu(self.bn1(self.conv1(x)))
    x = self.elu(self.bn2(self.conv2(x)))
    x = self.bn3(self.conv3(x))
    
    if self.identity_downsample is not None:
      identity = self.identity_downsample(identity)
    x += identity
    x = self.elu(x)
    return x

# Multitask Model
class ComboModel(nn.Module):
  def __init__(self, num_layers=18, block=ResBlock, image_channels=3):
    assert num_layers in [18, 34, 50, 101, 152], "Unknown ResNet architecture, number of layers must be 18, 34, 50, 101 or 152"
    super(ComboModel, self).__init__()

    # polylines' shape
    self.n_coords = 2  # 2 coordinates: x,y
    self.n_points = 4  # number of points of each polyline
    self.max_n_lines = 6 # max number of polylines per frame

    if num_layers < 50:
      self.expansion = 1
    else:
      self.expansion = 4
    if num_layers == 18:
      layers = [2, 2, 2, 2]
    elif num_layers == 34 or num_layers == 50:
      layers = [3, 4, 23, 3]
    elif num_layers == 101:
      layers = [3, 8, 23, 3]
    else:
      layers = [3, 8, 36, 3]

    self.in_channels = 16
    self.conv1 = nn.Conv2d(image_channels, 16, kernel_size=7, stride=2, padding=3)  # TODO: maybe kernel 5x5
    self.bn1 = nn.BatchNorm2d(16)
    self.elu = nn.ELU()
    self.avgpool1 = nn.AvgPool2d(3, 2, padding=1)

    # ResNet Layers
    self.layer1 = self.make_layers(num_layers, block, layers[0], intermediate_channels=64, stride=1)
    self.layer2 = self.make_layers(num_layers, block, layers[1], intermediate_channels=128, stride=2)
    self.layer3 = self.make_layers(num_layers, block, layers[2], intermediate_channels=256, stride=2)
    self.layer4 = self.make_layers(num_layers, block, layers[3], intermediate_channels=512, stride=2)

    self.avgpool2 = nn.AvgPool2d(1, 1)

    # Fully Connected Layers
    self.cr_head = get_cr_head()
    self.re_head = get_re_head()

  def forward(self, x):
    x = self.avgpool1(self.elu(self.bn1(self.conv1(x))))
    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)
    x = self.avgpool2(x)
    #print(x.shape)
    x = x.view(-1, self.num_flat_features(x))
    cr = torch.sigmoid(self.cr_head(x))
    re = self.re_head(x)
    return cr, re

  def get_cr_head(self):
    relu = nn.ReLU()
    fc1 = nn.Linear(256*5*10, 1024) # NOTE: this works only with ResNet18
    bn1 = nn.BatchNorm1d(1024)
    fc2 = nn.Linear(1024, 128)
    bn2 = nn.BatchNorm1d(128)
    fc3 = nn.Linear(128, 84)
    bn3 = nn.BatchNorm1d(84)
    fc4 = nn.Linear(84, 1)

    head = nn.Sequential(fc1, bn1, relu, fc2, bn2, relu, fc3, bn3, relu, fc4)
    return head

  def get_re_head(self):
    l_relu = nn.LeakyReLU()
    fc1 = nn.Linear(512*5*10, 8192) # NOTE: this works only with ResNet18
    bn1 = nn.BatchNorm1d(8192)
    fc2 = nn.Linear(8192, 4096)
    bn2 = nn.BatchNorm1d(4096)
    fc3 = nn.Linear(4096, 2048)
    bn3 = nn.BatchNorm1d(2048)
    fc4 = nn.Linear(2048, 1024)
    bn4 = nn.BatchNorm1d(1024)
    fc5 = nn.Linear(1024, 512)
    bn5 = nn.BatchNorm1d(512)
    fc6 = nn.Linear(512, 256)
    bn6 = nn.BatchNorm1d(256)
    fc7 = nn.Linear(256, 128)
    bn7 = nn.BatchNorm1d(128)
    fc8 = nn.Linear(128, 64)
    bn8 = nn.BatchNorm1d(64)
    fc9 = nn.Linear(64, self.n_coords*self.n_points*self.max_n_lines)

    head = nn.Sequential(fc1, bn1, l_relu, fc2, bn2, l_relu, fc3, bn3, l_relu,
                         fc4, bn4, l_relu, fc5, bn5, l_relu, fc6, bn6, l_relu,
                         fc8, bn8, l_relu, fc9)
    return head

# TODO: progress bar and loss/accuracy for each epoch instead of each sub-iteration + different graphs for each task

In [None]:
# LOSS WRAPPER

class ComboLoss(nn.Module):
  def __init__(self, task_num, model):
    super(ComboLoss, self).__init__()
    self.task_num = task_num  # TODO: maybe make this constant
    self.model = model
    self.log_vars = nn.Parameter(torch.zeros((task_num)))

  def forward(self, preds, cr, re):
    bce, mse = nn.BCELoss(), nn.MSELoss()

    loss0 = bce(preds[0], cr)
    precision0 = torch.exp(-self.log_vars[0])
    loss1 = mse(preds[1], re)
    precision1 = torch.exp(-self.log_vars[1])

    # TODO: need better multitask loss (weighted sum maybe)
    return loss0 + loss1