In [9]:
import json

import robomimic.utils.train_utils as TrainUtils
import robomimic.utils.file_utils as FileUtils
import robomimic.utils.obs_utils as ObsUtils

from robomimic.config import config_factory

from torch.utils.data import DataLoader

def get_dataloaders_from_mimic(config_path, data_path):
    ext_cfg = json.load(open(config_path, 'r'))
    config = config_factory(ext_cfg["algo_name"])
    # update config with external json - this will throw errors if
    # the external config has keys not present in the base algo config
    with config.values_unlocked():
        config.update(ext_cfg)
        config.train.data = data_path

    ObsUtils.initialize_obs_utils_with_config(config)

    action_keys = ["actions"]
    shape_meta = FileUtils.get_shape_metadata_from_dataset(
            dataset_path=config.train.data,
            all_obs_keys=config.all_obs_keys,
            action_keys = action_keys,
            language_conditioned=config.observation.language_conditioned,
            verbose=True)

    trainset, validset = TrainUtils.load_data_for_training(
        config, obs_keys=shape_meta["all_obs_keys"])
    train_sampler = trainset.get_dataset_sampler()

    # initialize data loaders
    train_loader = DataLoader(dataset=trainset,
                                sampler=train_sampler,
                                batch_size=config.train.batch_size,
                                shuffle=(train_sampler is None),
                                num_workers=config.train.num_data_workers,
                                drop_last=True)

    if config.experiment.validate:
        # cap num workers for validation dataset at 1
        num_workers = min(config.train.num_data_workers, 1)
        valid_sampler = validset.get_dataset_sampler()
        valid_loader = DataLoader(dataset=validset,
                                  sampler=valid_sampler,
                                  batch_size=config.train.batch_size,
                                  shuffle=(valid_sampler is None),
                                  num_workers=num_workers,
                                  drop_last=True)
    else:
        valid_loader = None

    return train_loader, valid_loader

In [10]:
config_path = "/home/memmelma/Projects/robotic/robomimic_pcd/robomimic/config/default_templates/proprio_pcd_no_noise.json"
data_path = "/home/memmelma/Projects/robotic/gifs_curobo/red_cube_500_pcd_vanilla.hdf5"

train_loader, valid_loader = get_dataloaders_from_mimic(config_path, data_path)



using obs modality: low_dim with keys: ['qpos_normalized']
using obs modality: rgb with keys: []
using obs modality: depth with keys: []
using obs modality: scan with keys: []
using obs modality: pc with keys: ['camera_intrinsic', 'depth', 'camera_extrinsic']
obs key camera_extrinsic with shape (4, 4)
obs key camera_intrinsic with shape (3, 3)
obs key depth with shape (128, 128)
obs key qpos_normalized with shape (7,)
SequenceDataset: loading dataset into memory...
  0%|          | 0/495 [00:00<?, ?it/s]

100%|██████████| 495/495 [00:00<00:00, 5935.81it/s]
SequenceDataset: loading dataset into memory...
100%|██████████| 5/5 [00:00<00:00, 4509.03it/s]


In [11]:
batch = next(iter(train_loader))

In [12]:
batch = next(iter(valid_loader))

In [7]:
batch["obs"].keys()

dict_keys(['camera_extrinsic', 'camera_intrinsic', 'depth', 'qpos_normalized'])