In [None]:

from bc_algos.dataset.robomimic import RobomimicDataset
from bc_algos.rollout.robomimic import RobomimicRolloutEnv
from bc_algos.dataset.isaac_gym import IsaacGymDataset
from bc_algos.rollout.isaac_gym_simple import IsaacGymSimpleRolloutEnv
import bc_algos.utils.obs_utils as ObsUtils
import bc_algos.utils.tensor_utils as TensorUtils
from bc_algos.models.obs_nets import ObservationGroupEncoder, ActionDecoder
from bc_algos.models.backbone import Transformer, MLP
from bc_algos.models.policy_nets import BC_Transformer, BC, BC_MLP
import bc_algos.utils.constants as Constants
from torch.utils.data import DataLoader
from addict import Dict
import json

config_path = "../config/bc_transformer.json"
video_dir = "../outputs/test_rollout"

# load config 
with open(config_path, 'r') as f:
    config = json.load(f)
config = Dict(config)

### ObsUtils Init

In [None]:
ObsUtils.init_obs_utils(config=config)
print(ObsUtils.MODALITY_TO_ENC_CORE_CLASS)
print(ObsUtils.OBS_KEY_TO_SHAPE)
print(ObsUtils.OBS_KEY_TO_MODALITY)
print(ObsUtils.OBS_GROUP_TO_KEY)

### Test Factory Functions

In [None]:
if config.dataset.type == Constants.DatasetType.ROBOMIMIC:
    trainset = RobomimicDataset.factory(config=config, train=True)
    validset = RobomimicDataset.factory(config=config, train=False)
elif config.dataset.type == Constants.DatasetType.ISAAC_GYM:
    trainset = IsaacGymDataset.factory(config=config, train=True)
    validset = IsaacGymDataset.factory(config=config, train=False)

obs_group_enc = ObservationGroupEncoder.factory(config=config)

if config.policy.type == Constants.PolicyType.MLP:
    backbone = MLP.factory(config=config, embed_dim=obs_group_enc.output_dim)
elif config.policy.type == Constants.PolicyType.TRANSFORMER:
    backbone = Transformer.factory(config=config)

action_dec = ActionDecoder.factory(config=config, input_dim=backbone.output_dim)

if config.rollout.type == Constants.RolloutType.ROBOMIMIC:
    rollout_env = RobomimicRolloutEnv.factory(
        config=config, 
        validset=validset, 
        normalization_stats=trainset.normalization_stats,
    )
elif config.rollout.type == Constants.RolloutType.ISAAC_GYM:
    rollout_env = IsaacGymSimpleRolloutEnv.factory(
        config=config, 
        validset=validset, 
        normalization_stats=trainset.normalization_stats,
    )

###  Test Policy

In [None]:
train_loader = DataLoader(trainset, batch_size=config.train.batch_size, shuffle=True)
train_loader_iter = iter(train_loader)
input = next(train_loader_iter)
input = BC.prepare_input(input=input)
input["obs"] = TensorUtils.slice(x=input["obs"], dim=1, start=0, end=config.dataset.frame_stack+1)

In [None]:
if config.policy.type == Constants.PolicyType.MLP:
    policy = BC_MLP(obs_group_enc=obs_group_enc, backbone=backbone, action_dec=action_dec)
elif config.policy.type == Constants.PolicyType.TRANSFORMER:
    policy = BC_Transformer.factory(
        config=config, 
        obs_group_enc=obs_group_enc, 
        backbone=backbone, 
        action_dec=action_dec
    )

actions = policy(input)
assert list(actions.shape) == [config.train.batch_size, config.dataset.seq_length, *config.policy.action_shape]

### Test Rollout

In [None]:
results = rollout_env.rollout_with_stats(
    policy=policy,
    demo_id=validset.demos[0],
    video_dir=video_dir,
    horizon=25,
    video_skip=1,
)
assert results["horizon"] == 25

### ObsUtils Deinit

In [None]:
ObsUtils.deinit_obs_utils()