In [None]:
import bc_algos.utils.obs_utils as ObsUtils
from bc_algos.models.obs_core import EncoderCore, VisualCore
from bc_algos.dataset.robomimic import RobomimicDataset
from bc_algos.models.obs_nets import ObservationGroupEncoder, ActionDecoder
from bc_algos.models.backbone import Transformer
from bc_algos.models.policy_nets import BC_Transformer
import torch
from addict import Dict
import json
from collections import OrderedDict

config_path = "../config/bc_transformer.json"

# load config 
with open(config_path, 'r') as f:
    config = json.load(f)
config = Dict(config)

In [None]:
ObsUtils.register_encoder_core(EncoderCore, ObsUtils.Modality.LOW_DIM)
ObsUtils.register_encoder_core(VisualCore, ObsUtils.Modality.RGB)

### ObsUtils Initialization

In [None]:
ObsUtils.init_obs_utils(config=config)
print(ObsUtils.MODALITY_TO_ENC_CORE)
print(ObsUtils.OBS_KEY_TO_SHAPE)
print(ObsUtils.OBS_KEY_TO_MODALITY)
print(ObsUtils.OBS_GROUP_TO_KEY)

### Factory Functions

In [None]:
# check compilation
trainset = RobomimicDataset.factory(config=config, train=True)
validset = RobomimicDataset.factory(config=config, train=False)

obs_group_enc = ObservationGroupEncoder.factory(config=config)
transformer = Transformer.factory(config=config, input_dim=obs_group_enc.output_dim)
act_dec = ActionDecoder.factory(config=config, input_dim=transformer.output_dim)

### Test Policy

In [None]:
B = 4
T = 10
robot0_eef_pos = 2*torch.rand(B, T, *config.observation.shapes.robot0_eef_pos)-1
robot0_eef_quat = 2*torch.rand(B, T, *config.observation.shapes.robot0_eef_quat)-1
agentview_image = 2*torch.rand(B, T, *config.observation.shapes.agentview_image)-1
inputs = OrderedDict({
    "obs": {
        "robot0_eef_pos": robot0_eef_pos,
        "robot0_eef_quat": robot0_eef_quat,
        "agentview_image": agentview_image,
    },
    "goal": {
        "agentview_image": agentview_image,
    },
})

In [None]:
bc_transformer = BC_Transformer(
    obs_group_enc=obs_group_enc,
    backbone=transformer,
    act_dec=act_dec,
)

assert list(bc_transformer(inputs).shape) == [B, T, *config.policy.action_shape]

In [None]:
ObsUtils.unregister_encoder_core(ObsUtils.Modality.LOW_DIM)
ObsUtils.unregister_encoder_core(ObsUtils.Modality.RGB)

In [None]:
import torch
import torch.nn as nn
y = torch.Tensor([1.])
yh = torch.Tensor([0.78])
loss = nn.MSELoss()
print(loss(y, yh).item())
assert isinstance(loss, nn.Module)

In [37]:
from torch.utils.data import DataLoader
from bc_algos.utils.tensor_utils import slice
from tqdm import tqdm
train_loader = DataLoader(trainset, batch_size=config.train.batch_size, shuffle=True)

with tqdm(total=len(train_loader), unit='batch') as progress_bar:
    for batch in train_loader:
        print(batch["actions"].shape)
        print(config.dataset.frame_stack)
        pi = int(config.dataset.frame_stack)
        target = batch["actions"][:, pi:, :]
        inputs = slice(x=batch, dim=1, start=0, end=config.dataset.frame_stack+1)
        print(target.shape)
        print(inputs["obs"]["agentview_image"].shape)
        break
        progress_bar.update(1)

  0%|          | 0/1698 [00:00<?, ?batch/s]

torch.Size([16, 19, 7])
9
torch.Size([16, 10, 7])
torch.Size([16, 10, 3, 84, 84])



