In [14]:
import os
import torch
from trainers.torch.networks import SplitValueSharedActorCritic
from mlagents_envs.base_env import ObservationSpec, DimensionProperty, ObservationType
from mlagents.trainers.settings import NetworkSettings, TrainerSettings
from mlagents_envs.base_env import ActionSpec
from trainers.policy.torch_policy import TorchPolicy
from mlagents_envs.base_env import BehaviorSpec

In [3]:
load_path = "/home/rmarr/Projects/visibility-game-env/results/dualrun1221/Seeker/checkpoint.pt"
saved_state_dict = torch.load(load_path)

In [16]:
seed = 5404
observation_specs = [ObservationSpec(
            name="position_observation",
            shape=(3,),  # 3D vector
            dimension_property=(DimensionProperty.NONE,),  # Must be a tuple
            observation_type=ObservationType.DEFAULT
        )]
action_spec = ActionSpec(
    continuous_size=0,
    discrete_branches=(5,)
)
behavior_spec = BehaviorSpec(
    observation_specs=observation_specs,
    action_spec=action_spec
)
trainer_settings = TrainerSettings(
    dual_critic=True
)
# position_obs_spec = [ObservationSpec(
#             name="position_observation",
#             shape=(3,),  # 3D vector
#             dimension_property=(DimensionProperty.NONE,),  # Must be a tuple
#             observation_type=ObservationType.DEFAULT
#         )]
# crumbs_obs_spec = [ObservationSpec(
#             name="crumbs_observation",
#             shape=(9,),  # 3D vector
#             dimension_property=(DimensionProperty.NONE,),  # Must be a tuple
#             observation_type=ObservationType.DEFAULT
#         )]

network_settings = NetworkSettings(
    deterministic=False,
    memory=None,
    hidden_units=128,
    num_layers=2,
)

stream_names = ["default"]
conditional_sigma = False
tanh_squash = False
load_critic_only = "position_only"

In [17]:
policy = TorchPolicy(
    seed=0,
    behavior_spec=behavior_spec,
    trainer_settings=trainer_settings,
    tanh_squash=False,
    separate_critic=True,
    condition_sigma_on_obs=False,
    load_critic_only='default'
)

In [24]:
modules = policy.get_modules()

In [25]:
for name, mod in modules.items():
    try:
        if load_critic_only:

            # only use critic network
            if "policy" in name.lower():
                continue
            
            # only use position_network from critic network
            if "position_only" in load_critic_only.lower() and "optimizer:critic" in name.lower():
                position_state_dict = {k: v for k, v in saved_state_dict[name].items() 
                                    if 'position_network' in k}
                mod.load_state_dict(position_state_dict, strict=False)
                continue

        if isinstance(mod, torch.nn.Module):
            missing_keys, unexpected_keys = mod.load_state_dict(
                saved_state_dict[name], strict=False
            )
            if missing_keys:
                print(
                    f"Did not find these keys {missing_keys} in checkpoint. Initializing."
                )
            if unexpected_keys:
                print.warning(
                    f"Did not expect these keys {unexpected_keys} in checkpoint. Ignoring."
                )
        else:
            # If module is not an nn.Module, try to load as one piece
            mod.load_state_dict(saved_state_dict[name])

    # KeyError is raised if the module was not present in the last run but is being
    # accessed in the saved_state_dict.
    # ValueError is raised by the optimizer's load_state_dict if the parameters have
    # have changed. Note, the optimizer uses a completely different load_state_dict
    # function because it is not an nn.Module.
    # RuntimeError is raised by PyTorch if there is a size mismatch between modules
    # of the same name. This will still partially assign values to those layers that
    # have not changed shape.
    except (KeyError, ValueError, RuntimeError) as err:
        print(f"Failed to load for module {name}. Initializing")
        print(f"Module loading error : {err}")    

In [28]:
inputs = []
for x in range(-9, 11, 1):
    for y in range(-9, 11, 1):
        inputs.append([x-0.5, 0.5, y-0.5])
inputs = [torch.tensor(inputs)]

-9.5
-8.5
-7.5
-6.5
-5.5
-4.5
-3.5
-2.5
-1.5
-0.5
0.5
1.5
2.5
3.5
4.5
5.5
6.5
7.5
8.5
9.5


In [None]:
inputs = [torch.tensor([[-9.5000,  0.5000, -9.5000,  ...,  0.0000,  0.0000,  0.0000],
        [-9.5000,  0.5000, -8.5000,  ...,  0.0000,  0.0000,  0.0000],
        [-9.5000,  0.5000, -7.5000,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 9.5000,  0.5000,  7.5000,  ...,  0.0000,  0.0000,  0.0000],
        [ 9.5000,  0.5000,  8.5000,  ...,  0.0000,  0.0000,  0.0000],
        [ 9.5000,  0.5000,  9.5000,  ...,  0.0000,  0.0000,  0.0000]])]

In [26]:
policy.actor.critic_pass_position(inputs)

NameError: name 'inputs' is not defined