In [None]:
from bc_algos.models.obs_core import EncoderCore, VisualCore
from bc_algos.models.obs_nets import ObservationEncoder, ObservationGroupEncoder
import bc_algos.utils.obs_utils as ObsUtils
import torch
from collections import OrderedDict

In [None]:
batch_size = 4
state_shape = [3,]
img_shape = [3, 224, 162,]
output_shape = [4, 4,]
hidden_dim=[64, 18, 64,]
x_low_dim = 2*torch.rand(batch_size, *state_shape)-1
x_rgb = 2*torch.rand(batch_size, *img_shape)-1

### Encoder Cores

In [None]:
# test encoder core with no specified output shape
enc_core = EncoderCore(input_shape=state_shape)
y_low_dim = enc_core(x_low_dim)
assert list(y_low_dim.shape) == [batch_size,]+state_shape
assert list(y_low_dim.shape) == [batch_size,]+enc_core.output_shape
# test encoder core with specified output shape
enc_core = EncoderCore(input_shape=state_shape, output_shape=output_shape, hidden_dim=hidden_dim)
y_low_dim = enc_core(x_low_dim)
assert list(y_low_dim.shape) == [batch_size,]+output_shape
assert list(y_low_dim.shape) == [batch_size,]+enc_core.output_shape
# test visual core with no specified output shape
visual_core = VisualCore(input_shape=img_shape)
y_rgb = visual_core(x_rgb)
assert list(y_rgb.shape) == [batch_size, 768,]
assert list(y_rgb.shape) == [batch_size,]+visual_core.output_shape
# test visual core with specified output shape
visual_core = VisualCore(input_shape=img_shape, output_shape=output_shape, hidden_dim=hidden_dim)
y_rgb = visual_core(x_rgb)
assert list(y_rgb.shape) == [batch_size,]+output_shape
assert list(y_rgb.shape) == [batch_size,]+visual_core.output_shape

### Observation/Group Encoder 

In [None]:
ObsUtils.register_encoder_core(EncoderCore, ObsUtils.Modality.LOW_DIM)
ObsUtils.register_encoder_core(VisualCore, ObsUtils.Modality.RGB)

obs_enc = ObservationEncoder()
obs_enc.register_obs_key(
    obs_key="robot0_eef_pos",
    modality=ObsUtils.Modality.LOW_DIM,
    input_shape=state_shape,
)

goal_enc = ObservationEncoder()
goal_enc.register_obs_key(
    obs_key="agentview_image",
    modality=ObsUtils.Modality.RGB,
    input_shape=img_shape,
)

group_enc = ObservationGroupEncoder()
group_enc.register_obs_group(obs_group="obs", obs_enc=obs_enc)
group_enc.register_obs_group(obs_group="goal", obs_enc=goal_enc)

inputs = OrderedDict({"obs": {"robot0_eef_pos": x_low_dim}, "goal": {"agentview_image": x_rgb}})
y = group_enc(inputs)
assert list(y.shape) == [batch_size,]+group_enc.output_shape

ObsUtils.unregister_encoder_core(ObsUtils.Modality.LOW_DIM)
ObsUtils.unregister_encoder_core(ObsUtils.Modality.RGB)