# Trajectron Examples

In [1]:
import matplotlib
%matplotlib inline
import matplotlib.pyplot as plt
plt.ioff()
from IPython.display import HTML

In [2]:
import mantrap
import torch


trajectron = mantrap.environment.Trajectron()
trajectron.add_ado(position=torch.zeros(2), velocity=torch.tensor([1, 0]))
trajectron.add_ado(position=torch.tensor([5, 0]), velocity=torch.tensor([-0.1, -0.4]))
trajectron.add_ado(position=torch.tensor([-4, 2]), velocity=torch.tensor([1, -1]))
trajectron.add_ado(position=torch.tensor([-2, -2]), velocity=torch.tensor([1, 1]))

# HTML(trajectron.visualize_prediction_wo_ego(t_horizon=5, enforce=True))

<mantrap.agents.integrator_single.IntegratorDTAgent at 0x130d93390>

In [22]:
import mantrap
import mantrap.visualization
import numpy as np
import torch

t_horizon = 5
num_grid_points = 200

trajectron = mantrap.environment.Trajectron()
trajectron.add_ado(position=torch.zeros(2), velocity=torch.tensor([3, 0]))
trajectron.add_ado(position=torch.tensor([-4, 2]), velocity=torch.tensor([-2, -2]))

# Compute full probability distribution of trajectory prediction. 
_, distribution = trajectron.build_connected_graph_wo_ego(t_horizon=t_horizon, return_distribution=True)

# Create probability heat-map by sampling the probability distribution on grid points.
with torch.no_grad():
    x_grid, y_grid = torch.meshgrid(torch.linspace(*trajectron.x_axis, steps=num_grid_points), 
                                    torch.linspace(*trajectron.y_axis, steps=num_grid_points))
    x_grid = x_grid.flatten()
    y_grid = y_grid.flatten()
    grid_points = torch.stack((x_grid, y_grid), dim=1)

    # Determine the probabilities over the full distribution, i.e. L modes for every step of 
    # the planning horizon T.
    grid_points_stacked = torch.stack(t_horizon * [grid_points])
    grid_points_stacked = grid_points_stacked.view(num_grid_points**2, 1, t_horizon, 2)
    
grid_probabilities = torch.zeros((num_grid_points**2, trajectron.num_ados, t_horizon))
for i_ado, (_, ado_dist) in enumerate(distribution.items()):
    log_prob = ado_dist.log_prob(grid_points_stacked).squeeze(dim=-2)
    grid_probabilities[:, i_ado, :] = torch.exp(log_prob)
    
# Sum over all ados to get one map per time-step.
grid_probabilities_sum = torch.sum(grid_probabilities, dim=1)
    
# Plot resulting heat-map for every time-step.
images = grid_probabilities_sum.view(t_horizon, num_grid_points, num_grid_points)
images = images.detach().numpy()
resolution = 20/(num_grid_points - 1)
HTML(mantrap.visualization.visualize_heat_map(images, bounds=((-10, -10), (10, 10)), resolution=resolution))