# MotionCNN

### Packages

In [1]:
from tqdm import tqdm
from abc import ABC
from glob import glob
from torch import nn
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader

import torch.nn.functional as F
import tensorflow as tf
import numpy as np

import multiprocessing
import os
import cv2
import json
import timm




  from .autonotebook import tqdm as notebook_tqdm


### Global variables

In [2]:
N_JOBS = 8

PATH_DATASET = '../dataset/'
PATH_TRAINING_DATASET = PATH_DATASET + 'training/'
PATH_TESTING_DATASET = PATH_DATASET + 'testing/'
PATH_VALIDATION_DATASET = PATH_DATASET + 'validation/'

PATH_PREPROCESSED_DATASET = './preprocessed/'
PATH_PREPROCESSED_TRAINING_DATASET = PATH_PREPROCESSED_DATASET + 'training/'
PATH_PREPROCESSED_TESTING_DATASET = PATH_PREPROCESSED_DATASET + 'testing/'
PATH_PREPROCESSED_VALIDATION_DATASET = PATH_PREPROCESSED_DATASET + 'validation/'

PATH_CHECKPOINTS = './checkpoints/'

CONFIG_PREPROCESSING = {
	'raster_size': 512,
	'scale': 3,
	'roadgraph_distillation_rate': 5,
	'center_x': 256,
	'center_y': 256,
}

CONFIG_MODEL = {
	'backbone': 'resnet34',
	'n_modes': 3,
	'n_timestamps': 80,
	'predict_covariances': True,
}

CONFIG_TRAINING = {
	'num_epochs': 5,
	'eval_every': 10000,
	'optimizer': {
		'lr': 0.0003
	},
	'train_dataloader': {
		'batch_size': 16,
		'num_workers': 8,
		'shuffle': True
	},
	'val_dataloader': {
		'batch_size': 16,
		'num_workers': 8,
		'shuffle': False
	},
}


### Preprocess dataset

#### Helper functions

##### Math

In [3]:
def rot_matrix(angle):
  return np.array([[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]])

In [4]:
def shift_rotate(x, shift, angle):
	assert isinstance(angle, (np.float32, np.float16, float))
	assert shift.shape[-1] == 2
	assert x.shape[-1] == 2
	return (x + shift) @ rot_matrix(angle).T

In [5]:
def rotate_shift(x, shift, angle):
	assert isinstance(angle, (np.float32, np.float16, float))
	assert shift.shape[-1] == 2
	assert x.shape[-1] == 2
	return (x) @ rot_matrix(angle).T + shift

##### Saves

In [6]:
def generate_filename(data_dict):
	return str(data_dict['agent_id']) + '.npz'

In [7]:
def create_folder_if_not_exists(path):
	if not os.path.exists(path):
		os.makedirs(path)

In [8]:
def create_saving_paths(path, data_dict):
	scenario_folder = os.path.join(path, data_dict['scenario_id'])
	agent_data_folder = os.path.join(scenario_folder, 'agent_data')
	roadgraph_data_folder = os.path.join(scenario_folder, 'roadgraph_data')
	create_folder_if_not_exists(scenario_folder)
	create_folder_if_not_exists(agent_data_folder)
	create_folder_if_not_exists(roadgraph_data_folder)

In [9]:
def save_roadgraph_data(path, data_dict, roadgraph_data):
	scenario_folder = os.path.join(path, data_dict['scenario_id'])
	roadgraph_data_folder = os.path.join(scenario_folder, 'roadgraph_data')
	create_saving_paths(path, data_dict)
	np.savez_compressed(os.path.join(roadgraph_data_folder, 'segments_global.npz'), **roadgraph_data)

In [10]:
def save_agent_data(path, data_dict):
	scenario_folder = os.path.join(path, data_dict['scenario_id'])
	agent_data_folder = os.path.join(scenario_folder, 'agent_data')
	create_saving_paths(path, data_dict)
	np.savez_compressed(os.path.join(agent_data_folder, generate_filename(data_dict)), **data_dict)

##### Processors

In [11]:
class RoadgraphProcessor:
	def __init__(self, data, config):
		self._config = config
		self._segments = None
		validity_flag = data['roadgraph_samples/valid'].numpy().flatten()
		self._roadgraph_xy = data['roadgraph_samples/xyz'].numpy()[validity_flag == 1][:, :2]
		self._roadgraph_type = data['roadgraph_samples/type'].numpy().flatten()[validity_flag == 1]
		self._ids = data['roadgraph_samples/id'].numpy().flatten()[validity_flag == 1]

	def _get_splits(self):
		splits = []
		prev_value = self._ids[0]
		for i, idx in enumerate(self._ids):
			if idx != prev_value:
				splits.append(i)
				prev_value = idx
		splits.append(len(self._ids))
		splits = [[splits[i - 1], splits[i]] for i in range(1, len(splits))]
		return splits

	def _get_color(self, segment_type):
		type_to_color = {
			# 0:  (0, 0, 0, 0), # TODO: Include key 0
			1:  (0, 0, 0, 0),
			2:  (0, 0, 0, 255),
			3:  (0, 0, 255, 0),
			6:  (0, 0, 255, 255),
			7:  (0, 255, 0, 0),
			8:  (0, 255, 0, 255),
			9:  (0, 255, 255, 0),
			10: (0, 255, 255, 255),
			11: (255, 0, 0, 0),
			12: (255, 0, 0, 255),
			13: (255, 0, 255, 0),
			15: (255, 0, 255, 255),
			16: (255, 255, 0, 0),
			17: (255, 255, 0, 255),
			18: (255, 255, 255, 0),
			19: (255, 255, 255, 255)}
			# 20: (255, 255, 255, 255)} # TODO: Include key 20
		return type_to_color[segment_type]

	def _prepare_segments(self):
		graph_segments = []
		graph_types = []
		splits = self._get_splits()
		for (start, end) in splits:
			num_points = max(int((end - start) / self._config['roadgraph_distillation_rate']), 2)
			roadline_ids = self._ids[start:end]
			roadline_types = self._roadgraph_type[start:end]
			assert all(roadline_ids == roadline_ids[0])
			if roadline_types[0] == 18:
				distilled_roadline_data = self._roadgraph_xy[start:end]
			else:
				idx = np.linspace(start, end - 1, num_points).astype(int)
				distilled_roadline_data = self._roadgraph_xy[idx]
			segments = np.concatenate([
				np.pad(distilled_roadline_data, ((0, 1), (0, 0)))[:, None, :],
				np.pad(distilled_roadline_data, ((1, 0), (0, 0)))[:, None, :]],
				axis=1)[1:-1]
			graph_segments.append(segments)
			graph_types.extend([roadline_types[0]] * segments.shape[0])
		self._segments = np.concatenate(graph_segments, axis=0)
		self._segment_types = graph_types

	def center_to(self, target_agent_xy, target_agent_yaw):
		return shift_rotate(self._segments, -target_agent_xy, -target_agent_yaw)

	def json(self):
		return json.dumps(np.round(self._segments.tolist(), 2).tolist())

	def __str__(self) -> str:
		return self.json()

	def render(self, target_agent_xy, target_agent_yaw):
		if self._segments is None:
			self._prepare_segments()
		segments = self.center_to(target_agent_xy, target_agent_yaw)
		masked_raster = np.zeros((self._config['raster_size'], self._config['raster_size'], 1), np.uint8)
		typed_raster = np.zeros((self._config['raster_size'], self._config['raster_size'], 4), np.uint8)
		for segment_type, segment in zip(self._segment_types, segments):
			# INFO: Added code to skip 0 and 20 segment types newly added to the dataset
			if segment_type == 0 or segment_type == 20:
				continue
			int_segment = (segment * self._config['scale'] + \
				np.array(
					[self._config['center_x'], self._config['center_y']])) \
					.astype(int)
			masked_raster = cv2.line(
				masked_raster,
				int_segment[0], int_segment[1],
				255, 1)
			typed_raster = cv2.line(
				typed_raster,
				int_segment[0], int_segment[1],
				self._get_color(segment_type), 1)
		raster = np.concatenate([masked_raster, typed_raster], axis=-1)
		return raster

	def get_roadgraph_segments_data(self):
		return {'roadgraph_segments': self._segments}

In [12]:
class AgentProcessor:
	def __init__(self, data, config):
		self._config = config

		history_valid = np.concatenate([
      data['state/past/valid'].numpy(),
      data['state/current/valid'].numpy()], axis=-1
		)
		
		present_in_history = np.max(history_valid, axis=-1)
		self._is_target = data['state/tracks_to_predict'].numpy().flatten()
		selector = np.logical_or(present_in_history == 1, self._is_target == 1)

		self._history_xy = np.concatenate([
			np.concatenate([
				data['state/past/x'].numpy(),
				data['state/current/x'].numpy()], axis=-1)[:, :, None],
			np.concatenate([
				data['state/past/y'].numpy(),
				data['state/current/y'].numpy()], axis=-1)[:, :, None]],
		axis = -1)[selector]

		self._history_yaw = np.concatenate([
			data['state/past/bbox_yaw'].numpy(),
			data['state/current/bbox_yaw'].numpy()],
		axis = -1)[selector]

		self._history_valid = np.concatenate([
			data['state/past/valid'].numpy(),
			data['state/current/valid'].numpy()],
		axis = -1)[selector]

		self._future_xy = np.concatenate([
			data['state/future/x'].numpy()[:, :, None],
			data['state/future/y'].numpy()[:, :, None]],
		axis = -1)[selector]

		self._future_valid = data['state/future/valid'].numpy()[selector]

		self._current_xy = np.concatenate([
			data['state/current/x'].numpy(),
			data['state/current/y'].numpy()], axis = -1)[selector]

		self._history_speed = data['state/past/speed'].numpy()[selector]
		self._current_speed = data['state/current/speed'].numpy().flatten()[selector]
		self._future_speed = data['state/future/speed'].numpy()[selector]
		self._current_yaw = data['state/current/bbox_yaw'].numpy().flatten()[selector]
		self._agents_id = data['state/id'].numpy().flatten().astype(int)[selector]
		self._is_sdc = data['state/is_sdc'].numpy().flatten().astype(int)[selector]
		self._scenario_id = data['scenario/id'].numpy().item().decode()
		self._agents_type = data['state/type'].numpy().flatten().astype(int)[selector]
		self._agents_width = data['state/current/width'].numpy().flatten()[selector]
		self._agents_length = data['state/current/length'].numpy().flatten()[selector]

	def target_agents_idx(self):
		return np.arange(len(self._is_target))[self._is_target == 1]

	def get_target_agent_position(self, idx):
		return self._current_xy[idx], self._current_yaw[idx]
	
	def _gen_box(
    self, current_agent_xy, current_agent_yaw,
    target_agent_xy, target_agent_yaw,
    current_agent_length, current_agent_width
	):
		box = np.array([
			[-current_agent_length / 2, -current_agent_width / 2],
			[ current_agent_length / 2, -current_agent_width / 2],
			[ current_agent_length / 2,  current_agent_width / 2],
			[-current_agent_length / 2,  current_agent_width / 2]])[None, ]
		box *= self._config['scale']
		box = box @ rot_matrix(current_agent_yaw).T
		box = shift_rotate(box, (current_agent_xy - target_agent_xy) * self._config['scale'], -target_agent_yaw)
		return box

	def _draw_box(
		self, raster,
		current_agent_xy, current_agent_yaw,
		target_agent_xy, target_agent_yaw,
		current_agent_length, current_agent_width
	):
		raster = cv2.fillPoly(
			raster, 
			(self._gen_box(
				current_agent_xy, current_agent_yaw,
				target_agent_xy, target_agent_yaw,
				current_agent_length, current_agent_width) + \
					np.array([
						self._config['center_x'], self._config['center_y']])) \
						.astype(int),
			128, lineType=cv2.LINE_AA)
		raster = cv2.polylines(
			raster, 
			(self._gen_box(
				current_agent_xy, current_agent_yaw,
				target_agent_xy, target_agent_yaw,
				current_agent_length, current_agent_width) + \
					np.array([
						self._config['center_x'], self._config['center_y']])) \
						.astype(int),
			True, 255, lineType=cv2.LINE_AA, thickness=1)
		return raster

	def render(self, target_agent_order_idx):
		agents_raster = [np.zeros((
			self._config['raster_size'], self._config['raster_size'], 1),
			np.uint8) for _ in range(22)
		]

		target_agent_xy, target_agent_yaw = self.get_target_agent_position(target_agent_order_idx)
		
		for rendering_agent_agent_order_idx, (
			rendering_agent_history_xy,
			rendering_agent_history_yaw,
			rendering_agent_length,
			rendering_agent_width,
			rendering_agent_history_valid) in enumerate(zip(
				self._history_xy, self._history_yaw,
				self._agents_length, self._agents_width, self._history_valid)):
			
			for history_timestamp, (
				rendering_agent_history_xy_state,
        rendering_agent_history_yaw_state,
        rendering_agent_history_valid_state) in enumerate(zip(
					rendering_agent_history_xy,
					rendering_agent_history_yaw,
					rendering_agent_history_valid)):

				if rendering_agent_history_valid_state == 0:
						continue

				channel = history_timestamp

				if target_agent_order_idx == rendering_agent_agent_order_idx:
					channel += 11

				agents_raster[channel] = self._draw_box(
					agents_raster[channel],
					rendering_agent_history_xy_state,
					rendering_agent_history_yaw_state,
					target_agent_xy, target_agent_yaw,
					rendering_agent_length,
					rendering_agent_width
				)

		agents_raster = np.concatenate(agents_raster, axis=-1)
		return agents_raster

	def get_numerical_data(self, agent_order_idx):
		agent_gt_global = self._future_xy[agent_order_idx]
		target_agent_shift, target_agent_yaw = \
			self.get_target_agent_position(agent_order_idx)
		numerical_data = {
			'agent_id': self._agents_id[agent_order_idx],
			'scenario_id': self._scenario_id,
			'is_sdc': self._is_sdc[agent_order_idx],
			'agent_type': self._agents_type[agent_order_idx],
			'future_global': agent_gt_global,
			'future_local': shift_rotate(agent_gt_global, -target_agent_shift, -target_agent_yaw),
			'future_valid': self._future_valid[agent_order_idx],
			'history_global': self._history_xy[agent_order_idx],
			'history_valid': self._history_valid[agent_order_idx],
			'history_yaw_global': self._history_yaw[agent_order_idx],
			'current_xy_global': self._current_xy[agent_order_idx],
			'history_speed': self._history_speed[agent_order_idx],
			'current_speed': self._current_speed[agent_order_idx],
			'future_speed': self._future_speed[agent_order_idx],
			'width': self._agents_width[agent_order_idx],
			'length': self._agents_length[agent_order_idx],
			'shift': target_agent_shift,
			'yaw': target_agent_yaw}
		return numerical_data

##### Features

In [13]:
def generate_agent_features_by_timezone(timezone):
	_values_number_for_timezone = {
    "current": 1,
    "future": 80,
    "past": 10
	}

	n_values = _values_number_for_timezone[timezone]

	return {
		f"state/{timezone}/x": tf.io.FixedLenFeature(
				[128, n_values], tf.float32, default_value=None),
		f"state/{timezone}/y": tf.io.FixedLenFeature(
				[128, n_values], tf.float32, default_value=None),
		f"state/{timezone}/z": tf.io.FixedLenFeature(
				[128, n_values], tf.float32, default_value=None),

		f"state/{timezone}/velocity_x": tf.io.FixedLenFeature(
				[128, n_values], tf.float32, default_value=None),
		f"state/{timezone}/velocity_y": tf.io.FixedLenFeature(
				[128, n_values], tf.float32, default_value=None),
		f"state/{timezone}/speed": tf.io.FixedLenFeature(
				[128, n_values], tf.float32, default_value=None),

		f"state/{timezone}/length": tf.io.FixedLenFeature(
				[128, n_values], tf.float32, default_value=None),
		f"state/{timezone}/width": tf.io.FixedLenFeature(
				[128, n_values], tf.float32, default_value=None),
		f"state/{timezone}/height": tf.io.FixedLenFeature(
				[128, n_values], tf.float32, default_value=None),

		f"state/{timezone}/bbox_yaw": tf.io.FixedLenFeature(
				[128, n_values], tf.float32, default_value=None),
		f"state/{timezone}/timestamp_micros": tf.io.FixedLenFeature(
				[128, n_values], tf.int64, default_value=None),
		f"state/{timezone}/valid": tf.io.FixedLenFeature(
				[128, n_values], tf.int64, default_value=None),
		f"state/{timezone}/vel_yaw": tf.io.FixedLenFeature(
				[128, n_values], tf.float32, default_value=None)
	}

In [14]:
def get_features_description():
	_roadgraph_features = {
    "roadgraph_samples/dir": tf.io.FixedLenFeature(
        [30000, 3], tf.float32, default_value=None
    ),
    "roadgraph_samples/id": tf.io.FixedLenFeature(
        [30000, 1], tf.int64, default_value=None
    ),
    "roadgraph_samples/type": tf.io.FixedLenFeature(
        [30000, 1], tf.int64, default_value=None
    ),
    "roadgraph_samples/valid": tf.io.FixedLenFeature(
        [30000, 1], tf.int64, default_value=None
    ),
    "roadgraph_samples/xyz": tf.io.FixedLenFeature(
        [30000, 3], tf.float32, default_value=None
    ),
	}

	_general_state_features = {
    "state/id": tf.io.FixedLenFeature([128], tf.float32, default_value=None),
    "state/type": tf.io.FixedLenFeature([128], tf.float32, default_value=None),
    "state/is_sdc": tf.io.FixedLenFeature([128], tf.int64, default_value=None),
    "state/tracks_to_predict": tf.io.FixedLenFeature(
        [128], tf.int64, default_value=None),
    "scenario/id": tf.io.FixedLenFeature([1], tf.string, default_value=None)
	}

	_traffic_light_features = {
    "traffic_light_state/current/state":
        tf.io.FixedLenFeature([1, 16], tf.int64, default_value=None),
    "traffic_light_state/current/valid":
        tf.io.FixedLenFeature([1, 16], tf.int64, default_value=None),
    "traffic_light_state/current/x":
        tf.io.FixedLenFeature([1, 16], tf.float32, default_value=None),
    "traffic_light_state/current/y":
        tf.io.FixedLenFeature([1, 16], tf.float32, default_value=None),
    "traffic_light_state/current/z":
        tf.io.FixedLenFeature([1, 16], tf.float32, default_value=None),
    "traffic_light_state/past/state":
        tf.io.FixedLenFeature([10, 16], tf.int64, default_value=None),
    "traffic_light_state/past/valid":
        tf.io.FixedLenFeature([10, 16], tf.int64, default_value=None),
    "traffic_light_state/past/x":
        tf.io.FixedLenFeature([10, 16], tf.float32, default_value=None),
    "traffic_light_state/past/y":
        tf.io.FixedLenFeature([10, 16], tf.float32, default_value=None),
    "traffic_light_state/past/z":
        tf.io.FixedLenFeature([10, 16], tf.float32, default_value=None),
	}

	features_description = {}
	features_description.update(_roadgraph_features)
	features_description.update(_general_state_features)
	features_description.update(_traffic_light_features)

	for timezone in ['past', 'current', 'future']:
		features_description.update(generate_agent_features_by_timezone(timezone))

	return features_description

#### Prerender

In [15]:
def prerender(data, config, output_path):
	data = tf.io.parse_single_example(data, get_features_description())

	agent_processor = AgentProcessor(data, config)
	roadgraph_processor = RoadgraphProcessor(data, config)

	for i in agent_processor.target_agents_idx():
		agents_raster = agent_processor.render(i)
		roadgraph_raster = roadgraph_processor.render(*agent_processor.get_target_agent_position(i))
		full_raster = np.concatenate([roadgraph_raster, agents_raster], axis=-1)

		prepared_data = {'raster': full_raster}
		prepared_data.update(agent_processor.get_numerical_data(i))

		save_agent_data(output_path, prepared_data)
		
	save_roadgraph_data(output_path, prepared_data, roadgraph_processor.get_roadgraph_segments_data())

In [16]:
path_dataset = [
	PATH_TRAINING_DATASET,
	PATH_TESTING_DATASET,
	PATH_VALIDATION_DATASET
]

path_preprocessed_dataset = [
	PATH_PREPROCESSED_TRAINING_DATASET,
	PATH_PREPROCESSED_TESTING_DATASET,
	PATH_PREPROCESSED_VALIDATION_DATASET
]

In [17]:
datasets = []
for path in path_dataset:
	files = os.listdir(path)
	datasets.append(tf.data.TFRecordDataset([os.path.join(path, f) for f in files]))

In [18]:
'''
Uncomment to generate preprocessed data
p = multiprocessing.Pool(N_JOBS)
processes = []

for i, dataset in enumerate(datasets):
	path_output = path_preprocessed_dataset[i]
	
	for data in tqdm(dataset.as_numpy_iterator()):
		processes.append(
			p.apply_async(
				prerender,
				kwds = dict(
					data = data,
					config = CONFIG_PREPROCESSING,
					output_path = path_output
				)
			)
		)
	
	for r in tqdm(processes):
		r.get()
'''

'\nUncomment to generate preprocessed data\np = multiprocessing.Pool(N_JOBS)\nprocesses = []\n\nfor i, dataset in enumerate(datasets):\n\tpath_output = path_preprocessed_dataset[i]\n\t\n\tfor data in tqdm(dataset.as_numpy_iterator()):\n\t\tprocesses.append(\n\t\t\tp.apply_async(\n\t\t\t\tprerender,\n\t\t\t\tkwds = dict(\n\t\t\t\t\tdata = data,\n\t\t\t\t\tconfig = CONFIG_PREPROCESSING,\n\t\t\t\t\toutput_path = path_output\n\t\t\t\t)\n\t\t\t)\n\t\t)\n\t\n\tfor r in tqdm(processes):\n\t\tr.get()\n'

### Training

#### Losses

In [19]:
class Loss(ABC, nn.Module):
	def _precision_matrix(shape, sigma_xx, sigma_yy):
		assert sigma_xx.shape[-1] == 1
		assert sigma_xx.shape == sigma_yy.shape

		batch_size, n_modes, n_future_timstamps = sigma_xx.shape[0], sigma_xx.shape[1], sigma_xx.shape[2]
		
		sigma_xx_inv = 1 / sigma_xx
		sigma_yy_inv = 1 / sigma_yy
		
		return torch.cat([
			sigma_xx_inv,
			torch.zeros_like(sigma_xx_inv),
			torch.zeros_like(sigma_yy_inv),
			sigma_yy_inv], dim=-1).reshape(batch_size, n_modes, n_future_timstamps, 2, 2)

	def _log_N_conf(self, data_dict, prediction_dict):
		gt = data_dict['future_local'].unsqueeze(1)

		diff = (prediction_dict['xy'] - gt) * data_dict['future_valid'][:, None, :, None]
		assert torch.isfinite(diff).all()

		precision_matrices = self._precision_matrix(prediction_dict['sigma_xx'], prediction_dict['sigma_yy'])
		assert torch.isfinite(precision_matrices).all()

		log_confidences = torch.log_softmax(prediction_dict['confidences'], dim=-1)
		assert torch.isfinite(log_confidences).all()

		bilinear = diff.unsqueeze(-2) @ precision_matrices @ diff.unsqueeze(-1)
		bilinear = bilinear[:, :, :, 0, 0]
		assert torch.isfinite(bilinear).all()

		log_N = -0.5 * np.log(2 * np.pi) - 0.5 * torch.log(
			prediction_dict['sigma_xx'] * prediction_dict['sigma_yy']
			).squeeze(-1) - 0.5 * bilinear
			
		return log_N, log_confidences


In [20]:
class NLLGaussian2d(Loss):
	def __init__(self):
		super().__init__()

	def forward(self, data_dict, prediction_dict):
		log_N, log_confidences = self._log_N_conf(data_dict, prediction_dict)
		assert torch.isfinite(log_N).all()

		log_L = torch.logsumexp(log_N.sum(dim=2) + log_confidences, dim=1)
		assert torch.isfinite(log_L).all()
		
		return -log_L.mean()

#### Postprocess predictions

In [21]:
def limited_softplus(x):
	return torch.clamp(F.softplus(x), min=0.1, max=10)

In [22]:
def postprocess_predictions(predicted_tensor, model_config):
	confidences = predicted_tensor[:, :model_config['n_modes']]
	components = predicted_tensor[:, model_config['n_modes']:]
	components = components.reshape(
		-1, model_config['n_modes'], model_config['n_timestamps'], 5)
	sigma_xx = components[:, :, :, 2:3]
	sigma_yy = components[:, :, :, 3:4]
	visibility = components[:, :, :, 4:]
	return {
		'confidences': confidences,
		'xy': components[:, :, :, :2],
		'sigma_xx': limited_softplus(sigma_xx) if \
				model_config['predict_covariances'] else torch.ones_like(sigma_xx),
		'sigma_yy': limited_softplus(sigma_yy) if \
				model_config['predict_covariances'] else torch.ones_like(sigma_yy),
		'visibility': visibility
	}

#### Dataset

In [23]:
class MotionCNNDataset(Dataset):
	def __init__(self, data_path, load_roadgraph=False) -> None:
		super().__init__()
		self._load_roadgraph = load_roadgraph
		self._files = glob(os.path.join(data_path, '*', 'agent_data', '*.npz'))
		self._roadgraph_data = glob(os.path.join(data_path, '*', 'roadgraph_data', 'segments_global.npz'))
		self._scid_to_roadgraph = {f.split('/')[-3]: f for f in self._roadgraph_data}

	def __len__(self):
		return len(self._files)

	def __getitem__(self, idx):
		data = dict(np.load(self._files[idx], allow_pickle=True))
		if self._load_roadgraph:
			roadgraph_data_file = self._scid_to_roadgraph[data['scenario_id'].item()]
			roadgraph_data = np.load(roadgraph_data_file)['roadgraph_segments']
			roadgraph_valid = np.ones(roadgraph_data.shape[0])

			n_to_pad = 6000 - roadgraph_data.shape[0]

			roadgraph_data = np.pad(roadgraph_data, ((0, n_to_pad), (0, 0), (0, 0)))
			roadgraph_valid = np.pad(roadgraph_valid, (0, n_to_pad))

			data['roadgraph_data'] = roadgraph_data
			data['roadgraph_valid'] = roadgraph_valid
			
		data['raster'] = data['raster'].transpose(2, 0, 1) / 255.
		data['scenario_id'] = data['scenario_id'].item()
		return data

In [24]:
def get_last_checkpoint_file(path):
	list_of_files = glob(f'{path}/*.pth')

	if len(list_of_files) == 0:
		return None
		
	latest_file = max(list_of_files, key=os.path.getctime)
	return latest_file

In [25]:
def dict_to_cuda(data_dict):
	gpu_required_keys = ['raster', 'future_valid', 'future_local']
	for key in gpu_required_keys:
			data_dict[key] = data_dict[key].cuda()
	return data_dict

In [26]:
def get_model(model_config):
	n_components = 5
	n_modes = model_config['n_modes']
	n_timestamps = model_config['n_timestamps']
	output_dim = n_modes + n_modes * n_timestamps * n_components
	model = timm.create_model(model_config['backbone'], pretrained=True, in_chans=27, num_classes=output_dim)
	return model

In [27]:
model = get_model(CONFIG_MODEL)

optimizer = Adam(model.parameters(), **CONFIG_TRAINING['optimizer'])

processed_batches = 0
epochs_processed = 0

loss_module = NLLGaussian2d()
train_losses = []

experiment_checkpoints_dir = os.path.join(PATH_CHECKPOINTS, 'basic')
if not os.path.exists(experiment_checkpoints_dir):
  os.makedirs(experiment_checkpoints_dir)

latest_checkpoint = get_last_checkpoint_file(experiment_checkpoints_dir)

if latest_checkpoint is not None:
	print(f'Loading checkpoint from {latest_checkpoint}')
	checkpoint_data = torch.load(latest_checkpoint)

	model.load_state_dict(checkpoint_data['model_state_dict'])
	optimizer.load_state_dict(checkpoint_data['optimizer_state_dict'])
	epochs_processed = checkpoint_data['epochs_processed']
	processed_batches = checkpoint_data['processed_batches']

model = nn.DataParallel(model)

In [28]:
training_dataloader = DataLoader(MotionCNNDataset(PATH_PREPROCESSED_TRAINING_DATASET), **CONFIG_TRAINING['train_dataloader'])
validation_dataloader = DataLoader(MotionCNNDataset(PATH_PREPROCESSED_VALIDATION_DATASET, load_roadgraph=True), **CONFIG_TRAINING['val_dataloader'])

In [29]:
for epochs_processed in tqdm(
	range(epochs_processed, CONFIG_TRAINING['num_epochs']),
	total = CONFIG_TRAINING['num_epochs'],
	initial = epochs_processed
):
	train_progress_bar = tqdm(training_dataloader, total=len(training_dataloader))

	for train_data in train_progress_bar:
		optimizer.zero_grad()

		train_data = dict_to_cuda(train_data)

		prediction_tensor = model(train_data['raster'].float())
		prediction_dict = postprocess_predictions(prediction_tensor, model_config)

		loss = loss_module(train_data, prediction_dict)
		loss.backward()

		optimizer.step()

		train_losses.append(loss.item())

		processed_batches += 1
		train_progress_bar.set_description("Train loss: %.3f" % np.mean(train_losses[-100:]))

		if processed_batches % training_config['eval_every'] == 0:
			del train_data

			with torch.no_grad():
				for eval_data in tqdm(validation_dataloader):
					eval_data = dict_to_cuda(eval_data)

					prediction_tensor = model(eval_data['raster'].float())
					prediction_dict = postprocess_predictions(prediction_tensor, config_model)

					loss = loss_module(eval_data, prediction_dict)

			model_state_dict = model.module.state_dict()
			
			torch_checkpoint_data = {
				"model_state_dict": model_state_dict,
				"optimizer_state_dict": optimizer.state_dict(),
				"epochs_processed": epochs_processed,
				"processed_batches": processed_batches
			}
			torch_checkpoint_path = os.path.join(experiment_checkpoints_dir, f'e{epochs_processed}_b{processed_batches}.pth')
			torch.save(torch_checkpoint_data, torch_checkpoint_path)

  0%|          | 0/5 [00:00<?, ?it/s]