In [None]:
from bc_algos.models.obs_core import LowDimCore, ViTMAECore, ResNet18Core
from bc_algos.models.obs_nets import ActionDecoder, ObservationGroupEncoder
from bc_algos.models.backbone import Transformer, MLP
from bc_algos.models.policy_nets import BC_MLP, BC_Transformer
from bc_algos.models.loss import DiscountedMSELoss, DiscountedL1Loss
import bc_algos.utils.obs_utils as ObsUtils
import bc_algos.utils.tensor_utils as TensorUtils
import bc_algos.utils.constants as Const
import torch
import torch.nn as nn
from collections import OrderedDict

### Encoder Core Unit Tests

In [None]:
B = 4
state_shape = [7]
img_shape_vitmae = [3, 224, 224]
img_shape_resnet18 = [3, 256, 256]
output_shape = [512]
hidden_dims=[64, 18, 64]

x_low_dim = 2*torch.rand(B, *state_shape)-1
x_rgb_vitmae = 2*torch.rand(B, *img_shape_vitmae)-1
x_rgb_resnet18 = 2*torch.rand(B, *img_shape_resnet18)-1

In [None]:
# test encoder core with no specified output dim
low_dim_core = LowDimCore(input_shape=state_shape)
y_low_dim = low_dim_core(x_low_dim)
assert list(y_low_dim.shape) == [B, *state_shape]
assert list(y_low_dim.shape) == [B, *low_dim_core.output_shape]

# test encoder core with specified output dim
torch.manual_seed(0)
low_dim_core = LowDimCore(input_shape=state_shape, output_shape=output_shape, hidden_dims=hidden_dims)
y_low_dim = low_dim_core(x_low_dim)
assert list(y_low_dim.shape) == [B, *output_shape]
assert list(y_low_dim.shape) == [B, *low_dim_core.output_shape]

# test ViTMAE core
vitmae_core = ViTMAECore(input_shape=img_shape_vitmae)
y_rgb = vitmae_core(x_rgb_vitmae)
assert list(y_rgb.shape) == [B, 768]
assert list(y_rgb.shape) == [B, *vitmae_core.output_shape]

# test ResNet core
resnet_core = ResNet18Core(input_shape=img_shape_resnet18)
y_rgb = resnet_core(x_rgb_resnet18)
assert list(y_rgb.shape) == [B, 64, 512]
assert list(y_rgb.shape) == [B, *resnet_core.output_shape]

### Observation Encoder Unit Tests

In [None]:
low_dim_key = "low_dim"
rgb_key = "rgb"
obs_group_to_keys = OrderedDict({"obs": [low_dim_key, rgb_key], "goal": [rgb_key]})

inputs = OrderedDict({"obs": {low_dim_key: x_low_dim, rgb_key: x_rgb_resnet18}, "goal": {rgb_key: x_rgb_resnet18}})

In [None]:
ObsUtils.register_encoder_core_class(core=LowDimCore, modality=Const.Modality.LOW_DIM)
ObsUtils.register_encoder_core_class(core=ResNet18Core, modality=Const.Modality.RGB)

torch.manual_seed(0)
ObsUtils.register_encoder_core(
    obs_key=low_dim_key, 
    modality=Const.Modality.LOW_DIM, 
    input_shape=state_shape,
    output_shape=output_shape,
    hidden_dims=hidden_dims,
)
ObsUtils.register_encoder_core(
    obs_key=rgb_key, 
    modality=Const.Modality.RGB, 
    input_shape=img_shape_resnet18,
)

In [None]:
obs_group_enc = ObservationGroupEncoder(obs_group_to_key=obs_group_to_keys)

latent_dict = obs_group_enc(inputs)
obs_latent = latent_dict["obs"]
goal_latent = latent_dict["goal"]
assert list(obs_latent.shape) == [B, obs_group_enc.output_dim["obs"]]
assert list(goal_latent.shape) == [B, obs_group_enc.output_dim["goal"]]

obs_latent = obs_latent.view(B, -1, *output_shape)
goal_latent = goal_latent.view(B, -1, *output_shape)
assert torch.equal(obs_latent[:, 0, :], y_low_dim)
assert torch.equal(obs_latent[:, 1:, :], y_rgb)
assert torch.equal(goal_latent, y_rgb)

### Backbone Unit Tests

In [None]:
T_src = 4
T_tgt = 2
embed_dim = 512
output_dim = 128

x = 2*torch.randn(B, obs_group_enc.output_dim["obs"] + obs_group_enc.output_dim["goal"])-1
x_src = 2*torch.rand(B, T_src, embed_dim)-1
x_tgt = 2*torch.rand(B, T_tgt, embed_dim)-1

In [None]:
mlp = MLP(embed_dim=obs_group_enc.output_dim["obs"] + obs_group_enc.output_dim["goal"], output_dim=output_dim)
y = mlp(x)
assert list(y.shape) == [B, output_dim]

transformer = Transformer(embed_dim=embed_dim)
y = transformer(x_src, x_tgt)
assert list(y.shape) == [B, T_tgt, embed_dim]

### Discounted Loss Unit Tests

In [None]:
src = 2*torch.randn(B, T_tgt, embed_dim)-1
tgt = 2*torch.randn(B, T_tgt, embed_dim)-1
mask = torch.ones(B, T_tgt).float()
mask[:, -1] = torch.zeros(B)

In [None]:
l1_loss = nn.L1Loss()
disc_l1_loss = DiscountedL1Loss(discount=1.0)
l2_loss = nn.MSELoss()
disc_l2_loss = DiscountedMSELoss(discount=1.0)

assert torch.isclose(l1_loss(src, tgt), disc_l1_loss(src, tgt))
assert torch.isclose(l2_loss(src, tgt), disc_l2_loss(src, tgt))
assert torch.isclose(l1_loss(src[:, :-1, :], tgt[:, :-1, :]), disc_l1_loss(src, tgt, mask))
assert torch.isclose(l2_loss(src[:, :-1, :], tgt[:, :-1, :]), disc_l2_loss(src, tgt, mask))

disc_l1_loss.discount = 0.9
disc_l2_loss.discount = 0.9
assert disc_l1_loss(src, tgt, mask).dim() == 0
assert disc_l2_loss(src, tgt, mask).dim() == 0

### Policy Unit Tests

In [None]:
action_shape = [7]
T_obs = 4
T_goal = 3

In [None]:
inputs = TensorUtils.to_sequence(inputs)

action_dec = ActionDecoder(action_shape=action_shape, input_dim=mlp.output_dim)
bc_mlp = BC_MLP(obs_group_enc=obs_group_enc, backbone=mlp, action_dec=action_dec)
actions = bc_mlp(inputs)
assert list(actions.shape) == [B, 1, *action_shape]

In [None]:
inputs["obs"] = TensorUtils.repeat_seq(inputs["obs"], T_obs)
inputs["goal"] = TensorUtils.repeat_seq(inputs["goal"], T_goal)

action_dec = ActionDecoder(action_shape=action_shape, input_dim=transformer.output_dim)
bc_transformer = BC_Transformer(
    obs_group_enc=obs_group_enc, 
    backbone=transformer, 
    action_dec=action_dec, 
    action_chunk=T_tgt
)
actions = bc_transformer(inputs)
assert list(actions.shape) == [B, T_tgt, *action_shape]

In [None]:
ObsUtils.unregister_encoder_core_class(modality=Const.Modality.LOW_DIM)
ObsUtils.unregister_encoder_core_class(modality=Const.Modality.RGB)

ObsUtils.unregister_encoder_core(obs_key=low_dim_key)
ObsUtils.unregister_encoder_core(obs_key=rgb_key)