In [52]:
from torch.utils.data import Dataset
import numpy as np
import logging
import os
import h5py

In [53]:
IMG_SIZE = (320, 512)
GOAL_PROMPT = {
    "can_ph": "Pick up the red can and place it in the bottom right container.", 
    "lift_ph": "Pick up the red block.",
    "square_ph": "Pick up the square tool and place it on the square peg.", 
}

In [54]:
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("robomimic_datasets")

In [None]:
class RobomimicDataset(Dataset):
    def __init__(self, data_dir, cam_views=["agentview"], max_demos=50, max_samples=50):
        """
        Initializes a general Pytorch dataset for Robomimic data.

        Args:
            data_dir (str): Directory that contains Robomimic datasets (.hdf5 files).
            cam_views ([str]): Only images from these camera views will be extracted.
            max_demos (int): Maximum number of demos per dataset to be extracted.
            max_samples (int): Maximum number of samples per demo to be extracted.
        """
        # N = demos*samples
        self.data_dir = data_dir
        # (N, 7) - ee_pos, ee_quat
        self.state = []
        # (len(cam_views), N, H, W, 3)
        self.rgb = [[] for _ in range(len(cam_views))]
        # (len(cam_views), N, H, W, 1)
        self.depth = [[] for _ in range(len(cam_views))]
        # (N, 7) - delta ee_pos, delta ee_aa, grip_cmd
        self.action = []
        # (N,) - idx into GOAL_PROMPT
        self.goal_key = []

        for dataset_filename in os.listdir(self.data_dir):
            dataset_path = os.path.join(self.data_dir, dataset_filename)
            dataset = os.path.splitext(os.path.basename(dataset_filename))[0]
            logger.info(f"loading data from {dataset}...")
            f = h5py.File(dataset_path, "r")
            demos = list(f["data"].keys())
            for i, demo in enumerate(demos):
                if i == max_demos:
                    break
                # demo
                demo_grp = f["data/{}".format(demo)]
                num_samples = demo_grp.attrs["num_samples"]
                num_samples_trunc = min(num_samples, max_samples)
                sample_idx = np.linspace(0, num_samples-1, num_samples_trunc, dtype=int)
                # goal prompt
                self.goal_key.extend([dataset for _ in range (num_samples_trunc)])
                # action
                action = list(demo_grp["actions"][sample_idx])
                self.action.extend(action)
                # state
                eef_pos = demo_grp["obs/{}".format("robot0_eef_pos")][sample_idx]
                eef_quat = demo_grp["obs/{}".format("robot0_eef_quat")][sample_idx]
                state = list(np.concatenate((eef_pos, eef_quat), axis=-1))
                self.state.extend(state)
                # rgb, depth
                for j, cam_view in enumerate(cam_views):
                    rgb_seq = list(demo_grp["obs/{}".format(cam_view+"_image")][sample_idx])
                    depth_seq = list(demo_grp["obs/{}".format(cam_view+"_depth")][sample_idx])
                    self.rgb[j].extend(rgb_seq)
                    self.depth[j].extend(depth_seq)
                        
        self.state = np.array(self.state)
        self.rgb = np.array(self.rgb)
        self.depth = np.array(self.depth)
        self.action = np.array(self.action)
        self.goal_key = np.array(self.goal_key)
            
    def __getitem__(self, index): 
        # [insert pre-processing steps here for particular use case]
        return (self.state[index],
                self.rgb[:, index, :],
                self.depth[:, index, :],
                self.action[index], 
                GOAL_PROMPT[self.goal_key[index]])
    
    def __len__(self):
        return len(self.state)
    
dataset = RobomimicDataset(data_dir="../data/svd_sample")
logger.info(dataset.state.shape)
logger.info(dataset.action.shape)
logger.info(dataset.goal_key.shape)
logger.info(dataset.rgb.shape)
logger.info(dataset.depth.shape)