In [1]:
import warnings
warnings.filterwarnings('ignore')
import numpy as np

In [6]:
import ray
from ray.rllib import agents
ray.init(log_to_driver=False) # Skip or set to ignore if already called

from envs.point_mass_env import PointMassEnv 
from ray.tune.logger import pretty_print

2021-04-09 22:33:51,299	INFO services.py:1174 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8266[39m[22m


In [3]:
from ray.tune.registry import register_env

def env_creator(env_config):
    return PointMassEnv('maze2d-open-dense-v0')

register_env("point_mass_1", env_creator)

In [4]:
from ray.rllib.agents.callbacks import DefaultCallbacks

class CustomCallbacks(DefaultCallbacks):
    def on_episode_end(self, worker, base_env,
                       policies, episode,
                       **kwargs):
        success = int(episode.last_info_for()['success'])
        last_obs = episode.last_observation_for()
        pos = last_obs[:2]
        target = last_obs[4:6]
        dist = np.linalg.norm(pos - target)
        episode.custom_metrics['dist'] = dist
        episode.custom_metrics["success"] = success

## Their Model

In [5]:
config = {
          'num_workers': 1,
          'train_batch_size': 1000,
          'log_level': 'ERROR',
          'framework': 'torch',
          'callbacks': CustomCallbacks,
          'model': {
              'fcnet_hiddens': [128, 128],
          }}
trainer = agents.ppo.PPOTrainer(env='point_mass_1', config=config)
for i in range(1000):
    results = trainer.train()
    print(f"episode {i}, mean rew {results['episode_reward_mean']}," +
          f"success_mean {results['custom_metrics']['success_mean']}, dist {results['custom_metrics']['dist_mean']}")

RaySystemError: System error: Ray has not been started yet. You can start Ray with 'ray.init()'.

# Our Model

In [10]:
config = {
          'num_workers': 1,
          'train_batch_size': 200,
          'log_level': 'ERROR',
          'framework': 'torch',
          'callbacks': CustomCallbacks,
              'lambda': .99,
              'num_sgd_iter': 4,
              'lr': 1e-5,
              'vf_loss_coeff': .05,
              'entropy_coeff': .01,
              'clip_param': .2,
              'vf_clip_param': .2,
              'grad_clip': .5,
          'model': {
              'fcnet_hiddens': [64, 64],
          }}
trainer = agents.ppo.PPOTrainer(env='point_mass_1', config=config)
for i in range(1000):
    results = trainer.train()
    print(f"episode {i}, mean rew {results['episode_reward_mean']}," +
          f"success_mean {results['custom_metrics']['success_mean']}, dist {results['custom_metrics']['dist_mean']}")



episode 0, mean rew 1.0692866670458385,success_mean 0.0, dist 1.9715892846976069
episode 1, mean rew 0.8894788660171702,success_mean 0.0, dist 1.7873492818147683
episode 2, mean rew 0.7707202149347199,success_mean 0.0, dist 1.747219061249293
episode 3, mean rew 1.4870071118250645,success_mean 0.0, dist 1.6146722070907749
episode 4, mean rew 1.4751989947015547,success_mean 0.0, dist 1.6230163641327033
episode 5, mean rew 1.3336213128305683,success_mean 0.0, dist 1.5965751981046032
episode 6, mean rew 1.216758760241136,success_mean 0.0, dist 1.6749183910915009
episode 7, mean rew 1.1032890394815946,success_mean 0.0, dist 1.6637052698397556
episode 8, mean rew 1.0496191374475825,success_mean 0.0, dist 1.6709231376646045
episode 9, mean rew 1.100692563938615,success_mean 0.0, dist 1.6377839572556991
episode 10, mean rew 1.0230912662509597,success_mean 0.0, dist 1.6849154009430667
episode 11, mean rew 1.0660560557408691,success_mean 0.0, dist 1.6120337411991557
episode 12, mean rew 1.002128

episode 101, mean rew 1.2461697698656606,success_mean 0.0, dist 1.5057621882110346
episode 102, mean rew 1.2452612304584156,success_mean 0.0, dist 1.515415764849753
episode 103, mean rew 1.246289680431168,success_mean 0.0, dist 1.5108534689326043
episode 104, mean rew 1.2626394432501244,success_mean 0.0, dist 1.508515509044965
episode 105, mean rew 1.260262873966135,success_mean 0.0, dist 1.5078802411763073
episode 106, mean rew 1.2559027805817002,success_mean 0.0, dist 1.514632267456452
episode 107, mean rew 1.3047084425947046,success_mean 0.0, dist 1.51176339255784
episode 108, mean rew 1.314726969193822,success_mean 0.0, dist 1.5219233772410072
episode 109, mean rew 1.3116824226067294,success_mean 0.0, dist 1.5235500902645105
episode 110, mean rew 1.3028018296585038,success_mean 0.0, dist 1.5265747097833091
episode 111, mean rew 1.3131203781149978,success_mean 0.0, dist 1.518290222611409
episode 112, mean rew 1.3235313418206311,success_mean 0.0, dist 1.52616385670905
episode 113, me

episode 201, mean rew 1.6848208974493502,success_mean 0.0, dist 1.8203422107948564
episode 202, mean rew 1.6861296814105555,success_mean 0.0, dist 1.7989577358904179
episode 203, mean rew 1.6913557042115572,success_mean 0.0, dist 1.8080467853670985
episode 204, mean rew 1.6949922621009728,success_mean 0.0, dist 1.8252827901678563
episode 205, mean rew 1.6951911623208402,success_mean 0.0, dist 1.823357407097734
episode 206, mean rew 1.6793156383480083,success_mean 0.0, dist 1.8341055962938657
episode 207, mean rew 1.6753047729473738,success_mean 0.0, dist 1.836161191289773
episode 208, mean rew 1.6816033598580105,success_mean 0.0, dist 1.8523076478646785
episode 209, mean rew 1.6776109116419764,success_mean 0.0, dist 1.8771294519802206
episode 210, mean rew 1.670334391604258,success_mean 0.0, dist 1.8926127622147215
episode 211, mean rew 1.6751322143922434,success_mean 0.0, dist 1.9067347168883475
episode 212, mean rew 1.6827198039446147,success_mean 0.0, dist 1.9224638312448767
episode

episode 301, mean rew 1.735326896350896,success_mean 0.0, dist 2.2616406493463925
episode 302, mean rew 1.71279831691212,success_mean 0.0, dist 2.282541466787379
episode 303, mean rew 1.7203734005862206,success_mean 0.0, dist 2.29270567235481
episode 304, mean rew 1.7339406481722248,success_mean 0.0, dist 2.291837101359293
episode 305, mean rew 1.7370140549687418,success_mean 0.0, dist 2.304319831918282
episode 306, mean rew 1.7408194566251423,success_mean 0.0, dist 2.304994964568728
episode 307, mean rew 1.7422918560962275,success_mean 0.0, dist 2.3038252981972924
episode 308, mean rew 1.7256818326021184,success_mean 0.0, dist 2.3111561813576205
episode 309, mean rew 1.7273682818721243,success_mean 0.0, dist 2.3221491359412756
episode 310, mean rew 1.7405505454856458,success_mean 0.0, dist 2.3224438688791564
episode 311, mean rew 1.7522566680353635,success_mean 0.0, dist 2.3152238202545505
episode 312, mean rew 1.7591288029019017,success_mean 0.0, dist 2.3148139435892237
episode 313, 

KeyboardInterrupt: 

# Our Model, their Hparams

In [13]:
config = {
          'num_workers': 1,
          'log_level': 'ERROR',
          'framework': 'torch',
          'callbacks': CustomCallbacks,
             'train_batch_size': 1000,
#               'lambda': .99,
#               'num_sgd_iter': 4,
#               'lr': 1e-5,
#               'vf_loss_coeff': .05,
#               'entropy_coeff': .01,
#               'clip_param': .2,
#               'vf_clip_param': .2,
#               'grad_clip': .5,
          'model': {
              'fcnet_hiddens': [64, 64],
          }}
trainer = agents.ppo.PPOTrainer(env='point_mass_1', config=config)
for i in range(1000):
    results = trainer.train()
    print(f"episode {i}, mean rew {results['episode_reward_mean']}," +
          f"success_mean {results['custom_metrics']['success_mean']}, dist {results['custom_metrics']['dist_mean']}")



episode 0, mean rew 0.6130870731157936,success_mean 0.0, dist 1.5921449389108202
episode 1, mean rew 1.2603985894650935,success_mean 0.0, dist 1.3978968519769315
episode 2, mean rew 1.6397175687765189,success_mean 0.0, dist 1.2829645914351056
episode 3, mean rew 1.8413079892639617,success_mean 0.0, dist 1.270840889907793
episode 4, mean rew 2.156830142168495,success_mean 0.0, dist 1.2729382458015874
episode 5, mean rew 2.3076811655311666,success_mean 0.0, dist 1.4099414971591138
episode 6, mean rew 2.4904063874259212,success_mean 0.0, dist 1.462403880433009
episode 7, mean rew 2.6058275301371445,success_mean 0.0, dist 1.5258415119505744
episode 8, mean rew 2.7845356533821137,success_mean 0.0, dist 1.5302406507837254
episode 9, mean rew 2.9226877111560543,success_mean 0.0, dist 1.509508168439038
episode 10, mean rew 3.0721563019944016,success_mean 0.0, dist 1.5143305709169286
episode 11, mean rew 3.2288280920816184,success_mean 0.0, dist 1.5142189454024064
episode 12, mean rew 3.3611624

episode 101, mean rew 8.345373523075823,success_mean 0.0, dist 0.2951726270229723
episode 102, mean rew 8.356438734781317,success_mean 0.0, dist 0.2942270329889533
episode 103, mean rew 8.39247823321373,success_mean 0.0, dist 0.29020816374299174
episode 104, mean rew 8.404170441718131,success_mean 0.0, dist 0.28588151144984125
episode 105, mean rew 8.42185639306079,success_mean 0.0, dist 0.2838739543432075
episode 106, mean rew 8.430484393961484,success_mean 0.0, dist 0.2827646849160595
episode 107, mean rew 8.439319391089974,success_mean 0.0, dist 0.2800070430440916
episode 108, mean rew 8.450960452982194,success_mean 0.0, dist 0.27784148786707064
episode 109, mean rew 8.459675586075612,success_mean 0.0, dist 0.2716126293272808
episode 110, mean rew 8.468959933220281,success_mean 0.0, dist 0.266991552798083
episode 111, mean rew 8.477030998035438,success_mean 0.0, dist 0.2635744936802888
episode 112, mean rew 8.484304935379548,success_mean 0.0, dist 0.2604694787992328
episode 113, mea

episode 200, mean rew 8.670010692981137,success_mean 0.0, dist 0.19653951107540757
episode 201, mean rew 8.677397130525575,success_mean 0.0, dist 0.19548016959439823
episode 202, mean rew 8.684561203712027,success_mean 0.0, dist 0.19581341275780192
episode 203, mean rew 8.696182679075651,success_mean 0.0, dist 0.19478019982661463
episode 204, mean rew 8.707525275362809,success_mean 0.0, dist 0.1942658968538919
episode 205, mean rew 8.721975572524888,success_mean 0.0, dist 0.19394566038316235
episode 206, mean rew 8.730244952162838,success_mean 0.0, dist 0.19342177062203753
episode 207, mean rew 8.73648376552747,success_mean 0.0, dist 0.1923770757022396
episode 208, mean rew 8.743030931343798,success_mean 0.0, dist 0.19066925823995562
episode 209, mean rew 8.751989838489926,success_mean 0.0, dist 0.18828599651582345
episode 210, mean rew 8.764782269033804,success_mean 0.0, dist 0.18695115174402197
episode 211, mean rew 8.78521710660791,success_mean 0.0, dist 0.18536280970484625
episode 


KeyboardInterrupt



# HParam Sweep

In [17]:
config = { # THEIRS: train_batch_size, num_sgd_iter, lr, entropy, clip
          'num_workers': 1,
          'log_level': 'ERROR',
          'framework': 'torch',
          'callbacks': CustomCallbacks,
          'train_batch_size': 1000,
              'lambda': .99,
#               'num_sgd_iter': 4,
#               'lr': 1e-5,
              'vf_loss_coeff': .05,
#               'entropy_coeff': .01,
#               'clip_param': .2,
#               'vf_clip_param': .2,
              'grad_clip': .5,
          'model': {
              'fcnet_hiddens': [128, 128, 128],
          }}
trainer = agents.ppo.PPOTrainer(env='point_mass_1', config=config)
for i in range(1000):
    results = trainer.train()
    print(f"episode {i}, mean rew {results['episode_reward_mean']}," +
          f"success_mean {results['custom_metrics']['success_mean']}, dist {results['custom_metrics']['dist_mean']}")



episode 0, mean rew 0.6786170131035255,success_mean 0.0, dist 1.9146594153270586
episode 1, mean rew 0.8691278398070891,success_mean 0.0, dist 1.7980147369420094
episode 2, mean rew 1.5717367175124572,success_mean 0.1, dist 1.5715564388750791
episode 3, mean rew 1.8888202755523775,success_mean 0.07692307692307693, dist 1.473106687198602
episode 4, mean rew 2.278760592825362,success_mean 0.06060606060606061, dist 1.4602916886052058
episode 5, mean rew 2.695640443949792,success_mean 0.05, dist 1.371014686746001
episode 6, mean rew 2.94998418856758,success_mean 0.043478260869565216, dist 1.363091157403699
episode 7, mean rew 3.1563118683216977,success_mean 0.03773584905660377, dist 1.312200164911673
episode 8, mean rew 3.3545484502806935,success_mean 0.03333333333333333, dist 1.2964396418590842
episode 9, mean rew 3.5326509792568284,success_mean 0.030303030303030304, dist 1.2693076212726981
episode 10, mean rew 3.7078602976985096,success_mean 0.0273972602739726, dist 1.250226518656329
epi

KeyboardInterrupt: 