In [1]:
import gym
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" #### REMOVE THIS LINE WHEN CUDA CONFIG IS FIXED
import tensorflow as tf
import numpy as np
import sys
import time
import gym_snake
import json
import importlib
import random
from PIL import Image
from tqdm import tqdm

sys.path.insert(0, '..')
from utils.Buffer import ReplayBuffer
from utils.Conv import ConvHead
from rl.models import get_policy_architecture, get_value_architecture, get_vision_architecture
from algos.PPO import PPO_agent
from algos.DQN import DQN_agent
from utils.Loader import load_agent

# %load_ext line_profiler
%matplotlib notebook

In [2]:
"""
agent = load_agent(
    os.path.join('..', 'configs', 'cartpole.json'),
    run_name="cartpole-DQN-load_test",
    ckpt_folder=os.path.join('..', 'checkpoints')
)
"""

'\nagent = load_agent(\n    os.path.join(\'..\', \'configs\', \'cartpole.json\'),\n    run_name="cartpole-DQN-load_test",\n    ckpt_folder=os.path.join(\'..\', \'checkpoints\')\n)\n'

In [3]:
"""
hist = []
agent.load_from_checkpoint()
hist += agent.train(100, t_max=1000, display=True)
"""

'\nhist = []\nagent.load_from_checkpoint()\nhist += agent.train(100, t_max=1000, display=True)\n'

In [4]:
# tetris = importlib.import_module('pytris-effect.src.gameui')

In [5]:
run_name = 'snake'
action = 'train'
algo = ('DDQN', 'Dueling')

In [6]:
cfg_fp = os.path.join('..', 'configs', run_name + '.json')
with open(cfg_fp, 'r') as f:
    config = json.load(f)
ckpt_folder = os.path.join('..', 'checkpoints')

In [7]:
env_name = config['env']
if run_name == 'tetris':
    env = tetris.GameUI(graphic_mode=False, its_per_sec=2, sec_per_tick=0.5)
else:
    env = gym.make(env_name).env if 'use_raw_env' in config else gym.make(env_name)

In [8]:
env.reset().shape

(150, 150, 3)

In [9]:
def show_img(arr):
    scaling = 30
    data = np.zeros((scaling*arr.shape[0], scaling*arr.shape[1], 3), dtype=np.uint8)
    for i in range(data.shape[0]):
        for j in range(data.shape[1]):
            for k in range(data.shape[2]):
                data[i,j,k] = arr[i//scaling,j//scaling,k]
    img = Image.fromarray(data, 'RGB')
    # img.save('my.png')
    img.show()

In [10]:
if action == 'evaluate':
    %lprun -f env.drawMatrix env.drawMatrix()

In [11]:
if action == 'evaluate':
    arr = env.reset()[::10,::10,:]
    img = Image.fromarray(arr, 'RGB')
    img.show()
    #show_img(env.reset())

In [12]:
if False:
    action = 1
    obs, reward, dn, info = env.step(action)
    show_img(obs)
    print(reward, dn, info)

In [13]:
def do_step():
    _, _, dn, _ = env.step(random.choice(range(7)))
    if dn:
        env.reset()

In [14]:
#%timeit env.reset()

In [15]:
#%timeit do_step()

In [16]:
#%lprun -f env.get_obs do_step()

In [17]:
model = get_policy_architecture(env_name, algo=algo)
if 'DQN' in "\n".join(algo):
    target = tf.keras.models.clone_model(model)
else:
    value = get_value_architecture(env_name)

In [18]:
if 'DQN' in "\n".join(algo):
    agent = DQN_agent(
        model,
        # (TODO): Move args for ReplayBuffer into DQN
        ReplayBuffer(config.get("max_buf_size", 20000), mode='uniform'),
        target=target,
        env=env,
        mode=('DDQN'), # 'PER'
        learning_rate=config['learning_rate'],
        batch_size=config['batch_size'],
        update_steps=1,
        update_freq=4,
        multistep=5,
        alpha=1.5,
        beta=1.0,
        gamma=0.95,
        target_delay=1000,
        delta=1.0,
        # delta=0.000003,
        env_name=config['env_name'],
        algo_name='DQN',
        ckpt_folder=ckpt_folder,
        run_name='snake-DQN-pretrain-hard_update-uniform-multistep5-7'
    )
elif 'PPO' in "\n".join(algo):
    agent = PPO_agent(
        model,
        value,
        env=env,
        learning_rate=config['learning_rate'],
        minibatch_size=config['minibatch_size'],
        gamma=0.99,
        env_name=config['env_name'],
        run_name='snake-PPO-pretrain',
        ckpt_folder=ckpt_folder
    )

In [19]:
t_max = config['t_max']

In [20]:
p_buf = []

def collect_rollout(env, t_max, policy):
    s = agent.preprocess(env.reset())
    for t in range(t_max):
        act = policy(s)
        ss, r, dn, _ = env.step(agent.action_wrapper(act))
        ss = agent.preprocess(ss)
        p_buf.append([s, ss])
        s = ss
        if dn:
            break

In [21]:
pretrain = True
if pretrain and action == 'train': # only necessary for tasks on raw pixels (vision)
    model = get_vision_architecture(env_name)
    out = tf.keras.layers.Dense(16, activation=None)(model.output)
    embed = tf.keras.Model(inputs=model.input, outputs=out)
    # get some data from random interactions with the env
    for i in tqdm(range(500)):
        collect_rollout(env, t_max, lambda x: np.random.choice(4))
    print("Collected {} samples".format(len(p_buf)))
    head = ConvHead(embed, p_buf)
    head.train(6)
    
    out = head.model.layers[-2].output
    vision = tf.keras.Model(inputs=head.model.input, outputs=out)
    pretrained_model = get_policy_architecture(env_name, algo=algo, head=vision)
    agent.set_model(pretrained_model)

100%|██████████| 500/500 [00:01<00:00, 457.76it/s]


Collected 15340 samples


  0%|          | 0/239 [00:00<?, ?it/s]

[0] Loss: [0.16207144 0.16422944 0.17887095 0.17409125 0.19734879 0.21781886
 0.23387843 0.22500688 0.23673627 0.1937929  0.24326453 0.25857937
 0.17051055 0.20156612 0.21865784 0.21861552 0.18118884 0.19640914
 0.2025929  0.22165535 0.20173624 0.20688468 0.21547227 0.17386135
 0.21331543 0.21538275 0.24038132 0.23202287 0.1916675  0.19486418
 0.15530969 0.13813093 0.20592816 0.21268854 0.17688681 0.16875999
 0.18357621 0.17544466 0.18546654 0.1651627  0.27501094 0.2845768
 0.27642527 0.29021144 0.17672037 0.17219049 0.21144046 0.18698743
 0.17732346 0.18187226 0.26007393 0.28570256 0.20874463 0.19590168
 0.2097269  0.20679654 0.17465702 0.17358631 0.22824326 0.24328586
 0.28171524 0.24776584 0.24015933 0.2051436  0.27498612 0.33187854
 0.15795347 0.17067455 0.21202    0.21311164 0.23877276 0.23929784
 0.16703236 0.14166465 0.17781574 0.16438174 0.15650536 0.18442966
 0.24041952 0.2079033  0.2299893  0.23280607 0.20981313 0.22526228
 0.24233896 0.20656031 0.23173124 0.22923258 0.270834

  0%|          | 0/239 [00:00<?, ?it/s]

[1] Loss: [0.06608044 0.07096907 0.0868551  0.06263551 0.04648707 0.04968729
 0.04721485 0.03682361 0.03976986 0.03505346 0.05929387 0.06991001
 0.05749486 0.05228494 0.04985422 0.04354837 0.04065986 0.04368109
 0.03940775 0.03472828 0.03732676 0.04097142 0.04393915 0.03408255
 0.03781205 0.04436415 0.02870641 0.02737951 0.03803553 0.04434922
 0.04586969 0.03829176 0.06061662 0.04370287 0.04050909 0.03675894
 0.04232659 0.04525478 0.07392286 0.07023486 0.05406747 0.05565314
 0.04819098 0.04888806 0.04207579 0.04111275 0.05111394 0.04756966
 0.06411922 0.04901112 0.06108206 0.0499912  0.03603443 0.03855403
 0.06306186 0.06362385 0.05657302 0.05354837 0.05325956 0.0515463
 0.05331401 0.04907485 0.07831968 0.09182312 0.04262915 0.03161718
 0.04138502 0.0437209  0.06222383 0.06522892 0.04171403 0.04547331
 0.05722599 0.06518463 0.04003819 0.03790011 0.06562233 0.05431224
 0.0596202  0.05330799 0.06059129 0.05579082 0.04254686 0.03628898
 0.04860356 0.04298042 0.05865681 0.04513944 0.038353

  0%|          | 0/239 [00:00<?, ?it/s]

[2] Loss: [0.03651669 0.04229024 0.05569711 0.05452789 0.03677306 0.03046176
 0.05449818 0.06663302 0.04291207 0.03697745 0.0641152  0.05121865
 0.04275385 0.04533607 0.03816033 0.03768339 0.04733712 0.04448603
 0.04699548 0.05901052 0.04663393 0.01956399 0.03435628 0.03192681
 0.03838997 0.05149022 0.04673856 0.03772877 0.04097513 0.05036984
 0.04807089 0.03889022 0.02323935 0.02445655 0.05227857 0.05944159
 0.04191131 0.04997728 0.05360296 0.05451753 0.03659081 0.03415033
 0.0493088  0.05013776 0.03447326 0.04404224 0.0602532  0.06551577
 0.06974792 0.05170915 0.06444231 0.04678233 0.05620944 0.07084816
 0.03593076 0.04066257 0.06110825 0.05054953 0.02757229 0.01887234
 0.06325857 0.06284911 0.02683798 0.03467828 0.03105195 0.03475263
 0.04164911 0.03778234 0.05091904 0.04359056 0.05845844 0.04506823
 0.03020418 0.04177947 0.0293866  0.04341442 0.04346883 0.04675227
 0.03172495 0.04121622 0.05378237 0.04296017 0.04023663 0.04158319
 0.04032715 0.03795957 0.03152762 0.04175606 0.07341

  0%|          | 0/239 [00:00<?, ?it/s]

[3] Loss: [0.03845309 0.04579066 0.05518716 0.05361039 0.03579831 0.02469344
 0.01971919 0.02784969 0.02648689 0.03380919 0.0244562  0.02571085
 0.03444288 0.03517738 0.04794062 0.06162735 0.01590949 0.02209456
 0.05152367 0.0386446  0.04308474 0.03793183 0.04706028 0.03035263
 0.0409775  0.05237413 0.03104175 0.04130465 0.02741797 0.0314894
 0.06332452 0.05259414 0.04055731 0.05016332 0.04082781 0.03484916
 0.02067821 0.02527783 0.02727887 0.02916835 0.06522166 0.04931908
 0.05360284 0.04298502 0.01890577 0.02068627 0.03861409 0.03070395
 0.02973093 0.04151338 0.0340594  0.04210494 0.02859811 0.02952087
 0.03721852 0.04967336 0.06754504 0.03998899 0.04377296 0.05087035
 0.03574838 0.03907729 0.03842021 0.04523966 0.05148607 0.03412048
 0.04267645 0.02766808 0.02516358 0.0329487  0.03992095 0.04851993
 0.05080765 0.04540567 0.06495339 0.07095345 0.04011546 0.03157413
 0.04927352 0.04822543 0.04200134 0.03076501 0.04469045 0.05348495
 0.04237165 0.03880769 0.04736083 0.05602064 0.027296

  0%|          | 0/239 [00:00<?, ?it/s]

[4] Loss: [0.07432232 0.04493627 0.04294223 0.04096032 0.03906918 0.05358807
 0.03667955 0.03302206 0.04028605 0.03529242 0.04998352 0.05795919
 0.05518873 0.05120972 0.0525756  0.04548169 0.03581625 0.03640327
 0.03020453 0.0397019  0.04524422 0.020797   0.03283569 0.03527138
 0.04068377 0.02317345 0.02977543 0.03561272 0.03218897 0.0342539
 0.02013012 0.02912482 0.03491794 0.03500136 0.05594141 0.05024605
 0.04663849 0.05703757 0.04874103 0.04875229 0.0503183  0.0423066
 0.04426927 0.04531255 0.03774369 0.02892774 0.03318269 0.02866987
 0.03315325 0.02994227 0.03539295 0.04820392 0.05243016 0.03742301
 0.05655768 0.0722189  0.0211629  0.01800555 0.04090874 0.03662692
 0.04305533 0.04159042 0.03205605 0.04015957 0.05638218 0.06665453
 0.04358838 0.04146018 0.05579618 0.05330103 0.04072233 0.03358936
 0.05197082 0.04254153 0.03696806 0.03605074 0.03581255 0.0426849
 0.04062825 0.04444099 0.0418625  0.02843546 0.03404472 0.02711272
 0.05813572 0.04098496 0.04770242 0.03558632 0.06245985

  0%|          | 0/239 [00:00<?, ?it/s]

[5] Loss: [0.04530404 0.06506164 0.02913353 0.03491594 0.05039276 0.05554381
 0.03593262 0.04023889 0.05723673 0.05027205 0.04109735 0.03817065
 0.0298716  0.03176341 0.04523849 0.03366725 0.03512605 0.04500847
 0.02953389 0.04669889 0.04236947 0.05278236 0.06045643 0.03392676
 0.04538282 0.03654708 0.02728711 0.02864299 0.03663272 0.03013776
 0.0242861  0.02630979 0.03278621 0.02076043 0.04418003 0.03399241
 0.02503235 0.03080395 0.06749966 0.07030056 0.03570505 0.04772815
 0.05695973 0.05585954 0.03430215 0.04935735 0.03622068 0.03977071
 0.05536965 0.03950199 0.06537244 0.04659582 0.0355219  0.02867193
 0.02926116 0.0245997  0.03663009 0.03199714 0.05411894 0.05270815
 0.07208761 0.07046887 0.04352183 0.04116809 0.04003951 0.03262783
 0.02278956 0.036877   0.04348972 0.04031864 0.02613353 0.03116534
 0.05848362 0.04721993 0.03004063 0.03663519 0.03499538 0.03794961
 0.04152729 0.05391518 0.04455994 0.04360804 0.04890003 0.0486824
 0.01865504 0.02369533 0.02997198 0.02785135 0.030314

TypeError: get_policy_architecture() got an unexpected keyword argument 'pretrain'

In [22]:
pretrained_model = get_policy_architecture(env_name, algo=algo, head=vision)
agent.set_model(pretrained_model)

In [23]:
agent.load_from_checkpoint()
hist = []

In [None]:
if action == 'train':
    if 'DQN' in "\n".join(algo):
        # fill buffer with some random samples
        for i in tqdm(range(500)):
            agent.collect_rollout(t_max=t_max, policy=lambda x: np.random.choice(4), train=False, display=False)
        #print(agent.epsilon)
        #agent.epsilon = 0.05
        hist += agent.train(epochs=config['train_epochs'], t_max=t_max, display=False)
    elif 'PPO' in "\n".join(algo):
        agent.train(epochs=config['train_epochs'], t_max=t_max, buf_size=3000, min_buf_size=600, display=False)

100%|██████████| 500/500 [00:01<00:00, 469.00it/s]


Training epochs:   0%|          | 0/20000 [00:00<?, ?it/s]

[5] Average reward: -1.0
Predicted reward: [[-0.04494638 -0.06519236 -0.06496485 -0.09102722]]
Buffer size: 14711
Saving to checkpoint...
[10] Average reward: -0.8
Predicted reward: [[-0.0541863  -0.14556536 -0.15801829 -0.136219  ]]
Buffer size: 14784
Saving to checkpoint...
[15] Average reward: -1.0
Predicted reward: [[-0.21347167 -0.13125837 -0.06782893 -0.00234007]]
Buffer size: 14910
Saving to checkpoint...
[20] Average reward: -0.8
Predicted reward: [[ 0.04583263  0.02079096  0.09753878 -0.06431076]]
Buffer size: 15035
Saving to checkpoint...
[25] Average reward: -0.8
Predicted reward: [[0.02698723 0.07697456 0.06024702 0.012311  ]]
Buffer size: 15109
Saving to checkpoint...
[30] Average reward: -1.0
Predicted reward: [[-0.08768529 -0.06441404  0.01772811  0.02166178]]
Buffer size: 15200
Saving to checkpoint...
[35] Average reward: -0.8
Predicted reward: [[ 0.01833391 -0.02038789 -0.03147763 -0.02695323]]
Buffer size: 15564
Saving to checkpoint...
[40] Average reward: -0.8
Predic

In [None]:
from matplotlib import pyplot as plt
a = hist[::1]
plt.plot(range(len(a)), a)
plt.show()

In [None]:
def test_rollout(t_max, env, close=True):
    import sys
    obs = agent.preprocess(env.reset())
    reward = 0
    for i in range(t_max):
        # print(agent.get_policy(obs))
        # act = agent.get_action(obs, greedy=True)[0]
        act = agent.get_action(obs, mode='greedy')[0][0]
        obs, r, dn, info = env.step(agent.action_wrapper(act))
        env.render()
        print(act, file=sys.stderr)
        time.sleep(0.05)
        obs = agent.preprocess(obs)
        reward += r
        if dn:
            break

    print("Total reward: {}".format(reward), file=sys.stderr)
    if close: env.close()

In [None]:
if action == 'test':
    test_rollout(10000, env, close=True)

In [None]:
# agent.train(4, t_max=500, min_buf_size=10)

In [None]:
# %lprun -f agent.train agent.train(1, t_max=500, buf_size=2000, min_buf_size=10)