In [1]:
import gym
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" #### REMOVE THIS LINE WHEN CUDA CONFIG IS FIXED
import tensorflow as tf
import numpy as np
import sys
import time
import gym_snake
import json
import importlib
import random
from PIL import Image
from tqdm import tqdm

sys.path.insert(0, '..')
from utils.Buffer import ReplayBuffer
from utils.Conv import ConvHead
from rl.models import get_policy_architecture, get_value_architecture, get_vision_architecture
from algos.PPO import PPO_agent
from algos.DQN import DQN_agent

# %load_ext line_profiler
%matplotlib notebook

In [2]:
# tetris = importlib.import_module('pytris-effect.src.gameui')

In [3]:
run_name = 'snake'
action = 'train'
algo = ('DQN', 'Dueling')

In [4]:
cfg_fp = os.path.join('..', 'configs', run_name + '.json')
with open(cfg_fp, 'r') as f:
    config = json.load(f)
ckpt_folder = os.path.join('..', 'checkpoints')

In [5]:
env_name = config['env']
if run_name == 'tetris':
    env = tetris.GameUI(graphic_mode=False, its_per_sec=2, sec_per_tick=0.5)
else:
    env = gym.make(env_name).env if 'use_raw_env' in config else gym.make(env_name)

In [6]:
env.reset().shape

(150, 150, 3)

In [7]:
def show_img(arr):
    scaling = 30
    data = np.zeros((scaling*arr.shape[0], scaling*arr.shape[1], 3), dtype=np.uint8)
    for i in range(data.shape[0]):
        for j in range(data.shape[1]):
            for k in range(data.shape[2]):
                data[i,j,k] = arr[i//scaling,j//scaling,k]
    img = Image.fromarray(data, 'RGB')
    # img.save('my.png')
    img.show()

In [8]:
if action == 'evaluate':
    %lprun -f env.drawMatrix env.drawMatrix()

In [9]:
if action == 'evaluate':
    arr = env.reset()[::10,::10,:]
    img = Image.fromarray(arr, 'RGB')
    img.show()
    #show_img(env.reset())

In [10]:
if False:
    action = 1
    obs, reward, dn, info = env.step(action)
    show_img(obs)
    print(reward, dn, info)

In [11]:
def do_step():
    _, _, dn, _ = env.step(random.choice(range(7)))
    if dn:
        env.reset()

In [12]:
#%timeit env.reset()

In [13]:
#%timeit do_step()

In [14]:
#%lprun -f env.get_obs do_step()

In [15]:
model = get_policy_architecture(env_name, algo=algo)
if 'DQN' in algo:
    target = tf.keras.models.clone_model(model)
else:
    value = get_value_architecture(env_name)

In [16]:
if 'DQN' in algo:
    agent = DQN_agent(
        model,
        # (TODO): Move args for ReplayBuffer into DQN
        ReplayBuffer(config.get("max_buf_size", 20000), mode='uniform'),
        target=target,
        env=env,
        mode=('DDQN'), # 'PER'
        learning_rate=config['learning_rate'],
        batch_size=config['batch_size'],
        update_steps=1,
        update_freq=4,
        multistep=5,
        alpha=1.5,
        beta=1.0,
        gamma=0.95,
        target_delay=1000,
        delta=1.0,
        # delta=0.000003,
        env_name=config['env_name'],
        algo_name='DQN',
        ckpt_folder=ckpt_folder,
        run_name='snake-DQN-pretrain-hard_update-uniform-multistep5-5'
    )
elif 'PPO' in algo:
    agent = PPO_agent(
        model,
        value,
        env=env,
        learning_rate=config['learning_rate'],
        minibatch_size=config['minibatch_size'],
        gamma=0.99,
        env_name=config['env_name'],
        run_name='snake-PPO-pretrain',
        ckpt_folder=ckpt_folder
    )

In [17]:
t_max = config['t_max']

In [18]:
p_buf = []

def collect_rollout(env, t_max, policy):
    s = agent.preprocess(env.reset())
    for t in range(t_max):
        act = policy(s)
        ss, r, dn, _ = env.step(agent.action_wrapper(act))
        ss = agent.preprocess(ss)
        p_buf.append([s, ss])
        s = ss
        if dn:
            break

In [19]:
pretrain = True
if pretrain and action == 'train': # only necessary for tasks on raw pixels (vision)
    embed = get_vision_architecture(env_name)
    # get some data from random interactions with the env
    for i in tqdm(range(500)):
        collect_rollout(env, t_max, lambda x: np.random.choice(4))
    print("Collected {} samples".format(len(p_buf)))
    head = ConvHead(embed, p_buf)
    head.train(6)
    
    out = head.model.layers[-2].output
    vision = tf.keras.Model(inputs=head.model.input, outputs=out)
    pretrained_model = get_policy_architecture(env_name, algo=algo, pretrain=vision)

100%|██████████| 500/500 [00:01<00:00, 478.26it/s]

Collected 14453 samples





  0%|          | 0/225 [00:00<?, ?it/s]

[0] Loss: [0.09307133 0.08805098 0.1230799  0.13021004 0.0880185  0.11400352
 0.14983298 0.15032071 0.13313848 0.125382   0.14798413 0.14268923
 0.20014411 0.19308825 0.10734501 0.09796516 0.15062109 0.15624909
 0.16062999 0.14636958 0.21961443 0.19605972 0.22441262 0.19783959
 0.18830147 0.18348753 0.15732811 0.12852766 0.11299545 0.12922427
 0.20187263 0.19025165 0.15708141 0.1380813  0.1343213  0.11293899
 0.13321531 0.12582491 0.11186472 0.13796696 0.23163427 0.23343693
 0.13907497 0.13698103 0.16049136 0.20446755 0.1124256  0.10173924
 0.12114715 0.12034106 0.21479034 0.19829915 0.15562432 0.16402315
 0.18751632 0.17315833 0.18989809 0.16708255 0.12942217 0.13907868
 0.16002885 0.1242775  0.14145528 0.13295877 0.17853296 0.17334443
 0.14578609 0.14513028 0.19483477 0.19287898 0.18659134 0.19433878
 0.11810233 0.1306259  0.20806473 0.17389207 0.13683502 0.14739028
 0.14568031 0.14540078 0.12714218 0.13556686 0.20043488 0.17022417
 0.11415653 0.09915541 0.1564895  0.16081233 0.15215

  0%|          | 0/225 [00:00<?, ?it/s]

[1] Loss: [0.08069696 0.06952894 0.05375838 0.06677462 0.08574428 0.08150017
 0.06659737 0.05088126 0.03079374 0.03225385 0.05729793 0.06749725
 0.07141192 0.06627582 0.03381082 0.06222618 0.05714087 0.06563023
 0.04653515 0.0618677  0.05486274 0.04257086 0.04326961 0.02935478
 0.03072287 0.02876476 0.04294356 0.05013074 0.05562707 0.06784251
 0.03888958 0.04217744 0.06150097 0.07398859 0.0740888  0.0558664
 0.05050935 0.05924526 0.04455234 0.04736103 0.05856825 0.04865573
 0.06566683 0.04728223 0.07552149 0.06846961 0.06431076 0.07249643
 0.07259826 0.0475819  0.04638346 0.0325957  0.04459222 0.05427106
 0.09895363 0.08344954 0.03792978 0.0358173  0.05256628 0.05470241
 0.06501411 0.05030365 0.05717375 0.0558335  0.03406027 0.03694902
 0.03912221 0.04269006 0.0382542  0.05231444 0.02758526 0.03084766
 0.07542432 0.08737613 0.02136563 0.01789146 0.04387948 0.0455508
 0.06181699 0.05165013 0.04938443 0.04689608 0.06160506 0.06458031
 0.06791691 0.06569298 0.06223798 0.07090948 0.0469179

  0%|          | 0/225 [00:00<?, ?it/s]

[2] Loss: [0.03055952 0.03679064 0.04988742 0.04812478 0.02804299 0.03315771
 0.0338087  0.03549401 0.0552864  0.06846392 0.06824289 0.05238279
 0.03377852 0.03717086 0.04155423 0.04466004 0.05422344 0.05945678
 0.05181535 0.05553811 0.05448621 0.07958236 0.02679933 0.04300687
 0.03880384 0.03839058 0.04599491 0.04580816 0.03881588 0.03656403
 0.06739261 0.04752773 0.03120575 0.04467125 0.0182112  0.01693921
 0.04058313 0.05055325 0.03524451 0.03572703 0.03887005 0.0396244
 0.04333559 0.07113092 0.06844048 0.04862629 0.04491714 0.06131336
 0.03900247 0.04682746 0.04493868 0.03342726 0.02739711 0.04554051
 0.02939444 0.03298733 0.04668454 0.04121552 0.03192267 0.03750737
 0.04043548 0.04558646 0.04527524 0.05078357 0.06705981 0.0420308
 0.05350422 0.06501158 0.04785181 0.05403553 0.04146181 0.04019069
 0.02607631 0.02565541 0.04933422 0.04105052 0.05182121 0.04491791
 0.06838065 0.06010645 0.03646328 0.03085103 0.08059835 0.06899423
 0.07301004 0.04024048 0.05016606 0.05492176 0.0254037

  0%|          | 0/225 [00:00<?, ?it/s]

[3] Loss: [0.04908782 0.04762285 0.05388691 0.05137071 0.0572384  0.04507666
 0.05553929 0.06873049 0.03160487 0.02376461 0.04076912 0.03465714
 0.04543669 0.05206829 0.05261564 0.06397304 0.06174389 0.07239935
 0.05424506 0.0548534  0.04608164 0.04357877 0.0529219  0.04912912
 0.02132355 0.020961   0.05411478 0.04767129 0.05058946 0.03171999
 0.05932987 0.05145971 0.03519034 0.02853065 0.03606766 0.04353313
 0.05239749 0.05006587 0.04879293 0.03928277 0.03316757 0.0324785
 0.04889863 0.05251304 0.04155439 0.03575277 0.0724899  0.05511487
 0.04213324 0.03598309 0.05591566 0.0544179  0.03374112 0.0701111
 0.04354433 0.04633152 0.04804539 0.03871378 0.04192482 0.0384791
 0.05562237 0.03532547 0.05646754 0.04310688 0.04946909 0.04129008
 0.04409382 0.04461006 0.03206426 0.03564734 0.02541309 0.03338563
 0.04010674 0.03927039 0.04874812 0.04385082 0.03482851 0.04078761
 0.03696679 0.03794277 0.03643734 0.03990407 0.03357722 0.02884897
 0.03406895 0.04242399 0.0347821  0.05078148 0.03205214

  0%|          | 0/225 [00:00<?, ?it/s]

[4] Loss: [0.03214146 0.0338574  0.04857971 0.05920762 0.04656979 0.0489878
 0.03232387 0.0241944  0.02961327 0.03608162 0.03065651 0.04061222
 0.05417483 0.0656703  0.03909378 0.04706008 0.02645934 0.01521402
 0.04154513 0.04053309 0.02782937 0.04474166 0.03994885 0.05003527
 0.04109413 0.0624642  0.02423389 0.01017717 0.02585901 0.02961898
 0.04065834 0.02975772 0.0483307  0.04676063 0.07183118 0.06559954
 0.03987803 0.02932897 0.0363347  0.06152781 0.05941777 0.04710889
 0.06756462 0.06500492 0.05069375 0.04725711 0.0444295  0.03956562
 0.05126524 0.056915   0.04634612 0.03983348 0.03424932 0.02929766
 0.02331321 0.01989938 0.0428693  0.04016604 0.0329407  0.04071037
 0.03068308 0.02632184 0.0538184  0.04182496 0.03728766 0.03764407
 0.03781104 0.04098009 0.04441852 0.03242008 0.03523916 0.04607099
 0.04504864 0.05392614 0.04449615 0.05858412 0.06684392 0.0404642
 0.03754137 0.02990552 0.04903559 0.0532108  0.02557804 0.02918741
 0.03515541 0.0389764  0.04187919 0.03648016 0.0199433

  0%|          | 0/225 [00:00<?, ?it/s]

[5] Loss: [0.03505196 0.0399691  0.03983459 0.02971544 0.04253129 0.05012533
 0.01976922 0.038383   0.04452203 0.03821186 0.03148089 0.02578117
 0.03615927 0.03947281 0.06425937 0.0547542  0.05698005 0.05080581
 0.04003004 0.04406574 0.04444344 0.04984261 0.0432989  0.05042571
 0.04942934 0.04425447 0.03749426 0.04864088 0.03882334 0.03949568
 0.03763978 0.03242958 0.06720599 0.08775312 0.01975565 0.01949976
 0.05014361 0.04978589 0.03786414 0.05313593 0.0328842  0.02707512
 0.03814159 0.03270101 0.0458215  0.03834472 0.04723456 0.0325404
 0.04142407 0.04420083 0.022715   0.02860026 0.04929208 0.04258433
 0.02717546 0.02064011 0.03174585 0.03330657 0.03539924 0.04144128
 0.0353392  0.03456359 0.0594635  0.07093821 0.06046783 0.06262735
 0.03764725 0.03365881 0.04411452 0.03623327 0.03060147 0.0457408
 0.05252559 0.05467099 0.04195033 0.03081959 0.05362513 0.04937164
 0.07536021 0.05643767 0.04373831 0.04542193 0.03873754 0.03781836
 0.05625566 0.09462573 0.06542093 0.08096468 0.0497875

In [20]:
agent.load_from_checkpoint()
### WARNING: will overwrite existing runs
hist = []

In [None]:
if action == 'train':
    if 'DQN' in algo:
        # fill buffer with some random samples
        for i in tqdm(range(500)):
            agent.collect_rollout(t_max=t_max, policy=lambda x: np.random.choice(4), train=False, display=False)
        #print(agent.epsilon)
        #agent.epsilon = 0.05
        hist += agent.train(epochs=config['train_epochs'], t_max=t_max, display=False)
    elif 'PPO' in algo:
        agent.train(epochs=config['train_epochs'], t_max=t_max, buf_size=3000, min_buf_size=600, display=False)

100%|██████████| 500/500 [00:01<00:00, 467.75it/s]


Training epochs:   0%|          | 0/20000 [00:00<?, ?it/s]

[5] Average reward: -1.0
Predicted reward: [[-0.10766181 -0.08803619 -0.10001855 -0.06081846]]
Buffer size: 14747
Saving to checkpoint...
[10] Average reward: -1.0
Predicted reward: [[-0.15911473 -0.0771579  -0.09232363 -0.08019051]]
Buffer size: 14803
Saving to checkpoint...
[15] Average reward: -0.6
Predicted reward: [[-0.07551421 -0.00872605 -0.00587697 -0.00992572]]
Buffer size: 14938
Saving to checkpoint...
[20] Average reward: -0.8
Predicted reward: [[-0.02514904  0.05228026  0.04356568  0.00485907]]
Buffer size: 15085
Saving to checkpoint...
[25] Average reward: -1.0
Predicted reward: [[-0.04146246  0.04002898 -0.01569206 -0.00385167]]
Buffer size: 15403
Saving to checkpoint...
[30] Average reward: -1.0
Predicted reward: [[0.00590582 0.01548624 0.02131607 0.01730995]]
Buffer size: 15516
Saving to checkpoint...
[35] Average reward: -1.0
Predicted reward: [[-0.01413933  0.01605348  0.00584279  0.00413394]]
Buffer size: 15704
Saving to checkpoint...
[40] Average reward: -0.8
Predic

[300] Average reward: -0.4
Predicted reward: [[-0.04185158 -0.12805587  0.05426696 -0.05608696]]
Buffer size: 40000
Saving to checkpoint...
[305] Average reward: -0.4
Predicted reward: [[-0.19587076 -0.26437587 -0.17650378 -0.22997981]]
Buffer size: 40000
Saving to checkpoint...
[310] Average reward: -0.6
Predicted reward: [[-0.36250913 -0.16373888 -0.3920886  -0.36049742]]
Buffer size: 40000
Saving to checkpoint...
[315] Average reward: -0.8
Predicted reward: [[ 0.04813107  0.34725714  0.03083775 -0.09888037]]
Buffer size: 40000
Saving to checkpoint...
[320] Average reward: -0.6
Predicted reward: [[-0.06462394 -0.08669166 -0.11473812 -0.08166464]]
Buffer size: 40000
Saving to checkpoint...
[325] Average reward: -0.6
Predicted reward: [[ 0.01721678 -0.01730771 -0.02629794 -0.04468776]]
Buffer size: 40000
Saving to checkpoint...
[330] Average reward: -1.0
Predicted reward: [[ 0.19045699  0.32718548 -0.03778636  0.0790229 ]]
Buffer size: 40000
Saving to checkpoint...
[335] Average reward

[595] Average reward: -0.4
Predicted reward: [[-0.26592442 -0.01538748 -0.16597584 -0.40632567]]
Buffer size: 40000
Saving to checkpoint...
[600] Average reward: -0.2
Predicted reward: [[0.09462348 0.07887654 0.05474317 0.00305627]]
Buffer size: 40000
Saving to checkpoint...
[605] Average reward: -1.0
Predicted reward: [[0.2092766  0.1338644  0.12929104 0.12068484]]
Buffer size: 40000
Saving to checkpoint...
[610] Average reward: -0.8
Predicted reward: [[-0.1127308   0.0391437  -0.06749873 -0.23854363]]
Buffer size: 40000
Saving to checkpoint...
[615] Average reward: -0.4
Predicted reward: [[-0.01658532 -0.0110094  -0.02125613 -0.0905906 ]]
Buffer size: 40000
Saving to checkpoint...
[620] Average reward: -0.8
Predicted reward: [[ 0.03755306 -0.00292496  0.01095085 -0.02526449]]
Buffer size: 40000
Saving to checkpoint...
[625] Average reward: -1.0
Predicted reward: [[-0.03784773  0.03102881 -0.02508795 -0.13795847]]
Buffer size: 40000
Saving to checkpoint...
[630] Average reward: -0.8
P

[890] Average reward: -0.8
Predicted reward: [[0.05967283 0.01052999 0.03381503 0.02097711]]
Buffer size: 40000
Saving to checkpoint...
[895] Average reward: -0.4
Predicted reward: [[0.44080657 0.01329568 0.25522488 0.47511822]]
Buffer size: 40000
Saving to checkpoint...
[900] Average reward: -0.6
Predicted reward: [[0.05741161 0.0080052  0.0353182  0.02961391]]
Buffer size: 40000
Saving to checkpoint...
[905] Average reward: -1.0
Predicted reward: [[0.26284632 0.101134   0.16848698 0.22385916]]
Buffer size: 40000
Saving to checkpoint...
[910] Average reward: -0.6
Predicted reward: [[-0.0182218  -0.0527364  -0.01836872 -0.03241897]]
Buffer size: 40000
Saving to checkpoint...
[915] Average reward: -1.0
Predicted reward: [[-0.48519516 -0.22531754 -0.31068605 -0.4934733 ]]
Buffer size: 40000
Saving to checkpoint...
[920] Average reward: -0.8
Predicted reward: [[-0.02145499 -0.03552729 -0.01892143 -0.04700822]]
Buffer size: 40000
Saving to checkpoint...
[925] Average reward: -0.6
Predicted

[1185] Average reward: 0.0
Predicted reward: [[1.2751616 3.3900933 6.0824103 3.1378326]]
Buffer size: 40000
Saving to checkpoint...
[1190] Average reward: -1.0
Predicted reward: [[ 0.07647717 -0.02170861  0.06291783  0.06141078]]
Buffer size: 40000
Saving to checkpoint...
[1195] Average reward: -0.6
Predicted reward: [[-0.39823198 -0.16815436 -0.24876654 -0.41050673]]
Buffer size: 40000
Saving to checkpoint...
[1200] Average reward: -0.6
Predicted reward: [[-0.38249588 -0.14631677 -0.231148   -0.3894918 ]]
Buffer size: 40000
Saving to checkpoint...
[1205] Average reward: -0.8
Predicted reward: [[0.06751943 0.08047605 0.07215428 0.03199983]]
Buffer size: 40000
Saving to checkpoint...
[1210] Average reward: -1.0
Predicted reward: [[-0.03399825 -0.0124557  -0.00594354 -0.06008697]]
Buffer size: 40000
Saving to checkpoint...
[1215] Average reward: -0.8
Predicted reward: [[0.06718469 0.04418015 0.05696988 0.03962302]]
Buffer size: 40000
Saving to checkpoint...
[1220] Average reward: -1.0
Pr

[1480] Average reward: -0.6
Predicted reward: [[-0.26648402 -0.17539215 -0.19243908 -0.2762289 ]]
Buffer size: 40000
Saving to checkpoint...
[1485] Average reward: -0.6
Predicted reward: [[-0.22377181 -0.11244249 -0.15030193 -0.23525715]]
Buffer size: 40000
Saving to checkpoint...
[1490] Average reward: -1.0
Predicted reward: [[ 0.09508681 -0.0125885   0.05649972  0.08238029]]
Buffer size: 40000
Saving to checkpoint...
[1495] Average reward: -0.6
Predicted reward: [[0.25371027 0.09494328 0.18383574 0.2263478 ]]
Buffer size: 40000
Saving to checkpoint...
[1500] Average reward: -0.6
Predicted reward: [[-0.22855842 -0.09542358 -0.14751494 -0.2587632 ]]
Buffer size: 40000
Saving to checkpoint...
[1505] Average reward: -0.8
Predicted reward: [[-0.36720848 -0.22768664 -0.25618362 -0.37525725]]
Buffer size: 40000
Saving to checkpoint...
[1510] Average reward: -0.8
Predicted reward: [[-0.07651043 -0.06153393 -0.04049087 -0.09320736]]
Buffer size: 40000
Saving to checkpoint...
[1515] Average re

[1775] Average reward: -1.0
Predicted reward: [[-0.13543153 -0.05438149 -0.09697163 -0.16707063]]
Buffer size: 40000
Saving to checkpoint...
[1780] Average reward: -0.8
Predicted reward: [[ 0.2542622  -0.0038116   0.13843822  0.2587874 ]]
Buffer size: 40000
Saving to checkpoint...
[1785] Average reward: -1.0
Predicted reward: [[-0.430691   -0.24156594 -0.31780744 -0.46721864]]
Buffer size: 40000
Saving to checkpoint...
[1790] Average reward: -0.6
Predicted reward: [[0.6553538  0.27686858 0.45347786 0.64556646]]
Buffer size: 40000
Saving to checkpoint...
[1795] Average reward: -1.0
Predicted reward: [[-0.30414867 -0.19711232 -0.22309399 -0.32351875]]
Buffer size: 40000
Saving to checkpoint...
[1800] Average reward: -0.6
Predicted reward: [[-0.04292178 -0.09475899 -0.04774809 -0.05093622]]
Buffer size: 40000
Saving to checkpoint...
[1805] Average reward: -1.0
Predicted reward: [[-0.09664202 -0.05966139 -0.06883144 -0.12477517]]
Buffer size: 40000
Saving to checkpoint...
[1810] Average re

[2070] Average reward: -0.8
Predicted reward: [[ 0.09273314 -0.01370907  0.05283999  0.08808422]]
Buffer size: 40000
Saving to checkpoint...
[2075] Average reward: -1.0
Predicted reward: [[-0.12468266 -0.19256234 -0.13793063 -0.13017559]]
Buffer size: 40000
Saving to checkpoint...
[2080] Average reward: -0.8
Predicted reward: [[ 0.06836724 -0.03688717  0.02812195  0.07236052]]
Buffer size: 40000
Saving to checkpoint...
[2085] Average reward: -1.0
Predicted reward: [[-0.7993119  -0.6398468  -0.6664438  -0.79115486]]
Buffer size: 40000
Saving to checkpoint...
[2090] Average reward: -0.6
Predicted reward: [[-0.25782323 -0.27779603 -0.2281797  -0.23878098]]
Buffer size: 40000
Saving to checkpoint...
[2095] Average reward: -0.8
Predicted reward: [[-0.56629205 -0.43784523 -0.45903778 -0.5690398 ]]
Buffer size: 40000
Saving to checkpoint...
[2100] Average reward: -1.0
Predicted reward: [[-0.02043772 -0.16224813 -0.04766321 -0.00871921]]
Buffer size: 40000
Saving to checkpoint...
[2105] Averag

[2365] Average reward: -0.8
Predicted reward: [[-0.15801036 -0.19513547 -0.15238607 -0.14265835]]
Buffer size: 40000
Saving to checkpoint...
[2370] Average reward: -0.6
Predicted reward: [[-0.21543789 -0.20612311 -0.18683445 -0.21166897]]
Buffer size: 40000
Saving to checkpoint...
[2375] Average reward: -1.0
Predicted reward: [[-0.39072335 -0.35431635 -0.34637702 -0.3928851 ]]
Buffer size: 40000
Saving to checkpoint...
[2380] Average reward: -0.6
Predicted reward: [[-0.02580917  0.00373161 -0.01135075 -0.03438902]]
Buffer size: 40000
Saving to checkpoint...
[2385] Average reward: -0.8
Predicted reward: [[-0.39006412 -0.33132565 -0.3522309  -0.4084164 ]]
Buffer size: 40000
Saving to checkpoint...
[2390] Average reward: -1.0
Predicted reward: [[-0.20545268 -0.26874924 -0.21566296 -0.21036386]]
Buffer size: 40000
Saving to checkpoint...
[2395] Average reward: -1.0
Predicted reward: [[-0.15248752 -0.19242835 -0.15557241 -0.16220093]]
Buffer size: 40000
Saving to checkpoint...
[2400] Averag

[2660] Average reward: -0.8
Predicted reward: [[-0.20253205 -0.17320931 -0.17694223 -0.22261441]]
Buffer size: 40000
Saving to checkpoint...
[2665] Average reward: -0.6
Predicted reward: [[-0.10931671 -0.1176486  -0.10696304 -0.12059808]]
Buffer size: 40000
Saving to checkpoint...
[2670] Average reward: -1.0
Predicted reward: [[-0.23882568 -0.21751082 -0.21550941 -0.25199652]]
Buffer size: 40000
Saving to checkpoint...
[2675] Average reward: -0.2
Predicted reward: [[-0.14418268 -0.16095662 -0.13480282 -0.15541613]]
Buffer size: 40000
Saving to checkpoint...
[2680] Average reward: -1.0
Predicted reward: [[-0.45252812 -0.40119898 -0.40572774 -0.46942174]]
Buffer size: 40000
Saving to checkpoint...
[2685] Average reward: -0.6
Predicted reward: [[0.0533644  0.08396798 0.0660854  0.03593498]]
Buffer size: 40000
Saving to checkpoint...
[2690] Average reward: -1.0
Predicted reward: [[-0.13658118 -0.16177213 -0.13576424 -0.14063036]]
Buffer size: 40000
Saving to checkpoint...
[2695] Average re

[2955] Average reward: -0.8
Predicted reward: [[-0.42784715 -0.41317558 -0.40179634 -0.44036555]]
Buffer size: 40000
Saving to checkpoint...
[2960] Average reward: -0.8
Predicted reward: [[-0.07607567 -0.01651204 -0.04675156 -0.09719473]]
Buffer size: 40000
Saving to checkpoint...
[2965] Average reward: -0.6
Predicted reward: [[-0.09323096 -0.08485615 -0.08077586 -0.10693061]]
Buffer size: 40000
Saving to checkpoint...
[2970] Average reward: -1.0
Predicted reward: [[-0.15539575 -0.08828115 -0.12012839 -0.16896975]]
Buffer size: 40000
Saving to checkpoint...
[2975] Average reward: -1.0
Predicted reward: [[-0.12585759 -0.11367428 -0.11256218 -0.1329658 ]]
Buffer size: 40000
Saving to checkpoint...
[2980] Average reward: -0.8
Predicted reward: [[0.05455998 0.1320264  0.08742556 0.03612518]]
Buffer size: 40000
Saving to checkpoint...
[2985] Average reward: -0.8
Predicted reward: [[-0.25684094 -0.25088763 -0.23167527 -0.25401068]]
Buffer size: 40000
Saving to checkpoint...
[2990] Average re

[3250] Average reward: -0.8
Predicted reward: [[-0.02382088  0.00868344 -0.00275075 -0.03422916]]
Buffer size: 40000
Saving to checkpoint...
[3255] Average reward: -0.4
Predicted reward: [[-0.10823309 -0.08071876 -0.08393502 -0.12191665]]
Buffer size: 40000
Saving to checkpoint...
[3260] Average reward: -0.8
Predicted reward: [[-0.3273127  -0.30704427 -0.30334544 -0.34790635]]
Buffer size: 40000
Saving to checkpoint...
[3265] Average reward: -1.0
Predicted reward: [[-0.23818338 -0.2097168  -0.21720803 -0.25107503]]
Buffer size: 40000
Saving to checkpoint...
[3270] Average reward: -0.4
Predicted reward: [[-0.20990562 -0.18357265 -0.19138956 -0.22612512]]
Buffer size: 40000
Saving to checkpoint...
[3275] Average reward: -0.4
Predicted reward: [[0.4574898  0.46468723 0.45500487 0.45506728]]
Buffer size: 40000
Saving to checkpoint...
[3280] Average reward: -0.6
Predicted reward: [[-0.2317903  -0.22679412 -0.2237091  -0.23620594]]
Buffer size: 40000
Saving to checkpoint...
[3285] Average re

[3545] Average reward: -0.6
Predicted reward: [[-0.06751859 -0.07036328 -0.05803335 -0.08445168]]
Buffer size: 40000
Saving to checkpoint...
[3550] Average reward: -0.8
Predicted reward: [[0.1326127  0.15088105 0.14390409 0.128757  ]]
Buffer size: 40000
Saving to checkpoint...
[3555] Average reward: -0.6
Predicted reward: [[0.8931371  0.81210303 0.86397386 0.8964689 ]]
Buffer size: 40000
Saving to checkpoint...
[3560] Average reward: -0.8
Predicted reward: [[0.05522621 0.07160437 0.06590962 0.04786098]]
Buffer size: 40000
Saving to checkpoint...
[3565] Average reward: -0.6
Predicted reward: [[-0.2045238  -0.20656061 -0.2064054  -0.20708513]]
Buffer size: 40000
Saving to checkpoint...
[3570] Average reward: -0.6
Predicted reward: [[0.5456325 0.473356  0.5196612 0.5444714]]
Buffer size: 40000
Saving to checkpoint...


In [None]:
from matplotlib import pyplot as plt
a = hist[::1]
plt.plot(range(len(a)), a)
plt.show()

In [None]:
def test_rollout(t_max, env, close=True):
    import sys
    obs = agent.preprocess(env.reset())
    reward = 0
    for i in range(t_max):
        # print(agent.get_policy(obs))
        # act = agent.get_action(obs, greedy=True)[0]
        act = agent.get_action(obs, mode='greedy')[0][0]
        obs, r, dn, info = env.step(agent.action_wrapper(act))
        env.render()
        print(act, file=sys.stderr)
        time.sleep(0.05)
        obs = agent.preprocess(obs)
        reward += r
        if dn:
            break

    print("Total reward: {}".format(reward), file=sys.stderr)
    if close: env.close()

In [None]:
if action == 'test':
    test_rollout(10000, env, close=True)

In [None]:
# agent.train(4, t_max=500, min_buf_size=10)

In [None]:
# %lprun -f agent.train agent.train(1, t_max=500, buf_size=2000, min_buf_size=10)