In [320]:
"""

This module contains the `AgentCenter` class which applies agent-centered
transformation to the given batch data.

"""

import numpy as np
import torch


def homogenize_matrix(matrix):
    """
    Homogenize a 2D matrix by adding a column of ones.

    Args:
        matrix (np.ndarray): 2D matrix.

    Returns:
        np.ndarray: Homogenized matrix with an additional column of ones.
    """

    # get the original shape
    original_shape = matrix.shape

    # get the non-numerical dimensions
    non_numerical_dims = original_shape[:-1]

    # add the '1' layer/row
    shape = non_numerical_dims + (1,)
    ones = np.ones(shape)

    homogenized_matrix = np.concatenate(
        [matrix, ones],
        axis=-1,
    )
    return homogenized_matrix


def get_translation_matrix(positions):
    """
    Gets the translation matrix for the given positions.
    """
    num_timesteps = positions.shape[0]

    translation_transforms = np.eye(3)[np.newaxis].repeat(num_timesteps, axis=0)

    # set the translation component of the transformation matrices
    translation_transforms[:, :2, 2] -= positions

    return translation_transforms


def get_rotation_matrix(positions):
    """
    Gets the rotation matrix for the given positions.
    """
    rotation_transforms = np.eye(3)

    # get the angle from the target agent's first input position to the
    # final input position
    first_position = positions[0]
    last_position = positions[-1]

    # get the angle
    theta = (
        -np.arctan2(
            last_position[1] - first_position[1],
            last_position[0] - first_position[0],
        )
        + np.pi / 2
    )

    rotation_transforms[0, 0] = np.cos(theta)
    rotation_transforms[0, 1] = -np.sin(theta)
    rotation_transforms[1, 0] = np.sin(theta)
    rotation_transforms[1, 1] = np.cos(theta)

    return rotation_transforms


def apply(datum):
    """
    Apply agent-centered transformation to the given datum.

    Args:
        datum (dict): Dictionary representing a single data point.

    Returns:
        dict: Transformed datum with updated positions.
    """
    # get all of the ids for the agents being tracked
    # renaming due to bad naming in the dataset
    agent_ids = datum["track_id"]

    # extract the agent_id from the datum
    target_id = datum["agent_id"]

    # get the index of the target agent
    agent_index = np.where(agent_ids == target_id)[0][0]

    # get the input and output data
    positions_in = np.array(datum["p_in"])
    velocities_in = np.array(datum["v_in"])
    positions_out = np.array(datum["p_out"])
    velocities_out = np.array(datum["v_out"])

    # FIXME:
    # save the input length before we extend it
    input_length = positions_in.shape[1]

    # extend by the output data
    positions = np.concatenate([positions_in, positions_out], axis=1)
    velocities = np.concatenate([velocities_in, velocities_out], axis=1)

    # homogenize the 2D data
    # shape: (num_timesteps, num_agents, 3)
    # positions_h = homogenize_matrix(positions)
    # velocities_h = homogenize_matrix(velocities)

    offsets_h = np.diff(positions[agent_index], axis=0)

    # create a list of transformation matrices that center all points
    # around the target agent
    target_positions = positions[agent_index]

    # Shape: (num_timesteps, 3, 3)
    positions = positions - target_positions

    first_offset = np.array([0, 0, 0])
    offsets_h = np.vstack([first_offset,offsets_h])
    offsets = offsets_h[:, :2]
    positions[agent_index] = offsets

    p_out = positions[:, input_length:]

    # update the positions in the datum
    datum["p_in"] = positions[:, :input_length]
    datum["v_in"] = velocities[:, :input_length]
    datum["p_out"] = p_out
    datum["v_out"] = velocities[:, input_length:]

    # update the prediction correction
    datum["prediction_correction"] = inverse


    metadata = {
        "target_offset": target_positions,
    }

    datum["batch_correction_metadata"] = metadata

    return datum


def inverse(predictions, metadata):
    """TODO: correct_predictions"""

    # IMPORTANT: inputs are batched

    # NOTE: Since we have only moved the positions, we can just leave them
    # as is for training, but we'll need to revert them back to the original
    # positions when we test against the dataset.

    # Really, this should convert all the way back to the original, global
    # positions, but that's a bit more effor. Leaving it as a TODO for now.
    # thought: embed metadata in the data to know how to invert it:
    # input, output, prediction_correction, batch_correction_metadata.
    predictions = np.cumsum(predictions, axis=0)

    predictions = homogenize_matrix(predictions)

    # get the translation and rotation matrices
    target_positions = metadata["target_offset"][18]

    # (30, 3) @ (3, 3) -> (30, 3)
    predictions = predictions + target_positions

    # dehomogenize the data
    predictions = predictions[:, :2]

    return predictions

    # # apply corrections needed by other transformations.
    # return AgentCenter.prior_prediction_correction(batch_predictions, batch_metadata)


In [321]:

import pickle
from glob import glob
import torch
from torch.utils.data import Dataset, DataLoader
import os
import os.path

class ArgoverseDataset(Dataset):
    """Dataset class for Argoverse"""

    def __init__(self, data_path: str, transform=None):
        """TODO: init"""
        super(ArgoverseDataset, self).__init__()
        self.data_path = data_path
        self.transform = transform

        self.pkl_list = glob(os.path.join(self.data_path, "*"))
        self.pkl_list.sort()

    def __len__(self):
        """TODO: len"""
        return len(self.pkl_list)

    def __getitem__(self, idx):
        """getitem"""
        pkl_path = self.pkl_list[idx]
        with open(pkl_path, "rb") as f:
            data = pickle.load(f)

        if self.transform:
            data = self.transform(data)

        return data

dataset = ArgoverseDataset(data_path="data/train", transform=None)

In [322]:
# import tqdm 
# # test loop through time

# for datum in tqdm.tqdm(dataset):
#     _ = apply(datum)

# # using matmul for translation: 1:13, 1:14
# # using addition for translation: 0:41, 0:33, clearly much better.
    

In [323]:
datum = dataset[100]
agent_index = np.where(datum["track_id"] == datum["agent_id"])[0][0]
# print(agent_index)
print(datum["p_out"][agent_index][:5])

datum_transformed = apply(datum)
# print(datum_transformed["p_out"][agent_index][:5])

# print(datum_transformed["batch_correction_metadata"]["translation_transforms"][0])

[[ 151.68348694 2466.12939453]
 [ 151.58137512 2465.3894043 ]
 [ 151.59472656 2464.39868164]
 [ 151.56767273 2463.66186523]
 [ 151.54214478 2462.75317383]]


ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 1, the array at index 0 has size 3 and the array at index 1 has size 2

In [318]:
predictions = datum_transformed["p_out"][agent_index]
metadata = datum_transformed["batch_correction_metadata"]

inversed = inverse(predictions, metadata)

print(inversed.shape)
print(inversed[:5])

(30, 2)
[[ 151.68348694 2466.12939453]
 [ 151.58137512 2465.3894043 ]
 [ 151.59472656 2464.39868164]
 [ 151.56767273 2463.66186523]
 [ 151.54214478 2462.75317383]]
