In [10]:
from bc_algos.models.obs_core import EncoderCore, VisualCore
from bc_algos.models.obs_nets import ObservationEncoder, ObservationGroupEncoder, ActionDecoder
from bc_algos.models.backbone import Transformer
import bc_algos.utils.obs_utils as ObsUtils
import torch
from collections import OrderedDict

In [None]:
B = 4
state_shape = [3,]
img_shape = [3, 224, 162,]
output_shape = [4, 4,]
hidden_dims=[64, 18, 64,]
x_low_dim = 2*torch.rand(B, *state_shape)-1
x_rgb = 2*torch.rand(B, *img_shape)-1

### Encoder Cores

In [None]:
# test encoder core with no specified output shape
enc_core = EncoderCore(input_shape=state_shape)
y_low_dim = enc_core(x_low_dim)
assert list(y_low_dim.shape) == [B, *state_shape]
assert list(y_low_dim.shape) == [B, *enc_core.output_shape]
# test encoder core with specified output shape
enc_core = EncoderCore(input_shape=state_shape, output_shape=output_shape, hidden_dims=hidden_dims)
y_low_dim = enc_core(x_low_dim)
assert list(y_low_dim.shape) == [B, *output_shape]
assert list(y_low_dim.shape) == [B, *enc_core.output_shape]
# test visual core with no specified output shape
visual_core = VisualCore(input_shape=img_shape)
y_rgb = visual_core(x_rgb)
assert list(y_rgb.shape) == [B, 768,]
assert list(y_rgb.shape) == [B, *visual_core.output_shape]
# test visual core with specified output shape
visual_core = VisualCore(input_shape=img_shape, output_shape=output_shape, hidden_dims=hidden_dims)
y_rgb = visual_core(x_rgb)
assert list(y_rgb.shape) == [B, *output_shape]
assert list(y_rgb.shape) == [B, *visual_core.output_shape]

### Observation Encoder and Action Decoder

In [None]:
ObsUtils.register_encoder_core(EncoderCore, ObsUtils.Modality.LOW_DIM)
ObsUtils.register_encoder_core(VisualCore, ObsUtils.Modality.RGB)

obs_enc = ObservationEncoder()
obs_enc.register_obs_key(
    obs_key="robot0_eef_pos",
    modality=ObsUtils.Modality.LOW_DIM,
    input_shape=state_shape,
)

goal_enc = ObservationEncoder()
goal_enc.register_obs_key(
    obs_key="agentview_image",
    modality=ObsUtils.Modality.RGB,
    input_shape=img_shape,
)

group_enc = ObservationGroupEncoder()
group_enc.register_obs_group(obs_group="obs", obs_enc=obs_enc)
group_enc.register_obs_group(obs_group="goal", obs_enc=goal_enc)

inputs = OrderedDict({"obs": {"robot0_eef_pos": x_low_dim}, "goal": {"agentview_image": x_rgb}})
y = group_enc(inputs)
assert list(y.shape) == [B, group_enc.output_dim,]

act_dec = ActionDecoder(action_shape=output_shape, input_dim=group_enc.output_dim)
action = act_dec(y)
assert list(action.shape) == [B, *output_shape]

ObsUtils.unregister_encoder_core(ObsUtils.Modality.LOW_DIM)
ObsUtils.unregister_encoder_core(ObsUtils.Modality.RGB)

In [11]:
B = 4
T = 10
embed_dim = 128
x = 2*torch.rand(B, T, embed_dim)-1

### Backbone Models

In [14]:
transformer = Transformer(input_dim=embed_dim, nlayers=2, nhead=2)
y = transformer(x)
assert list(y.shape) == [B, T, embed_dim,]