In [6]:

from l5kit.dataset.utils import convert_str_to_fixed_length_tensor
import bisect
from functools import partial
import numpy as np
from typing import Callable,Optional
from torch.utils.data  import Dataset
from l5kit.data import ChunkedDataset,get_frames_slice_from_scenes

class MyBaseEgo(Dataset):

    def __init__(
            self,
            cfg:dict,
            zarr_dataset:ChunkedDataset,
                 ) -> None:
        super().__init__()
        self.cfg = cfg
        self.dataset = zarr_dataset
        self.cumulative_sizes  = self.dataset.scenes["frame_index_interval"][:,1]
        self.sample_function  = self._get_sample_function()

    #Callable 时间差一个函数是否能够进行调用
    def _get_sample_function(self) -> Callable[..., dict]:
        raise NotImplementedError()


    def __len__(self):
        return len(self.dataset.frames)
    # 大约2000000张


    """
       A utility function to get the rasterisation and trajectory target for a given agent in a given frame

       Args:
           scene_index (int): the index of the scene in the zarr
           state_index (int): a relative frame index in the scene
           track_id (Optional[int]): the agent to rasterize or None for the AV
       Returns:
           dict: the rasterised image in (Cx0x1) if the rast is not None, the target trajectory
           (position and yaw) along with their availability, the 2D matrix to center that agent,
           the agent track (-1 if ego) and the timestamp

       """
    def get_frame(self,scene_index:int,state_index:int,track_id: Optional[int])->dict:
        frames = self.dataset.frames[get_frames_slice_from_scenes(self.dataset.scenes[scene_index])]
        tl_faces  = self.dataset.tl_faces

        #如果不考虑交通灯就设置成0
        if(self.cfg["raster_params"]["disable_traffic_light_faces"]):
            tl_faces = np.empty(0,dtype=self.dataset.tl_faces.dtype)

        data = self.sample_function(state_index,frames,self.dataset.agents,tl_faces,track_id)
        data["scene_index"]  = scene_index
        data["host_id"] = np.uint8(convert_str_to_fixed_length_tensor(self.dataset
                                                                      .scenes[scene_index]["host"]).cpu())
        data["timestamp"] = frames[state_index]["tiemstamp"]
        data["track_id"] = np.int64(-1 if track_id is None else track_id)
        return data


    def __getitem__(self, index: int) -> dict:
        """
        Function called by Torch to get an element

        Args:
            index (int): index of the element to retrieve

        Returns: please look get_frame signature and docstring

        """
        if index < 0:
            if -index > len(self):
                raise ValueError("absolute value of index should not exceed dataset length")
            index = len(self) + index

        scene_index = bisect.bisect_right(self.cumulative_sizes, index)

        if scene_index == 0:
            state_index = index
        else:
            state_index = index - self.cumulative_sizes[scene_index - 1]
        return self.get_frame(scene_index, state_index,None)







