<a href="https://colab.research.google.com/github/ppfenninger/Sensorimotor_Learning_Final/blob/main/Sensorimotor_Testing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [410]:
## Installation
!pip install pybullet > /dev/null 2>&1
!pip install git+https://github.com/taochenshh/easyrl.git > /dev/null 2>&1
!pip install gym pyvirtualdisplay > /dev/null 2>&1
!pip install git+https://github.com/ppfenninger/airobot.git > /dev/null 2>&1
# !pip install git+https://github.com/ppfenninger/Sensorimotor_Learning_Final.git > /dev/null 2>&1

In [411]:
import os
import torch
import gym
import pprint
import time
import pybullet as p
import pybullet_data as pd
import pybullet_envs
import airobot as ar
import numpy as np
import matplotlib.pyplot as plt
import torch.optim as optim
import torch.nn.functional as F
from typing import Any
from matplotlib import animation
from IPython.display import HTML
from matplotlib import pylab
from dataclasses import dataclass
from gym import spaces
from gym.envs.registration import registry, register
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
from tqdm.notebook import tqdm
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from pathlib import Path
from copy import deepcopy
from itertools import count
from easyrl.agents.ppo_agent import PPOAgent
from easyrl.utils.common import save_traj
from easyrl.configs import cfg
from easyrl.configs import set_config
from easyrl.configs.basic_config import BasicConfig
from easyrl.configs.command_line import cfg_from_cmd
from easyrl.engine.ppo_engine import PPOEngine
from easyrl.models.categorical_policy import CategoricalPolicy
from easyrl.models.diag_gaussian_policy import DiagGaussianPolicy
from easyrl.models.mlp import MLP
from easyrl.models.value_net import ValueNet
from easyrl.agents.base_agent import BaseAgent
from easyrl.utils.torch_util import DictDataset
from easyrl.utils.torch_util import load_state_dict
from easyrl.utils.torch_util import load_torch_model
from easyrl.runner.nstep_runner import EpisodicRunner
from easyrl.utils.torch_util import save_model
from easyrl.utils.torch_util import action_entropy
from easyrl.utils.torch_util import action_from_dist
from easyrl.utils.torch_util import action_log_prob
from easyrl.utils.torch_util import clip_grad
from easyrl.utils.common import set_random_seed
from easyrl.utils.gym_util import make_vec_env
from easyrl.utils.common import load_from_json
from easyrl.utils.torch_util import freeze_model
from easyrl.utils.torch_util import move_to
from easyrl.utils.torch_util import torch_float
from easyrl.utils.torch_util import torch_to_np
from base64 import b64encode
from IPython import display as ipythondisplay

In [412]:
del sys.modules["de_agent"]
del sys.modules["de_runner"]
del sys.modules["utils"]
# del sys.modules["de_env"]
del sys.modules["de_mover_env"]
del sys.modules["de_engine"]

# install our library
!rm Sensorimotor_Learning_Final -r
!git clone -b testing https://github.com/ppfenninger/Sensorimotor_Learning_Final.git
import sys
sys.path.insert(0, './Sensorimotor_Learning_Final/deepexploration/')
import de_agent
from de_agent import DeepExplorationAgent
# import de_engine
import de_runner # this should work once everything compiles
from de_runner import DeepExplorationRunner
import utils
from utils import eval_agent, load_expert_agent, create_actor, create_critic

import de_engine
from de_engine import DeepExplorationEngine

import de_mover_env
from de_mover_env import URRobotGym

module_name = __name__

env_name = 'URRobotGym-v1'
if env_name in registry.env_specs:
    del registry.env_specs[env_name]
register(
    id=env_name,
    entry_point=f'{module_name}:URRobotGym',
)

Cloning into 'Sensorimotor_Learning_Final'...
remote: Enumerating objects: 272, done.[K
remote: Counting objects: 100% (79/79), done.[K
remote: Compressing objects: 100% (38/38), done.[K
remote: Total 272 (delta 66), reused 44 (delta 41), pack-reused 193[K
Receiving objects: 100% (272/272), 153.62 KiB | 12.80 MiB/s, done.
Resolving deltas: 100% (141/141), done.


# Setup

In [413]:
class DEConfig(BasicConfig):
    num_traj = 1

def set_configs(exp_name='de'):
    cfg.alg = DEConfig()
    cfg.alg.seed = 9037987 #seed
    cfg.alg.num_envs = 1
    cfg.alg.episode_steps = 30
    cfg.alg.max_steps = 600000
    cfg.alg.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    cfg.alg.env_name = 'URRobotGym-v1'
    cfg.alg.save_dir = Path.cwd().absolute().joinpath('data').as_posix()
    cfg.alg.save_dir += f'/{exp_name}'
    cfg.alg.policy_lr: float = 3e-4
    cfg.alg.value_lr: float = 1e-3
    cfg.alg.linear_decay_lr: bool = False
    cfg.alg.max_decay_steps: int = 1e6
    cfg.alg.eval_num_envs: int = None
    cfg.alg.opt_epochs: int = 10
    cfg.alg.normalize_adv: bool = False
    cfg.alg.clip_vf_loss: bool = False
    cfg.alg.vf_loss_type: str = 'mse'
    cfg.alg.vf_coef: float = 0.05 # was 0.5
    cfg.alg.ent_coef: float = 0.01
    cfg.alg.clip_range: float = 0.2
    cfg.alg.linear_decay_clip_range: bool = False
    cfg.alg.gae_lambda: float = 0.95
    cfg.alg.rew_discount: float = 0.99
    cfg.alg.use_amsgrad: bool = True
    cfg.alg.sgd: bool = False
    cfg.alg.momentum: float = 0.00
    cfg.alg.tanh_on_dist: bool = False
    cfg.alg.std_cond_in: bool = False
    cfg.alg.log_interval = 5
    setattr(cfg.alg, 'diff_cfg', dict(save_dir=cfg.alg.save_dir))

    print(f'====================================')
    print(f'      Device:{cfg.alg.device}')
    print(f'====================================')

In [414]:
# Runner Tests
# set_configs()
# env = URRobotPusherGym(max_episode_length=100)

# critics = []
# for index in range(4):
#   ob_size = env.observation_space.shape[0]
#   critic_body = MLP(input_size=ob_size,
#                      hidden_sizes=[64],
#                      output_size=64,
#                      hidden_act=nn.Tanh,
#                      output_act=nn.Tanh)
#   critic = ValueNet(critic_body, in_features=64)
#   critics.append(critic)

# actor = create_actor(env=env)
# agent = DeepExplorationAgent(actor=actor, critics=critics, env=env)
# runner = DeepExplorationRunner(agent=agent, env=env)

# traj = runner(time_steps=cfg.alg.episode_steps)

In [415]:
def train_de(use_sparse_reward=False, use_subgoal=False, with_obstacle=False, apply_collision_penalty=False, push_exp=False,
              max_steps=200000):
    set_configs()
    cfg.alg.num_envs = 1
    cfg.alg.episode_steps = 30
    cfg.alg.max_steps = max_steps
    cfg.alg.deque_size = 20
    cfg.alg.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    cfg.alg.env_name = 'URRobotGym-v1'
    cfg.alg.save_dir = Path.cwd().absolute().joinpath('data').as_posix()
    cfg.alg.save_dir += '/'
    cfg.alg.save_dir += 'push'
    setattr(cfg.alg, 'diff_cfg', dict(save_dir=cfg.alg.save_dir))

    print(f'====================================')
    print(f'      Device:{cfg.alg.device}')
    print(f'      Total number of steps:{cfg.alg.max_steps}')
    print(f'====================================')

    set_random_seed(cfg.alg.seed)
    env_kwargs=dict(use_sparse_reward=use_sparse_reward,
                    with_obstacle=with_obstacle,
                    use_subgoal=use_subgoal,
                    apply_collision_penalty=apply_collision_penalty,
                    is_exploratory=False)
    env = make_vec_env(cfg.alg.env_name,
                       cfg.alg.num_envs,
                       seed=cfg.alg.seed,
                       env_kwargs=env_kwargs)
    env.reset()
    ob_size = env.observation_space.shape[0]

    actor_body = MLP(input_size=ob_size,
                     hidden_sizes=[64],
                     output_size=64,
                     hidden_act=nn.Tanh,
                     output_act=nn.Tanh)

    critic_body = MLP(input_size=ob_size,
                     hidden_sizes=[64],
                     output_size=64,
                     hidden_act=nn.Tanh,
                     output_act=nn.Tanh)
  
    if isinstance(env.action_space, gym.spaces.Discrete):
        act_size = env.action_space.n
        actor = CategoricalPolicy(actor_body,
                                 in_features=64,
                                 action_dim=act_size)
        
    elif isinstance(env.action_space, gym.spaces.Box):
        act_size = env.action_space.shape[0]
        actor = DiagGaussianPolicy(actor_body,
                                   in_features=64,
                                   action_dim=act_size,
                                   tanh_on_dist=cfg.alg.tanh_on_dist,
                                   std_cond_in=cfg.alg.std_cond_in)
    else:
        raise TypeError(f'Unknown action space type: {env.action_space}')

    critics = [] # ValueNet(critic_body, in_features=64) # TODO: get critics

    for critic in range(5):
      critics.append(create_critic(env))


    agent = DeepExplorationAgent(actor=actor, critics=critics, env=env)
    runner = DeepExplorationRunner(agent=agent, env=env)
    engine = DeepExplorationEngine(agent=agent, runner=runner, env=env)
    engine.train()
    stat_info, raw_traj_info = engine.eval(render=False, save_eval_traj=True, eval_num=1, sleep_time=0.0)
    pprint.pprint(stat_info)
    return cfg.alg.save_dir

# Testing

In [None]:
# env = URRobotPusherGym(max_episode_length=100)
saved_dir_pusher = train_de(max_steps = 120000)

[32m[INFO][0m[2023-05-08 02:33:48]: [32mCreating 1 environments.[0m
INFO:EasyRL:Creating 1 environments.


      Device:cuda
      Device:cuda
      Total number of steps:120000
Use sparse reward:False
Use subgoal:False
With obstacle in the scene:False
Apply collision penalty:False


[31m[ERROR][0m[2023-05-08 02:33:49]: [31mNot a valid git repo: /usr/local/lib/python3.10/dist-packages[0m
ERROR:EasyRL:Not a valid git repo: /usr/local/lib/python3.10/dist-packages


i have this many people telling me what to do 5
iter: 0


[32m[INFO][0m[2023-05-08 02:33:50]: [32mExploration steps: 0[0m
INFO:EasyRL:Exploration steps: 0
[32m[INFO][0m[2023-05-08 02:33:50]: [32mSaving checkpoint: /content/data/push/seed_9037987/model/ckpt_000000000000.pt.[0m
INFO:EasyRL:Saving checkpoint: /content/data/push/seed_9037987/model/ckpt_000000000000.pt.
[32m[INFO][0m[2023-05-08 02:33:50]: [32mSaving checkpoint: /content/data/push/seed_9037987/model/model_best.pt.[0m
INFO:EasyRL:Saving checkpoint: /content/data/push/seed_9037987/model/model_best.pt.


We are done! 24
We are done! 24
iter: 1
iter: 2
iter: 3
iter: 4
iter: 5
iter: 6
iter: 7
iter: 8
iter: 9
iter: 10
iter: 11
iter: 12
iter: 13
iter: 14
iter: 15
iter: 16
iter: 17
iter: 18
iter: 19
iter: 20
iter: 21
iter: 22
iter: 23
iter: 24
iter: 25
iter: 26
iter: 27
iter: 28
iter: 29
iter: 30
iter: 31
iter: 32
iter: 33
iter: 34
iter: 35
iter: 36
iter: 37
iter: 38
iter: 39
iter: 40
iter: 41
iter: 42
iter: 43
iter: 44
iter: 45
iter: 46
iter: 47
iter: 48
iter: 49
iter: 50
iter: 51
iter: 52
iter: 53
iter: 54
iter: 55
iter: 56
iter: 57
iter: 58
iter: 59
iter: 60
iter: 61
iter: 62
iter: 63
iter: 64
iter: 65
iter: 66
iter: 67
iter: 68
iter: 69
iter: 70
iter: 71
iter: 72
iter: 73
iter: 74
iter: 75
iter: 76
iter: 77
iter: 78
iter: 79
iter: 80
iter: 81
iter: 82
iter: 83
iter: 84
iter: 85
iter: 86
iter: 87
iter: 88
iter: 89
iter: 90
iter: 91
iter: 92
iter: 93
iter: 94
iter: 95
iter: 96
iter: 97
iter: 98
iter: 99
iter: 100
We are done! 24


[32m[INFO][0m[2023-05-08 02:35:14]: [32mExploration steps: 3000[0m
INFO:EasyRL:Exploration steps: 3000
[32m[INFO][0m[2023-05-08 02:35:14]: [32mSaving checkpoint: /content/data/push/seed_9037987/model/ckpt_000000003000.pt.[0m
INFO:EasyRL:Saving checkpoint: /content/data/push/seed_9037987/model/ckpt_000000003000.pt.
[32m[INFO][0m[2023-05-08 02:35:14]: [32mSaving checkpoint: /content/data/push/seed_9037987/model/model_best.pt.[0m
INFO:EasyRL:Saving checkpoint: /content/data/push/seed_9037987/model/model_best.pt.


We are done! 24
iter: 101
iter: 102
iter: 103
iter: 104
iter: 105
iter: 106
iter: 107
iter: 108
iter: 109
iter: 110
iter: 111
iter: 112
iter: 113
iter: 114
iter: 115
iter: 116
iter: 117
iter: 118
iter: 119
iter: 120
iter: 121
iter: 122
iter: 123
iter: 124
iter: 125
iter: 126
iter: 127
iter: 128
iter: 129
iter: 130
iter: 131
iter: 132
iter: 133
iter: 134
iter: 135
iter: 136
iter: 137
iter: 138
iter: 139
iter: 140
iter: 141
iter: 142
iter: 143
iter: 144
iter: 145
iter: 146
iter: 147
iter: 148
iter: 149
iter: 150
iter: 151
iter: 152
iter: 153
iter: 154
iter: 155
iter: 156
iter: 157
iter: 158
iter: 159
iter: 160
iter: 161
iter: 162
iter: 163
iter: 164
iter: 165
iter: 166
iter: 167
iter: 168
iter: 169
iter: 170
iter: 171
iter: 172
iter: 173
iter: 174
iter: 175
iter: 176
iter: 177
iter: 178
iter: 179
iter: 180
iter: 181
iter: 182
iter: 183
iter: 184
iter: 185
iter: 186
iter: 187
iter: 188
iter: 189
iter: 190
iter: 191
iter: 192
iter: 193
iter: 194
iter: 195
iter: 196
iter: 197
iter: 198
iter

[32m[INFO][0m[2023-05-08 02:36:40]: [32mExploration steps: 6000[0m
INFO:EasyRL:Exploration steps: 6000
[32m[INFO][0m[2023-05-08 02:36:40]: [32mSaving checkpoint: /content/data/push/seed_9037987/model/ckpt_000000006000.pt.[0m
INFO:EasyRL:Saving checkpoint: /content/data/push/seed_9037987/model/ckpt_000000006000.pt.
[32m[INFO][0m[2023-05-08 02:36:40]: [32mSaving checkpoint: /content/data/push/seed_9037987/model/model_best.pt.[0m
INFO:EasyRL:Saving checkpoint: /content/data/push/seed_9037987/model/model_best.pt.


We are done! 24
iter: 201
iter: 202
iter: 203
iter: 204
iter: 205
iter: 206
iter: 207
iter: 208
iter: 209
iter: 210
iter: 211
iter: 212
iter: 213
iter: 214
iter: 215
iter: 216
iter: 217
iter: 218
iter: 219
iter: 220
iter: 221
iter: 222
iter: 223
iter: 224
iter: 225
iter: 226
iter: 227
iter: 228
iter: 229
iter: 230
iter: 231
iter: 232
iter: 233
iter: 234
iter: 235
iter: 236
iter: 237
iter: 238
iter: 239
iter: 240
iter: 241
iter: 242
iter: 243
iter: 244
iter: 245
iter: 246
iter: 247
iter: 248
iter: 249
iter: 250
iter: 251
iter: 252
iter: 253
iter: 254
iter: 255
iter: 256
iter: 257
iter: 258
iter: 259
iter: 260
iter: 261
iter: 262
iter: 263
iter: 264
iter: 265
iter: 266
iter: 267
iter: 268
iter: 269
iter: 270
iter: 271
iter: 272
iter: 273
iter: 274
iter: 275
iter: 276
iter: 277
iter: 278
iter: 279
iter: 280
iter: 281
iter: 282
iter: 283
iter: 284
iter: 285
iter: 286
iter: 287
iter: 288
iter: 289
iter: 290
iter: 291
iter: 292
iter: 293
iter: 294
iter: 295
iter: 296
iter: 297
iter: 298
iter

[32m[INFO][0m[2023-05-08 02:38:05]: [32mExploration steps: 9000[0m
INFO:EasyRL:Exploration steps: 9000
[32m[INFO][0m[2023-05-08 02:38:05]: [32mSaving checkpoint: /content/data/push/seed_9037987/model/ckpt_000000009000.pt.[0m
INFO:EasyRL:Saving checkpoint: /content/data/push/seed_9037987/model/ckpt_000000009000.pt.
[32m[INFO][0m[2023-05-08 02:38:05]: [32mSaving checkpoint: /content/data/push/seed_9037987/model/model_best.pt.[0m
INFO:EasyRL:Saving checkpoint: /content/data/push/seed_9037987/model/model_best.pt.


We are done! 24
We are done! 24
iter: 301
iter: 302
iter: 303
iter: 304
iter: 305
iter: 306
iter: 307
iter: 308
iter: 309
iter: 310
iter: 311
iter: 312
iter: 313
iter: 314
iter: 315
iter: 316
iter: 317
iter: 318
iter: 319
iter: 320
iter: 321
iter: 322
iter: 323
iter: 324
iter: 325
iter: 326
iter: 327
iter: 328
iter: 329
iter: 330
iter: 331
iter: 332
iter: 333
iter: 334
iter: 335
iter: 336
iter: 337
iter: 338
iter: 339
iter: 340
iter: 341
iter: 342
iter: 343
iter: 344
iter: 345
iter: 346
iter: 347
iter: 348
iter: 349
iter: 350
iter: 351
iter: 352
iter: 353
iter: 354
iter: 355
iter: 356
iter: 357
iter: 358
iter: 359
iter: 360
iter: 361
iter: 362
iter: 363
iter: 364
iter: 365
iter: 366
iter: 367
iter: 368
iter: 369
iter: 370
iter: 371
iter: 372
iter: 373
iter: 374
iter: 375
iter: 376
iter: 377
iter: 378
iter: 379
iter: 380
iter: 381
iter: 382
iter: 383
iter: 384
iter: 385
iter: 386
iter: 387
iter: 388
iter: 389
iter: 390
iter: 391
iter: 392
iter: 393
iter: 394
iter: 395
iter: 396
iter: 39

[32m[INFO][0m[2023-05-08 02:39:30]: [32mExploration steps: 12000[0m
INFO:EasyRL:Exploration steps: 12000
[32m[INFO][0m[2023-05-08 02:39:30]: [32mSaving checkpoint: /content/data/push/seed_9037987/model/ckpt_000000012000.pt.[0m
INFO:EasyRL:Saving checkpoint: /content/data/push/seed_9037987/model/ckpt_000000012000.pt.
[32m[INFO][0m[2023-05-08 02:39:30]: [32mSaving checkpoint: /content/data/push/seed_9037987/model/model_best.pt.[0m
INFO:EasyRL:Saving checkpoint: /content/data/push/seed_9037987/model/model_best.pt.


We are done! 24
We are done! 24
iter: 401
iter: 402
iter: 403
iter: 404
iter: 405
iter: 406
iter: 407
iter: 408
iter: 409
iter: 410
iter: 411
iter: 412
iter: 413
iter: 414
iter: 415
iter: 416
iter: 417
iter: 418
iter: 419
iter: 420
iter: 421
iter: 422
iter: 423
iter: 424
iter: 425
iter: 426
iter: 427
iter: 428
iter: 429
iter: 430
iter: 431
iter: 432
iter: 433
iter: 434
iter: 435
iter: 436
iter: 437
iter: 438
iter: 439
iter: 440
iter: 441
iter: 442
iter: 443
iter: 444
iter: 445
iter: 446
iter: 447
iter: 448
iter: 449
iter: 450
iter: 451
iter: 452
iter: 453
iter: 454
iter: 455
iter: 456
iter: 457
iter: 458
iter: 459
iter: 460
iter: 461
iter: 462
iter: 463
iter: 464
iter: 465
iter: 466
iter: 467
iter: 468
iter: 469
iter: 470
iter: 471
iter: 472
iter: 473
iter: 474
iter: 475
iter: 476
iter: 477
iter: 478
iter: 479
iter: 480
iter: 481
iter: 482
iter: 483
iter: 484
iter: 485
iter: 486
iter: 487
iter: 488
iter: 489
iter: 490
iter: 491
iter: 492
iter: 493
iter: 494
iter: 495
iter: 496
iter: 49

[32m[INFO][0m[2023-05-08 02:40:55]: [32mExploration steps: 15000[0m
INFO:EasyRL:Exploration steps: 15000
[32m[INFO][0m[2023-05-08 02:40:55]: [32mSaving checkpoint: /content/data/push/seed_9037987/model/ckpt_000000015000.pt.[0m
INFO:EasyRL:Saving checkpoint: /content/data/push/seed_9037987/model/ckpt_000000015000.pt.


We are done! 24
iter: 501
iter: 502
iter: 503
iter: 504
iter: 505
iter: 506
iter: 507
iter: 508
iter: 509
iter: 510
iter: 511
iter: 512
iter: 513
iter: 514
iter: 515
iter: 516
iter: 517
iter: 518
iter: 519
iter: 520
iter: 521
iter: 522
iter: 523
iter: 524
iter: 525
iter: 526
iter: 527
iter: 528
iter: 529
iter: 530
iter: 531
iter: 532
iter: 533
iter: 534
iter: 535
iter: 536
iter: 537
iter: 538
iter: 539
iter: 540
iter: 541
iter: 542
iter: 543
iter: 544
iter: 545
iter: 546
iter: 547
iter: 548
iter: 549
iter: 550
iter: 551
iter: 552
iter: 553
iter: 554
iter: 555
iter: 556
iter: 557
iter: 558
iter: 559
iter: 560
iter: 561
iter: 562
iter: 563
iter: 564
iter: 565
iter: 566
iter: 567
iter: 568
iter: 569
iter: 570
iter: 571
iter: 572
iter: 573
iter: 574
iter: 575
iter: 576
iter: 577
iter: 578
iter: 579
iter: 580
iter: 581
iter: 582
iter: 583
iter: 584
iter: 585
iter: 586
iter: 587
iter: 588
iter: 589
iter: 590
iter: 591
iter: 592
iter: 593
iter: 594
iter: 595
iter: 596
iter: 597
iter: 598
iter

[32m[INFO][0m[2023-05-08 02:42:17]: [32mExploration steps: 18000[0m
INFO:EasyRL:Exploration steps: 18000
[32m[INFO][0m[2023-05-08 02:42:17]: [32mSaving checkpoint: /content/data/push/seed_9037987/model/ckpt_000000018000.pt.[0m
INFO:EasyRL:Saving checkpoint: /content/data/push/seed_9037987/model/ckpt_000000018000.pt.


We are done! 24
iter: 601
iter: 602
iter: 603
iter: 604
iter: 605
iter: 606
iter: 607
iter: 608
iter: 609
iter: 610
iter: 611
iter: 612
iter: 613
iter: 614
iter: 615
iter: 616
iter: 617
iter: 618
iter: 619
iter: 620
iter: 621
iter: 622
iter: 623
iter: 624
iter: 625
iter: 626
iter: 627
iter: 628
iter: 629
iter: 630
iter: 631
iter: 632
iter: 633
iter: 634
iter: 635
iter: 636
iter: 637
iter: 638
iter: 639
iter: 640
iter: 641
iter: 642
iter: 643
iter: 644
iter: 645
iter: 646
iter: 647
iter: 648
iter: 649
iter: 650
iter: 651
iter: 652
iter: 653
iter: 654
iter: 655
iter: 656
iter: 657
iter: 658
iter: 659
iter: 660
iter: 661
iter: 662
iter: 663
iter: 664
iter: 665
iter: 666
iter: 667
iter: 668
iter: 669
iter: 670
iter: 671
iter: 672
iter: 673
iter: 674
iter: 675
iter: 676
iter: 677
iter: 678
iter: 679
iter: 680
iter: 681
iter: 682
iter: 683
iter: 684
iter: 685
iter: 686
iter: 687
iter: 688
iter: 689
iter: 690
iter: 691
iter: 692
iter: 693
iter: 694
iter: 695
iter: 696
iter: 697
iter: 698
iter

In [None]:
print(saved_dir_pusher)

def read_tf_log(log_dir):
    log_dir = Path(log_dir)
    log_files = list(log_dir.glob(f'**/events.*'))
    print(log_files)
    if len(log_files) < 1:
        return None
    log_file = log_files[-1]
    event_acc = EventAccumulator(log_file.as_posix())
    event_acc.Reload()
    tags = event_acc.Tags()
    try:
        scalar_success = event_acc.Scalars('train/episode_success')
        success_rate = [x.value for x in scalar_success]
        steps = [x.step for x in scalar_success]
        scalar_return = event_acc.Scalars('train/episode_return/mean')
        returns = [x.value for x in scalar_return]
        scalar_loss = event_acc.Scalars('train/total_loss')
        losses = [x.value for x in scalar_loss]
        vf_loss = event_acc.Scalars('train/vf_loss')
        vf_losses = [x.value for x in vf_loss]
        pg_loss = event_acc.Scalars('train/pg_loss')
        pg_losses = [x.value for x in pg_loss]
    except:
        return None
    return steps, returns, success_rate, losses, vf_losses, pg_losses

steps, returns, success_rates, losses, vf_losses, pg_losses = read_tf_log(saved_dir_pusher)

In [None]:
plt.plot(steps, returns)
print(returns)
print(steps)
print(success_rates)
plt.plot(steps, losses)
plt.plot(steps, vf_losses)
plt.plot(steps, pg_losses)
# plt.plot(steps, vf_losses + pg_losses)
plt.legend(['returns', 'losses', 'vf', 'pg'])

In [None]:
plt.plot(steps, success_rates)

In [None]:
def play_video(video_dir, video_file=None):
    if video_file is None:
        video_dir = Path(video_dir)
        video_files = list(video_dir.glob(f'**/render_video.mp4'))
        video_files.sort()
        video_file = video_files[-1]
    else:
        video_file = Path(video_file)
    compressed_file = video_file.parent.joinpath('comp.mp4')
    os.system(f"ffmpeg -i {video_file} -filter:v 'setpts=2.0*PTS' -vcodec libx264 {compressed_file.as_posix()}")
    mp4 = open(compressed_file.as_posix(),'rb').read()
    data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
    display(HTML("""
    <video width=400 controls>
        <source src="%s" type="video/mp4">
    </video>
    """ % data_url))


play_video(saved_dir_pusher)