In [1]:
import numpy as np
import cv2
import pims
from tqdm.notebook import trange
import xml.etree.ElementTree as ET

import matplotlib.pyplot as plt
from matplotlib.pyplot import plot

import torch
import torch.nn as nn
import torch.nn.functional as F
from blitz.modules import BayesianLinear
from blitz.utils import variational_estimator

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [2]:
# CONSTANTS
# network input resolution
W = 320
H = 160

# annotations' resolution
annot_W = 2562
annot_H = 1440

In [3]:
# POLYLINES FUNCTIONS

def extract_polylines(filename):
  tree = ET.parse(filename) # read in the XML
  polylines = []
  for polyline in tree.iter(tag='polyline'):
    frame = polyline.get("frame")
    points = polyline.get("points").split(";")
    for i in range(len(points)):
      points[i] = points[i].split(",")
      for j in range(len(points[i])):
        points[i][j] = float(points[i][j])
    data = (int(frame), points)
    polylines.append(data)

  return sorted(polylines)

# returns frames with annotations
def extract_frame_lines(polylines):
  n_frames = polylines[-1][0]
  frames = []

  # TODO: O(n^2), refactor this
  for i in range(n_frames+1):
    frame = []
    for polyline in polylines:
      if polyline[0] == i:
        frame.append(polyline[1])
    frames.append(sorted(frame))
  
  return frames

# TODO: need to project polylines from coefficients as well (maybe use matplotlib on top of cv2 image) OR create new points from coefficients (since network will output only coefficients)
def extract_coefficients(annotations):
  coefficients = []
  for frame in annotations:
    frame_coef = []
    for polyline in frame:
      polyline = np.array(polyline)
      x, y = polyline.T[0], polyline.T[1]
      coef = np.polyfit(x, y, 2)  # extract coefficients from 2nd degree polyilne points
      frame_coef.append(coef)
    coefficients.append(frame_coef)
  return np.array(coefficients, dtype=object)

# returns a, b, c coefficients of specific polyline in specific frame
def get_coefficients(coefficients, frame, line):
  return coefficients[frame][line]

# helps plot lines using coefficients
def poly_coefficients(x, coeffs):
  y = 0
  for i in range(len(coeffs)):
    y += coeffs[i]*x**i
  return y

# TODO: to make this work with coefs instead of points we need to have a start and a finish for every line (like in the PolyLaneNet paper)
def plot_coefs(frame):
  for coeffs in frame:
    x = np.linspace(0, 9, 10)
    plt.plot(x, poly_coefficients(x, coeffs))
  plt.show()

def display(video_file, annotations):
  cap = cv2.VideoCapture(video_file)
  idx = 0
  while True:
    ret, frame = cap.read()
    if ret:
      print("[+] Frame:", idx)
      polylines = annotations[idx]
      for polyline in polylines:
        polyline = np.array(polyline) # get points for every polyline
        print("Polyline:")
        print(polyline)
        x, y = polyline.T[0], polyline.T[1]
        coefficients = np.polyfit(x, y, 2)  # extract coefficients from 2nd degree polyilne points
        print("Coefficients:")
        print(coefficients)
        frame = cv2.polylines(frame, np.int32([polyline]), False, (0, 0, 255))
      # TODO: need to find a way to resize the polylines to a low resolution to multitask-train with the 320x160-input CrossroadDetection model
      frame = cv2.resize(frame, (1920//2, 1080//2))
      cv2.imshow('frame', frame)
      if cv2.waitKey(1) & 0xff == ord('q'):
        break
      idx += 1
    else:
      break
  cap.release()
  cv2.destroyAllWindows()

def draw_polylines(frame, polylines, color=(0, 0, 255)):
  for polyline in polylines:
    polyline = np.array(polyline)
    x, y = polyline.T[0], polyline.T[1]
    frame = cv2.polylines(frame, np.int32([polyline]), False, color, 2)
  return frame

# converts current annotations to new resolution
def convert_annotations(old_res, new_res, annotations):
  W, H = old_res
  new_W, new_H = new_res
  new_annotations = []
  for polylines in annotations:
    new_polylines = []
    for polyline in polylines:
      new_polyline = []
      for point in polyline:
        x, y = point
        new_x = (x*new_W) / W
        new_y = (y*new_H) / H
        new_polyline.append((new_x,new_y))
      new_polylines.append(new_polyline)
    new_annotations.append(new_polylines)
  return np.array(new_annotations, dtype=object)

# converts predicted polylines to new resolution
# TODO: this might be buggy (check the shape of polylines on old and new resolutions)
def convert_polylines(old_res, new_res, polylines):
  W, H = old_res
  new_W, new_H = new_res
  new_polylines = []
  for polyline in polylines:
    new_polyline = []
    for point in polyline:
      x, y = point
      new_x = (x*new_W) / W
      new_y = (y*new_H) / H
      new_polyline.append((new_x,new_y))
    new_polylines.append(new_polyline)
  return np.array(new_polylines)

#TODO: check if this function matches the next one
"""
# TODO: this algorithm has bad complexity (O(n^3)), refactor if possible
# convert polylines per frame to net output vector (flattens the array)
def serialize_polylines(polylines, n_coords, n_points, max_n_lines):
  # check if we have more than n_points
  # TODO: instead of removing the whole line, just get polyline[:n_points]
  for polyline in polylines:
    if len(polyline) != n_points:
      polylines.remove(polyline)
  assert len(polylines) <= max_n_lines, "More than max number of lines found"

  # fill the gaps with negative values (-1 == NULL => out of bounds)
  if len(polylines) < max_n_lines:
    for i in range(max_n_lines - len(polylines)):
      new_polyline = []
      for j in range(n_points):
        point = []
        for k in range(n_coords):
          point.append(-100.)
        new_polyline.append(point)
      polylines.append(new_polyline)
      
  # flatten
  ret = []
  for i in range(max_n_lines):
    for j in range(n_points):
      for k in range(n_coords):
        ret.append(polylines[i][j][k])

  return np.array(ret)
"""

# TODO: this algorithm has bad complexity (O(n^3)), refactor if possible
def serialize_polylines(polylines, n_coords, n_points, max_n_lines):
  # check if we have more than n_points
  # TODO: instead of removing the whole line, just get polyline[:n_points]
  for polyline in polylines.copy():
    if len(polyline) != n_points:
      polylines.remove(polyline)
  assert len(polylines) <= max_n_lines, "More than max number of lines found"

  # fill the gaps with negative values (-1 or -10 or -100 == NULL => out of bounds)
  if len(polylines) < max_n_lines:
    for i in range(max_n_lines - len(polylines)):
      new_polyline = []
      for j in range(n_points):
        point = []
        for k in range(n_coords):
          point.append(-100.)
        new_polyline.append(point)
      polylines.append(new_polyline)
      
  # flatten
  ret = []
  for i in range(max_n_lines):
    for j in range(n_points):
      for k in range(n_coords):
        ret.append(polylines[i][j][k])

  return np.array(ret)

# TODO: this needs more work depending on the net output, since it is tested only on annotations
# convert network output vector to polylines per frame
def deserialize_polylines(net_output, n_coords, n_points, max_n_lines):
  polylines = []
  point = []
  line = []
  for i in range(len(net_output)):
    point.append(net_output[i])
    if len(point) == 2:
      line.append(point)
      point = []
    if len(line) == 4:
      polylines.append(line)
      line = []

  # remove (-1, -1)/out-of-bounds points from lines
  for polyline in polylines:
    while [-1., -1.] in polyline:
      polyline.remove([-1., -1.]) # TODO: remove all negative numbers, not just (-1., -1.) pairs

  # remove empty lists
  while [] in polylines:
    polylines.remove([])

  return np.array(polylines)

In [6]:
data_path = "../data/videos/with_crossroads/"
video_file = data_path + "city_1.mp4"
data_file = data_path + "city_1_path.xml"

In [17]:
path = extract_polylines(data_file)
path = extract_frame_lines(path)
path = convert_annotations((annot_W,annot_H), (W,H), path)
print(path.shape)
path[0]

(2113, 1, 8, 2)


array([[[159.93255269320844, 160.0],
        [160.1086651053864, 146.21666666666667],
        [160.28602654176424, 132.64666666666665],
        [160.4633879781421, 123.96777777777778],
        [160.64074941451992, 117.81333333333332],
        [160.81811085089774, 113.39555555555555],
        [160.99547228727556, 109.29222222222221],
        [161.1728337236534, 106.76777777777778]]], dtype=object)