In [1]:
import torch
import dgl

from argoverse.data_loading.argoverse_forecasting_loader import ArgoverseForecastingLoader
from argoverse.map_representation.map_api import ArgoverseMap

from dgl.data import DGLDataset
# from torch.utils.data import DataLoader,Dataset

import numpy as np
from functools import lru_cache
import os
from typing import Any, List, Sequence, Union, Tuple, Callable
from pathlib import Path

Using backend: pytorch


In [11]:
print(type(np.zeros(1)))
xxx = np.array([1,2])
i,o = xxx
i,o
xxx
print(list([i for i in range(2,8,1)]))
print(torch.zeros(3,1))
print(torch.zeros(3,0))

<class 'numpy.ndarray'>
[2, 3, 4, 5, 6, 7]
tensor([[0.],
        [0.],
        [0.]])
tensor([], size=(3, 0))


In [None]:
from dgl.data import DGLDataset


def get_center_agent_traj(argo_loader_obj:"ArgoverseForecastingLoader") -> np.ndarray:
    """ extract the center agent trajectory line
    Args:
    argo_loader_obj : ArgoverseForecastingLoader object
    
    Return;
    numpy.ndarray for center agent trajectory track coordinates
    """
    return argo_loader_obj.agent_traj

argo_map = ArgoverseMap()
argo_centerlines = argo_map.city_lane_centerlines_dict
def get_lane_centerlines(argo_loader_obj:"ArgoverseForecastingLoader",
                         center_coordinate:Tuple[float, float],
                         radius:float,
                         city_name:str) -> List[np.ndarray]:
    center_x, center_y = center
    U, D, L, R = center_y + radius, center_y - radius, center_x - radius, center_x + radius
    
    centerlines_list = []        
    for lane_id, lane_seg in argo_centerlines[city_name].items():
    # lane_seg has attributes : 
    #  has_traffic_control,turn_direction,is_intersection,l_neighbor_id,r_neighbor_id,predecessors,successors,centerline,
        curr_centerline = lane_seg.centerline
        mid_idx = curr_centerline.shape[0] // 2
        x, y = curr_centerline[mid_idx,:]
        if all(L < x < R, D < y < U):
            centerlines_list.append(curr_centerline)
        
    return centerlines_list

def track_transform_wrapper(agent_traj, end_idx = 19):
    center = agent_traj[end_idx]
    def transform(traj):
        return traj - center
    return transform

def build_one_track_graph(traj:np.ndarray, tranform_func: Callable[[np.ndarray], Callable]):
    num_points = traj,shape[0]
    g = dgl.graph(list(range(0, num_points-1)), list(range(1, num_points)))
    g.ndata["state"] = tranform_func(traj)
    return g

def lane_transform_wrapper(agent_traj, end_idx = 19):
    center = agent_traj[end_idx]
    def transform(lane):
        return lane - center
    return transform

def build_one_lane_graph(lane:np.ndarray, tranform_func: Callable[[np.ndarray], Callable]):    
    num_points = lane.shape[0]
    g = dgl.graph(list(range(0, num_points-1)), list(range(1, num_points)))
    g.ndata["coordinate"] = tranform_func(lane) 
    return g


class MyDataset(DGLDataset):
    """ Template for customizing graph datasets in DGL.

    Parameters
    ----------
    url : str
        URL to download the raw dataset
    raw_dir : str
        Specifying the directory that will store the
        downloaded data or the directory that
        already stores the input data.
        Default: ~/.dgl/
    save_dir : str
        Directory to save the processed dataset.
        Default: the value of `raw_dir`
    force_reload : bool
        Whether to reload the dataset. Default: False
    verbose : bool
        Whether to print out progress information
    frac : float
        proportion of the dataset
    """
    def __init__(self,
                 url=None,
                 raw_dir=None,
                 save_dir=None,
                 force_reload=False,
                 verbose=False
                 frac=1.0):
        super(MyDataset, self).__init__(name='dataset_name',
                                        url=url,
                                        raw_dir=raw_dir,
                                        save_dir=save_dir,
                                        force_reload=force_reload,
                                        verbose=verbose)
        
        self.argo_loader = ArgoverseForecastingLoader(raw_dir)
        

    def download(self):
        # download raw data to local disk
        pass

    def process(self):
        # process raw data to graphs, labels, splitting masks
        pass

    def __getitem__(self, idx):
        # get one example by index
        curr_argo_sample = self.argo_loader[idx]
        curr_agent_traj = get_center_agent_traj(curr_argo_sample)
        curr_lane_centerlines = get_lane_centerlines(argo_loader_obj = curr_argo_sample, 
                                                     center_coordinate = curr_agent_traj[19:],
                                                     radius = 50,
                                                     city_name = curr_argo_sample.city)
        curr_agent_graph = build_one_track_graph(traj=curr_agent_traj,tranform_func=track_transform_wrapper(curr_agent_traj))
#       curr_av_graph = build
        curr_lane_graphs = [build_one_lane_graph(lane=tmp_lane, tranform_func=lane_transform_wrapper(curr_agent_traj))
            for tmp_lane in curr_lane_centerlines]
        
        graph_dict = {"tracks": [curr_agent_graph], "lanes":curr_lane_graphs}                
        return graph_dict

    def __len__(self):
        # number of data examples
        return len(self.argo_loader)

    def save(self):
        # save processed data to directory `self.save_path`
        pass

    def load(self):
        # load processed data from directory `self.save_path`
        pass

    def has_cache(self):
        # check whether there are processed data in `self.save_path`
        pass