In [30]:
import warnings
warnings.filterwarnings('ignore')
import numpy as np

In [7]:
import ray
from ray.rllib import agents
# ray.init(log_to_driver=False) # Skip or set to ignore if already called

from envs.point_mass_env import PointMassEnv 
from ray.tune.logger import pretty_print

In [3]:
from ray.tune.registry import register_env

def env_creator(env_config):
    return PointMassEnv('maze2d-open-dense-v0')

register_env("point_mass_1", env_creator)

In [33]:
from ray.rllib.agents.callbacks import DefaultCallbacks

class CustomCallbacks(DefaultCallbacks):
    def on_episode_end(self, worker, base_env,
                       policies, episode,
                       **kwargs):
        success = int(episode.last_info_for()['success'])
        last_obs = episode.last_observation_for()
        pos = last_obs[:2]
        target = last_obs[4:6]
        dist = np.linalg.norm(pos - target)
        episode.custom_metrics['dist'] = dist
        episode.custom_metrics["success"] = success

## Their Model

In [34]:
config = {
          'num_workers': 1,
          'train_batch_size': 1000,
          'log_level': 'ERROR',
          'framework': 'torch',
          'callbacks': CustomCallbacks,
          'model': {
              'fcnet_hiddens': [128, 128],
          }}
trainer = agents.ppo.PPOTrainer(env='point_mass_1', config=config)
for i in range(1000):
    results = trainer.train()
    print(f"episode {i}, mean rew {results['episode_reward_mean']}," +
          f"success_mean {results['custom_metrics']['success_mean']}, dist {results['custom_metrics']['dist_mean']}")



episode 0, mean rew 0.6658158570884053,success_mean 0.0, dist 1.8483879893430792
episode 1, mean rew 1.2314404386166329,success_mean 0.0, dist 1.5610352158508474
episode 2, mean rew 1.6477773688427302,success_mean 0.0, dist 1.563366200375389
episode 3, mean rew 1.8504928162101968,success_mean 0.0, dist 1.4955678309802978
episode 4, mean rew 1.9593957100409085,success_mean 0.0, dist 1.59955965123598
episode 5, mean rew 1.9146525730481418,success_mean 0.0, dist 1.5073492052559156
episode 6, mean rew 2.0049054781002154,success_mean 0.0, dist 1.5086795796481445
episode 7, mean rew 2.0921235966823946,success_mean 0.0, dist 1.6612538182217316
episode 8, mean rew 2.270401313213012,success_mean 0.0, dist 1.664423114581528
episode 9, mean rew 2.3849066783655797,success_mean 0.0, dist 1.6679225662937367
episode 10, mean rew 2.5252396106105532,success_mean 0.0, dist 1.6926331242650496
episode 11, mean rew 2.635495202837774,success_mean 0.0, dist 1.6913264859054629
episode 12, mean rew 2.785791452

episode 101, mean rew 8.155858596104782,success_mean 0.21, dist 0.14694831714016815
episode 102, mean rew 8.16350362565229,success_mean 0.2, dist 0.14951498603869248
episode 103, mean rew 8.182317428332006,success_mean 0.19, dist 0.15345117107158246
episode 104, mean rew 8.206676192478351,success_mean 0.19, dist 0.1528815639049516
episode 105, mean rew 8.224802396780252,success_mean 0.17, dist 0.15878157248767222
episode 106, mean rew 8.242522133877435,success_mean 0.17, dist 0.16276811969523486
episode 107, mean rew 8.244389685974427,success_mean 0.16, dist 0.16786721919393932
episode 108, mean rew 8.264915912672224,success_mean 0.16, dist 0.17017872366388345
episode 109, mean rew 8.27510537789664,success_mean 0.16, dist 0.1699003941385565
episode 110, mean rew 8.293246140402344,success_mean 0.12, dist 0.17194141311147085
episode 111, mean rew 8.298718355037598,success_mean 0.11, dist 0.1762240462437103
episode 112, mean rew 8.324081888961251,success_mean 0.08, dist 0.1863083668997446

episode 200, mean rew 8.748219946154316,success_mean 0.49, dist 0.0977245832482682
episode 201, mean rew 8.736959675723105,success_mean 0.52, dist 0.09557703748520749
episode 202, mean rew 8.716447217862871,success_mean 0.56, dist 0.09396369192401148
episode 203, mean rew 8.700911750993376,success_mean 0.55, dist 0.09666474850996948
episode 204, mean rew 8.6934621158343,success_mean 0.53, dist 0.09900547589754544
episode 205, mean rew 8.68612840981226,success_mean 0.55, dist 0.09854984291539345
episode 206, mean rew 8.675717856165576,success_mean 0.52, dist 0.10110991976019137
episode 207, mean rew 8.662218046180687,success_mean 0.54, dist 0.1003372864734671
episode 208, mean rew 8.656125989604178,success_mean 0.57, dist 0.09782356026111774
episode 209, mean rew 8.652636130822412,success_mean 0.52, dist 0.10148247419107839
episode 210, mean rew 8.642591383807208,success_mean 0.53, dist 0.10326482623887748
episode 211, mean rew 8.640315751415583,success_mean 0.54, dist 0.102263738007690

episode 298, mean rew 8.40480130042934,success_mean 0.88, dist 0.05804475174004984
episode 299, mean rew 8.4077293944011,success_mean 0.88, dist 0.057880140770624255
episode 300, mean rew 8.434973959238592,success_mean 0.91, dist 0.0539626990220696
episode 301, mean rew 8.442423602057609,success_mean 0.89, dist 0.05437712588530539
episode 302, mean rew 8.43695205535938,success_mean 0.89, dist 0.05482540507557076
episode 303, mean rew 8.42914264353846,success_mean 0.89, dist 0.055282467291534194
episode 304, mean rew 8.414636754752946,success_mean 0.87, dist 0.05989697982995082
episode 305, mean rew 8.40333351525164,success_mean 0.86, dist 0.05964877578404543
episode 306, mean rew 8.422979202066774,success_mean 0.87, dist 0.05730171360281751
episode 307, mean rew 8.417449131779604,success_mean 0.87, dist 0.05727268419853868
episode 308, mean rew 8.403418337649834,success_mean 0.88, dist 0.05792266434489888
episode 309, mean rew 8.42832404889758,success_mean 0.88, dist 0.0589041756509370

episode 396, mean rew 8.4082381405178,success_mean 0.77, dist 0.06950712419630091
episode 397, mean rew 8.41166623031713,success_mean 0.76, dist 0.07097657573602172
episode 398, mean rew 8.42482992288276,success_mean 0.77, dist 0.07135498038310248
episode 399, mean rew 8.447958136997208,success_mean 0.78, dist 0.07192854338880565
episode 400, mean rew 8.47143948383677,success_mean 0.8, dist 0.07118313910461482
episode 401, mean rew 8.517748371359925,success_mean 0.83, dist 0.0682868984688225
episode 402, mean rew 8.557219388286676,success_mean 0.84, dist 0.06672602404526508
episode 403, mean rew 8.611828053752348,success_mean 0.83, dist 0.0661874049382415
episode 404, mean rew 8.631124105681614,success_mean 0.82, dist 0.06797314394513657
episode 405, mean rew 8.642247027950567,success_mean 0.81, dist 0.06937781334292649
episode 406, mean rew 8.674824214482046,success_mean 0.78, dist 0.0717238155391088
episode 407, mean rew 8.691973957383398,success_mean 0.74, dist 0.07602547236657906
e

episode 495, mean rew 8.622424518955025,success_mean 0.68, dist 0.08116756772327614
episode 496, mean rew 8.635417196869593,success_mean 0.65, dist 0.08595257774379123
episode 497, mean rew 8.656123717792644,success_mean 0.6, dist 0.09245987995207898
episode 498, mean rew 8.69617737621349,success_mean 0.56, dist 0.09638331461368667
episode 499, mean rew 8.76661800543949,success_mean 0.54, dist 0.0972918858932651
episode 500, mean rew 8.829523058541572,success_mean 0.55, dist 0.09763518881483907
episode 501, mean rew 8.839754781314973,success_mean 0.53, dist 0.10245744706729665
episode 502, mean rew 8.866370670877279,success_mean 0.47, dist 0.10970042853953743
episode 503, mean rew 8.881214908184797,success_mean 0.43, dist 0.11835589470419089
episode 504, mean rew 8.900985405558279,success_mean 0.41, dist 0.1219878438298149
episode 505, mean rew 8.930577460923459,success_mean 0.39, dist 0.12304516250291764
episode 506, mean rew 8.950063107437261,success_mean 0.38, dist 0.124855830555250

episode 594, mean rew 8.655943459094823,success_mean 0.65, dist 0.08319548544779193
episode 595, mean rew 8.676262542437408,success_mean 0.64, dist 0.08446516081293666
episode 596, mean rew 8.661285117056314,success_mean 0.64, dist 0.08497775430341843
episode 597, mean rew 8.67357731115606,success_mean 0.6, dist 0.08813937099893286
episode 598, mean rew 8.728587853355766,success_mean 0.64, dist 0.08479236382229588
episode 599, mean rew 8.762024684616371,success_mean 0.68, dist 0.08324709627477138
episode 600, mean rew 8.799863057387926,success_mean 0.64, dist 0.08741826171092297
episode 601, mean rew 8.835808105931893,success_mean 0.63, dist 0.08822111632300106
episode 602, mean rew 8.836906137811784,success_mean 0.6, dist 0.09104310619536626
episode 603, mean rew 8.863275362714166,success_mean 0.6, dist 0.09116499309647526
episode 604, mean rew 8.897520998752022,success_mean 0.55, dist 0.09432303982975197
episode 605, mean rew 8.90348669515623,success_mean 0.53, dist 0.098984172630417

episode 693, mean rew 9.293125362514878,success_mean 0.6, dist 0.09446882381338804
episode 694, mean rew 9.289446761647936,success_mean 0.65, dist 0.08727886007065248
episode 695, mean rew 9.259286379773497,success_mean 0.65, dist 0.08719955150646472
episode 696, mean rew 9.238618220113354,success_mean 0.66, dist 0.08696286873926119
episode 697, mean rew 9.208694189984223,success_mean 0.69, dist 0.08396761968662128
episode 698, mean rew 9.1939751278327,success_mean 0.74, dist 0.0803219932144478
episode 699, mean rew 9.171567363552288,success_mean 0.77, dist 0.07740147954976978
episode 700, mean rew 9.16137417143217,success_mean 0.83, dist 0.07238690560244386
episode 701, mean rew 9.149417765978098,success_mean 0.87, dist 0.06912757326322388
episode 702, mean rew 9.129917685109882,success_mean 0.88, dist 0.0684339755988453
episode 703, mean rew 9.107151762857404,success_mean 0.91, dist 0.06526434171698597
episode 704, mean rew 9.110161160086486,success_mean 0.94, dist 0.0616834508564565

episode 791, mean rew 8.824919741286678,success_mean 0.95, dist 0.05615074014665623
episode 792, mean rew 8.823916576365484,success_mean 0.95, dist 0.056473702502186925
episode 793, mean rew 8.84028104949659,success_mean 0.94, dist 0.05708156348677463
episode 794, mean rew 8.863592381711932,success_mean 0.94, dist 0.05775160816074126
episode 795, mean rew 8.866098928921959,success_mean 0.92, dist 0.05963173501743868
episode 796, mean rew 8.896585756739592,success_mean 0.92, dist 0.060789890133697885
episode 797, mean rew 8.913260053386576,success_mean 0.91, dist 0.06197931998229103
episode 798, mean rew 8.919895203309185,success_mean 0.89, dist 0.06425544727951739
episode 799, mean rew 8.9291170634695,success_mean 0.88, dist 0.06472212328557102
episode 800, mean rew 8.943723079996065,success_mean 0.87, dist 0.0659116472519107
episode 801, mean rew 8.97803915908642,success_mean 0.87, dist 0.06468638366760725
episode 802, mean rew 9.005011782649756,success_mean 0.86, dist 0.0636493683235

episode 890, mean rew 9.506805987145844,success_mean 1.0, dist 0.03295752353299984
episode 891, mean rew 9.519702439763533,success_mean 1.0, dist 0.033498985300088865
episode 892, mean rew 9.510707820013373,success_mean 1.0, dist 0.03477296924849165
episode 893, mean rew 9.511139990713671,success_mean 1.0, dist 0.03490854890698716
episode 894, mean rew 9.494323352164487,success_mean 1.0, dist 0.03486172965110878
episode 895, mean rew 9.486924580586123,success_mean 0.99, dist 0.03512748340332241
episode 896, mean rew 9.477603136553851,success_mean 0.99, dist 0.03541444767812242
episode 897, mean rew 9.480036747473664,success_mean 0.99, dist 0.03403730586770632
episode 898, mean rew 9.466573001869651,success_mean 0.99, dist 0.03406089497090541
episode 899, mean rew 9.446478773426847,success_mean 0.99, dist 0.034249073590812716
episode 900, mean rew 9.440413240167432,success_mean 0.99, dist 0.033756477811717714
episode 901, mean rew 9.430023324991403,success_mean 0.99, dist 0.033225710956

episode 988, mean rew 9.418910459216868,success_mean 0.94, dist 0.06005052479221939
episode 989, mean rew 9.43917517961419,success_mean 0.95, dist 0.05879863940759264
episode 990, mean rew 9.428188313611065,success_mean 0.96, dist 0.057416972002349985
episode 991, mean rew 9.427311749133485,success_mean 0.97, dist 0.055103439869639254
episode 992, mean rew 9.433163404147404,success_mean 0.98, dist 0.053816105482005654
episode 993, mean rew 9.420120847411527,success_mean 0.98, dist 0.051334671169026624
episode 994, mean rew 9.408329399998962,success_mean 0.98, dist 0.049878019279302487
episode 995, mean rew 9.411799018088695,success_mean 0.98, dist 0.04945088665368245
episode 996, mean rew 9.397529412697223,success_mean 0.98, dist 0.046569443738656166
episode 997, mean rew 9.395698677549198,success_mean 1.0, dist 0.043481788622602
episode 998, mean rew 9.355710093071677,success_mean 1.0, dist 0.04203337400561307
episode 999, mean rew 9.33442478916029,success_mean 1.0, dist 0.04036440402

# Our Model

In [None]:
config = {
          'num_workers': 1,
          'train_batch_size': 200,
          'log_level': 'ERROR',
          'framework': 'torch',
          'callbacks': CustomCallbacks,
          'model': {
              'fcnet_hiddens': [64, 64],
              'lambda': .99,
              'num_sgd_iter': 4,
              'lr': 1e-5,
              'value_loss_coef': .05,
              'entropy_coeff': .01,
              'clip_param': .2,
              'vf_clip_param': .2,
              'grad_clip': .5,
          }}
trainer = agents.ppo.PPOTrainer(env='point_mass_1', config=config)
for i in range(1000):
    results = trainer.train()
    print(f"episode {i}, mean rew {results['episode_reward_mean']}," +
          f"success_mean {results['custom_metrics']['success_mean']}, dist {results['custom_metrics']['dist_mean']}")

In [28]:
results

{'episode_reward_max': 2.232037772966303,
 'episode_reward_min': 0.09500131149916595,
 'episode_reward_mean': 1.0471777207514443,
 'episode_len_mean': 150.0,
 'episodes_this_iter': 6,
 'policy_reward_min': {},
 'policy_reward_max': {},
 'policy_reward_mean': {},
 'custom_metrics': {'success_mean': 0.0, 'success_min': 0, 'success_max': 0},
 'hist_stats': {'episode_reward': [2.232037772966303,
   1.7740165433414394,
   0.15976652453021253,
   0.1101491564775453,
   0.09500131149916595,
   1.912095015694],
  'episode_lengths': [150, 150, 150, 150, 150, 150]},
 'sampler_perf': {'mean_env_wait_ms': 0.1896208935565167,
  'mean_raw_obs_processing_ms': 0.05933216639927456,
  'mean_inference_ms': 0.8938195822122214,
  'mean_action_processing_ms': 0.06338814040878556},
 'off_policy_estimator': {},
 'num_healthy_workers': 1,
 'timesteps_total': 1000,
 'timers': {'sample_time_ms': 1228.964,
  'sample_throughput': 813.694,
  'learn_time_ms': 2569.964,
  'learn_throughput': 389.11,
  'update_time_ms