In [1]:
from bc_algos.models.obs_core import LowDimCore, ViTMAECore, ResNet18Core
from bc_algos.models.obs_nets import ObservationEncoder, ObservationGroupEncoder, ActionDecoder
from bc_algos.models.backbone import Transformer, MLP
from bc_algos.models.policy_nets import BC_MLP, BC_Transformer
import bc_algos.utils.obs_utils as ObsUtils
import torch
from collections import OrderedDict

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
ObsUtils.register_encoder_core(LowDimCore, ObsUtils.Modality.LOW_DIM)
ObsUtils.register_encoder_core(ViTMAECore, ObsUtils.Modality.RGB)

In [3]:
B = 4
state_shape = [6,]
img_shape_vitmae = [3, 224, 224,]
img_shape_resnet18 = [3, 256, 256,]
output_shape = [1, 512,]
hidden_dims=[64, 18, 64,]

In [4]:
x_low_dim = 2*torch.rand(B, *state_shape)-1
x_rgb_vitmae = 2*torch.rand(B, *img_shape_vitmae)-1
x_rgb_resnet18 = 2*torch.rand(B, *img_shape_resnet18)-1

### Encoder Cores

In [5]:
# test encoder core with no specified output dim
low_dim_core = LowDimCore(input_shape=state_shape)
y_low_dim = low_dim_core(x_low_dim)
assert list(y_low_dim.shape) == [B, *state_shape]
assert list(y_low_dim.shape) == [B, *low_dim_core.output_shape]
# test encoder core with specified output dim
low_dim_core = LowDimCore(input_shape=state_shape, output_shape=output_shape, hidden_dims=hidden_dims)
y_low_dim = low_dim_core(x_low_dim)
assert list(y_low_dim.shape) == [B, *output_shape]
assert list(y_low_dim.shape) == [B, *low_dim_core.output_shape]
# test ViTMAE core
vitmae_core = ViTMAECore(input_shape=img_shape_vitmae)
y_rgb = vitmae_core(x_rgb_vitmae)
assert list(y_rgb.shape) == [B, 768,]
assert list(y_rgb.shape) == [B, *vitmae_core.output_shape]
# test ResNet core
resnet_core = ResNet18Core(input_shape=img_shape_resnet18)
y_rgb = resnet_core(x_rgb_resnet18)
assert list(y_rgb.shape) == [B, 2, 512,]
assert list(y_rgb.shape) == [B, *resnet_core.output_shape]

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [None]:
x_low_dim = 2*torch.rand(B, *state_shape)-1
x_rgb = 2*torch.rand(B, *img_shape)-1
xT_low_dim = 2*torch.rand(B, T, *state_shape)-1
xT_rgb = 2*torch.rand(B, T, *img_shape)-1
inputs = OrderedDict({"obs": {"robot0_eef_pos": x_low_dim}, "goal": {"agentview_image": x_rgb}})
inputsT = OrderedDict({"obs": {"robot0_eef_pos": xT_low_dim}, "goal": {"agentview_image": xT_rgb}})

### Observation Encoder

In [None]:
obs_enc = ObservationEncoder()
obs_enc.register_obs_key(
    obs_key="robot0_eef_pos",
    modality=ObsUtils.Modality.LOW_DIM,
    input_shape=state_shape,
)

goal_enc = ObservationEncoder()
goal_enc.register_obs_key(
    obs_key="agentview_image",
    modality=ObsUtils.Modality.RGB,
    input_shape=img_shape,
)

group_enc = ObservationGroupEncoder()
group_enc.register_obs_group(obs_group="obs", obs_enc=obs_enc)
group_enc.register_obs_group(obs_group="goal", obs_enc=goal_enc)

embed = group_enc(inputs)
assert list(embed.shape) == [B, group_enc.output_dim,]

### Backbone Models

In [None]:
embed_dim = group_enc.output_dim
x = 2*torch.rand(B, embed_dim)-1
xT = 2*torch.rand(B, T, embed_dim)-1

In [None]:
mlp = MLP(input_dim=embed_dim, output_dim=128)
y = mlp(x)
assert list(y.shape) == [B, 128,]

transformer = Transformer(input_dim=embed_dim, nlayers=2, nhead=2)
y = transformer(xT)
assert list(y.shape) == [B, T, embed_dim,]

### Policy Networks

In [None]:
act_dec = ActionDecoder(action_shape=output_shape, input_dim=mlp.output_dim)
bc_mlp = BC_MLP(obs_group_enc=group_enc, backbone=mlp, act_dec=act_dec)
assert list(bc_mlp(inputsT).shape) == [B, T, *output_shape]

act_dec = ActionDecoder(action_shape=output_shape, input_dim=transformer.output_dim)
bc_transformer = BC_Transformer(obs_group_enc=group_enc, backbone=transformer, act_dec=act_dec)
assert list(bc_transformer(inputsT).shape) == [B, T, *output_shape]

In [None]:
ObsUtils.unregister_encoder_core(ObsUtils.Modality.LOW_DIM)
ObsUtils.unregister_encoder_core(ObsUtils.Modality.RGB)