In [1]:
import gym, ray
from gym import spaces
import numpy as np
from scipy.spatial import distance
# import pdb
import MultiAgentEnv as ma_env

from policy import PolicyNetwork
from ray.rllib.utils.annotations import override
from ray.rllib.models import ModelCatalog
from ray.tune.registry import register_env
from ray import tune
import ray.rllib.agents.ppo as ppo
import os
from ray.tune.logger import pretty_print
from ray.tune.logger import Logger


from typing import Dict
from ray import tune
from ray.rllib.agents.callbacks import DefaultCallbacks
from ray.rllib.env import BaseEnv
from ray.rllib.evaluation import MultiAgentEpisode, RolloutWorker
from ray.rllib.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch

from datetime import datetime



In [15]:
model_C = PolicyNetwork
model_H = PolicyNetwork
ModelCatalog.register_custom_model("modelC", model_C)
ModelCatalog.register_custom_model("modelH", model_H)

act_space = spaces.Box(low=-0.1,high=0.1, shape=(3,))
obs_space = spaces.Box(low=-10000,high=10000, shape=(768,))

def gen_policy(atom):
    model = "model{}".format(atom)
    config = {
        "model": {
            "custom_model": model,
        },
    }
    return (None, obs_space, act_space, config)



policies = {"policy_C": gen_policy("C"),"policy_H": gen_policy("H")}
policy_ids = list(policies.keys())

def policy_mapping_fn(agent_id,  **kwargs):
    if agent_id.startswith("C"):
        pol_id = "policy_C"
    else:
        pol_id = "policy_H"
    return pol_id

def env_creator(env_config):
    return ma_env.MA_env(env_config)  # return an env instance

register_env("MA_env", env_creator)

config = ppo.DEFAULT_CONFIG.copy()

config["multiagent"] = {
        "policy_mapping_fn": policy_mapping_fn,
        "policies": policies,
        "policies_to_train": ["policy_C", "policy_H"],
    }

config["log_level"] = "WARN"
config["framework"] = "torch"
config["num_gpus"] =  int(os.environ.get("RLLIB_NUM_GPUS", "0"))
config["env_config"] =  {"atoms":["C", "H", "H", "H", "H"]}
config["rollout_fragment_length"] = 16

In [5]:
ray.init()

2021-08-04 20:53:39,421	INFO services.py:1330 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8265[39m[22m


{'node_ip_address': '172.16.0.114',
 'raylet_ip_address': '172.16.0.114',
 'redis_address': '172.16.0.114:6379',
 'object_store_address': '/tmp/ray/session_2021-08-04_20-53-36_742007_12874/sockets/plasma_store',
 'raylet_socket_name': '/tmp/ray/session_2021-08-04_20-53-36_742007_12874/sockets/raylet',
 'webui_url': '127.0.0.1:8265',
 'session_dir': '/tmp/ray/session_2021-08-04_20-53-36_742007_12874',
 'metrics_export_port': 61087,
 'node_id': '87116b37d9082c1f12b52c0185841d80d402e577b183cc54149c3396'}

In [16]:
agent = ppo.PPOTrainer(config=config, env="MA_env")



In [17]:
agent.restore("/home/sarvesh211999/ray_results/PPO_MA_env_2021-08-04_18-33-53ks1kndl_/checkpoint_000001/checkpoint-1")

2021-08-04 21:11:54,733	INFO trainable.py:379 -- Restored on 172.16.0.114 from checkpoint: /home/sarvesh211999/ray_results/PPO_MA_env_2021-08-04_18-33-53ks1kndl_/checkpoint_000001/checkpoint-1
2021-08-04 21:11:54,734	INFO trainable.py:387 -- Current state after restoring: {'_iteration': 1, '_timesteps_total': None, '_time_total': 263.95628547668457, '_episodes_total': 15}


In [18]:
env = ma_env.MA_env({})

In [19]:
obs = env.reset()

Reset called
forces = [4.691725675958892, -2.977502852185873, 4.385788820676909, -0.8213429966917942, -0.28131015568930023, -0.5851022919157759, -4.413608403535599, 2.0069012055883864, 1.1928383410509424, -2.3557062108455606, -1.2869259594802276, -4.341285085981355, 2.9049156820448228, 2.567220027594209, -0.6727541760726693] 	 energies = -17.74350464157722


In [20]:
agent.compute_action(obs,full_fetch=True)

KeyError: 'default_policy'

In [21]:
action = {}
for agent_id, agent_obs in obs.items():
    policy_id = config['multiagent']['policy_mapping_fn'](agent_id)
    action[agent_id] = agent.compute_action(agent_obs, policy_id=policy_id)

In [22]:
action

{'C_1': array([-0.1, -0.1,  0.1], dtype=float32),
 'H_1': array([-0.1,  0.1,  0.1], dtype=float32),
 'H_2': array([-0.1,  0.1,  0.1], dtype=float32),
 'H_3': array([-0.1, -0.1,  0.1], dtype=float32),
 'H_4': array([0.1, 0.1, 0.1], dtype=float32)}