In [1]:
import gym
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" #### REMOVE THIS LINE WHEN CUDA CONFIG IS FIXED
import tensorflow as tf
import numpy as np
import sys
import time
import gym_snake
import json
import importlib
import random
from PIL import Image
from tqdm import tqdm

sys.path.insert(0, '..')
from utils.Buffer import ReplayBuffer
from utils.Conv import ConvHead
from rl.models import get_policy_architecture, get_value_architecture, get_vision_architecture
from algos.PPO import PPO_agent
from algos.DQN import DQN_agent
from utils.Loader import load_agent

# %load_ext line_profiler
%matplotlib notebook

In [2]:
"""
agent = load_agent(
    os.path.join('..', 'configs', 'cartpole.json'),
    run_name="cartpole-DQN-load_test",
    ckpt_folder=os.path.join('..', 'checkpoints')
)
"""

'\nagent = load_agent(\n    os.path.join(\'..\', \'configs\', \'cartpole.json\'),\n    run_name="cartpole-DQN-load_test",\n    ckpt_folder=os.path.join(\'..\', \'checkpoints\')\n)\n'

In [3]:
"""
hist = []
agent.load_from_checkpoint()
hist += agent.train(100, t_max=1000, display=True)
"""

'\nhist = []\nagent.load_from_checkpoint()\nhist += agent.train(100, t_max=1000, display=True)\n'

In [4]:
# tetris = importlib.import_module('pytris-effect.src.gameui')

In [5]:
run_name = 'snake'
action = 'train'
algo = ('DDQN', 'Dueling')

In [6]:
cfg_fp = os.path.join('..', 'configs', run_name + '.json')
with open(cfg_fp, 'r') as f:
    config = json.load(f)
ckpt_folder = os.path.join('..', 'checkpoints')

In [7]:
env_name = config['env']
if run_name == 'tetris':
    env = tetris.GameUI(graphic_mode=False, its_per_sec=2, sec_per_tick=0.5)
else:
    env = gym.make(env_name).env if 'use_raw_env' in config else gym.make(env_name)

In [8]:
env.reset().shape

(150, 150, 3)

In [9]:
def show_img(arr):
    scaling = 30
    data = np.zeros((scaling*arr.shape[0], scaling*arr.shape[1], 3), dtype=np.uint8)
    for i in range(data.shape[0]):
        for j in range(data.shape[1]):
            for k in range(data.shape[2]):
                data[i,j,k] = arr[i//scaling,j//scaling,k]
    img = Image.fromarray(data, 'RGB')
    # img.save('my.png')
    img.show()

In [10]:
if action == 'evaluate':
    %lprun -f env.drawMatrix env.drawMatrix()

In [11]:
if action == 'evaluate':
    arr = env.reset()[::10,::10,:]
    img = Image.fromarray(arr, 'RGB')
    img.show()
    #show_img(env.reset())

In [12]:
if False:
    action = 1
    obs, reward, dn, info = env.step(action)
    show_img(obs)
    print(reward, dn, info)

In [13]:
def do_step():
    _, _, dn, _ = env.step(random.choice(range(7)))
    if dn:
        env.reset()

In [14]:
#%timeit env.reset()

In [15]:
#%timeit do_step()

In [16]:
#%lprun -f env.get_obs do_step()

In [17]:
model = get_policy_architecture(env_name, algo=algo)
if 'DQN' in "\n".join(algo):
    target = tf.keras.models.clone_model(model)
else:
    value = get_value_architecture(env_name)

In [18]:
if 'DQN' in "\n".join(algo):
    agent = DQN_agent(
        model,
        # (TODO): Move args for ReplayBuffer into DQN
        ReplayBuffer(config.get("max_buf_size", 20000), mode='uniform'),
        target=target,
        env=env,
        mode=('DDQN'), # 'PER'
        learning_rate=config['learning_rate'],
        batch_size=config['batch_size'],
        update_steps=1,
        update_freq=4,
        multistep=5,
        alpha=1.5,
        beta=1.0,
        gamma=0.95,
        target_delay=1000,
        delta=1.0,
        # delta=0.000003,
        env_name=config['env_name'],
        algo_name='DQN',
        ckpt_folder=ckpt_folder,
        run_name='snake-DQN-pretrain-hard_update-uniform-multistep5-6'
    )
elif 'PPO' in "\n".join(algo):
    agent = PPO_agent(
        model,
        value,
        env=env,
        learning_rate=config['learning_rate'],
        minibatch_size=config['minibatch_size'],
        gamma=0.99,
        env_name=config['env_name'],
        run_name='snake-PPO-pretrain',
        ckpt_folder=ckpt_folder
    )

In [19]:
t_max = config['t_max']

In [20]:
p_buf = []

def collect_rollout(env, t_max, policy):
    s = agent.preprocess(env.reset())
    for t in range(t_max):
        act = policy(s)
        ss, r, dn, _ = env.step(agent.action_wrapper(act))
        ss = agent.preprocess(ss)
        p_buf.append([s, ss])
        s = ss
        if dn:
            break

In [21]:
pretrain = True
if pretrain and action == 'train': # only necessary for tasks on raw pixels (vision)
    embed = get_vision_architecture(env_name)
    # get some data from random interactions with the env
    for i in tqdm(range(500)):
        collect_rollout(env, t_max, lambda x: np.random.choice(4))
    print("Collected {} samples".format(len(p_buf)))
    head = ConvHead(embed, p_buf)
    head.train(6)
    
    out = head.model.layers[-2].output
    vision = tf.keras.Model(inputs=head.model.input, outputs=out)
    pretrained_model = get_policy_architecture(env_name, algo=algo, pretrain=vision)

100%|██████████| 500/500 [00:01<00:00, 491.44it/s]


Collected 15340 samples


  0%|          | 0/239 [00:00<?, ?it/s]

[0] Loss: [0.22921282 0.26290193 0.12300461 0.12752995 0.11707391 0.15183185
 0.12748024 0.09673619 0.19803968 0.16643393 0.13425826 0.1545998
 0.08676375 0.07365814 0.1318234  0.16210482 0.08700682 0.10762814
 0.07453926 0.10087536 0.10958755 0.14552695 0.14574803 0.1252699
 0.14847773 0.13899545 0.11840709 0.1272275  0.17439987 0.20915326
 0.19524778 0.17320062 0.15796182 0.12654833 0.12276284 0.11939698
 0.11618607 0.11209881 0.11710346 0.13498604 0.1002903  0.0934117
 0.10714595 0.13876866 0.09840295 0.10085947 0.12309004 0.12536836
 0.09456156 0.12299266 0.17776307 0.16987197 0.18386348 0.1752637
 0.11528827 0.11021537 0.12458029 0.11667293 0.11671758 0.11232901
 0.17166162 0.2000716  0.1212991  0.10975193 0.12694229 0.14155324
 0.13857977 0.1081904  0.1234675  0.16209778 0.20601407 0.20295605
 0.14436077 0.1246126  0.16309425 0.13841558 0.10621558 0.11281113
 0.19111906 0.22662428 0.11762059 0.12247466 0.133143   0.146335
 0.10317672 0.09064768 0.15939142 0.14965695 0.08122177 0.

  0%|          | 0/239 [00:00<?, ?it/s]

[1] Loss: [0.04029174 0.04489665 0.04832982 0.04216554 0.04574964 0.03660686
 0.06399546 0.0709819  0.04659969 0.05755356 0.02703181 0.0410537
 0.06195014 0.05658235 0.05925698 0.06547911 0.0494944  0.03984912
 0.0590899  0.0894155  0.04201325 0.03549167 0.05642061 0.05546975
 0.04336316 0.0411713  0.0590832  0.05630757 0.06264707 0.05360743
 0.06640325 0.05431776 0.0435481  0.05724861 0.03452111 0.05112505
 0.04369593 0.06383638 0.03901803 0.03699951 0.0457134  0.03315241
 0.06023696 0.0682625  0.04574271 0.04599875 0.06782895 0.0663662
 0.05664791 0.04524819 0.03904472 0.05024296 0.05516161 0.07909564
 0.05928591 0.05901829 0.05771784 0.05427924 0.06338928 0.06343716
 0.07984376 0.07576228 0.0369681  0.03674037 0.0592887  0.05927235
 0.07247845 0.0512591  0.05222975 0.05634772 0.06515016 0.0582396
 0.03207615 0.0326538  0.0630405  0.06188006 0.03117936 0.0262541
 0.04549374 0.04857834 0.06109101 0.05544307 0.04180067 0.04443981
 0.04023527 0.03678816 0.03933122 0.05109398 0.05258431 

  0%|          | 0/239 [00:00<?, ?it/s]

[2] Loss: [0.04300529 0.05741789 0.02949149 0.04367517 0.03796415 0.03583062
 0.03479189 0.0499949  0.04471017 0.04257011 0.02320085 0.03129834
 0.04933061 0.04795079 0.04829975 0.04549236 0.04839857 0.04308864
 0.04128105 0.03332819 0.06758244 0.05519855 0.04625449 0.05746291
 0.03083989 0.02950994 0.04948584 0.0449586  0.03712494 0.04419211
 0.03414374 0.04122782 0.07662147 0.05575178 0.0525094  0.05502785
 0.06685682 0.0702917  0.0658897  0.05627571 0.06856452 0.04952583
 0.04250816 0.04239889 0.04973667 0.05461917 0.03239356 0.03538334
 0.03866177 0.04959716 0.0182526  0.02116719 0.0453904  0.04593368
 0.03196087 0.05328958 0.02721568 0.02692061 0.05168466 0.04138392
 0.0228628  0.02629223 0.04111489 0.03645758 0.07712051 0.05757553
 0.01562515 0.01941114 0.04824956 0.04054635 0.04647271 0.0326884
 0.06433125 0.04401124 0.05352043 0.05576983 0.03709669 0.03177122
 0.02647421 0.04239442 0.04917867 0.03572547 0.027128   0.02561799
 0.05509416 0.06412733 0.04005285 0.03700274 0.048857

  0%|          | 0/239 [00:00<?, ?it/s]

[3] Loss: [0.05036193 0.05223871 0.03915145 0.04665126 0.04789048 0.04297912
 0.02845554 0.02402087 0.03389608 0.03468594 0.05227201 0.04506958
 0.05283229 0.05907467 0.03919463 0.03084424 0.04246049 0.04145984
 0.0454597  0.04823953 0.03712408 0.02895919 0.04885987 0.05138497
 0.04227104 0.02642248 0.0407482  0.05470575 0.03214408 0.02913301
 0.04885745 0.04802742 0.03768533 0.04337101 0.05466463 0.06225794
 0.0521436  0.04176433 0.06057215 0.05670295 0.03673268 0.02920806
 0.02992549 0.03115466 0.04034725 0.03978996 0.04650857 0.05453834
 0.04185672 0.0306126  0.03795753 0.05249297 0.035193   0.04538126
 0.04292319 0.06119847 0.03174902 0.02030126 0.02590664 0.02751601
 0.04405541 0.03668127 0.03525884 0.0280512  0.04851555 0.04703248
 0.04137217 0.05256812 0.04200057 0.03598359 0.04470385 0.027852
 0.04288898 0.0471281  0.03892104 0.03024575 0.04174748 0.03988989
 0.03891396 0.02401215 0.032889   0.03086352 0.03550171 0.04679535
 0.04869046 0.05822576 0.03319585 0.04736598 0.0749398

  0%|          | 0/239 [00:00<?, ?it/s]

[4] Loss: [0.03115119 0.02337028 0.03250003 0.02595126 0.04767809 0.05652824
 0.03612578 0.0319135  0.03194577 0.03662231 0.04242696 0.04907579
 0.05662163 0.04818492 0.06163044 0.05637051 0.03125926 0.02312207
 0.05064761 0.04791186 0.07090805 0.05862759 0.04892192 0.04956776
 0.04378021 0.04115648 0.07376156 0.07683154 0.03648309 0.03910259
 0.06425255 0.05342804 0.02642041 0.04181585 0.0381895  0.04179161
 0.03318887 0.02777741 0.02751732 0.03660753 0.0619569  0.04668658
 0.04168559 0.04625313 0.06463359 0.07359751 0.03982521 0.03723287
 0.04112272 0.04155536 0.02494253 0.02785578 0.04381517 0.03780241
 0.03077905 0.04271168 0.02545066 0.03186544 0.04462107 0.03396027
 0.04067665 0.03453539 0.02213425 0.02525274 0.04739132 0.04124903
 0.04064483 0.0355098  0.0369911  0.04060873 0.03594515 0.03693855
 0.04744289 0.04475193 0.03527401 0.03694947 0.02679177 0.0288363
 0.02706868 0.02149264 0.06025117 0.04886607 0.02505201 0.02392531
 0.04394412 0.06457856 0.02689337 0.02186836 0.023758

  0%|          | 0/239 [00:00<?, ?it/s]

[5] Loss: [0.03618854 0.04181621 0.04324885 0.04190331 0.05366328 0.0531702
 0.07341421 0.07209782 0.04251619 0.04543882 0.06115434 0.03894209
 0.03867267 0.04320775 0.02820909 0.01705811 0.0390548  0.03646542
 0.03399859 0.03570601 0.05371594 0.05006413 0.04925235 0.06753691
 0.041731   0.05494013 0.04512675 0.06199396 0.02812941 0.03949092
 0.04091566 0.04482966 0.0588223  0.05762515 0.02738499 0.03746816
 0.02943849 0.02849346 0.03074321 0.02956468 0.04450908 0.04128332
 0.04578294 0.04807186 0.02662081 0.01946401 0.02332863 0.02478257
 0.02773257 0.02531618 0.04134595 0.04525616 0.05078652 0.03732973
 0.0490083  0.05443706 0.03562012 0.04541828 0.06864782 0.05847213
 0.04511335 0.03765792 0.0427177  0.03425774 0.05188077 0.06697496
 0.03938233 0.04603267 0.03688255 0.03164749 0.03318258 0.04038472
 0.0363289  0.03798346 0.03690832 0.04066602 0.04289086 0.04217483
 0.02029346 0.01922898 0.01320202 0.0170903  0.06065333 0.04677466
 0.03525647 0.03658151 0.0352665  0.03376456 0.021434

In [22]:
agent.load_from_checkpoint()
### WARNING: will overwrite existing runs
hist = []

In [None]:
if action == 'train':
    if 'DQN' in "\n".join(algo):
        # fill buffer with some random samples
        for i in tqdm(range(500)):
            agent.collect_rollout(t_max=t_max, policy=lambda x: np.random.choice(4), train=False, display=False)
        #print(agent.epsilon)
        #agent.epsilon = 0.05
        hist += agent.train(epochs=config['train_epochs'], t_max=t_max, display=False)
    elif 'PPO' in "\n".join(algo):
        agent.train(epochs=config['train_epochs'], t_max=t_max, buf_size=3000, min_buf_size=600, display=False)

100%|██████████| 500/500 [00:01<00:00, 478.33it/s]


Training epochs:   0%|          | 0/20000 [00:00<?, ?it/s]

[5] Average reward: -1.0
Predicted reward: [[-0.17022735 -0.15910876 -0.14909872 -0.13629511]]
Buffer size: 14991
Saving to checkpoint...
[10] Average reward: -1.0
Predicted reward: [[-0.13725631 -0.05270778 -0.086668   -0.05755933]]
Buffer size: 15080
Saving to checkpoint...
[15] Average reward: -1.0
Predicted reward: [[-0.07598333 -0.01828219 -0.05113086 -0.03390285]]
Buffer size: 15129
Saving to checkpoint...
[20] Average reward: -1.0
Predicted reward: [[-0.08268044 -0.0236209  -0.04741701 -0.01205611]]
Buffer size: 15200
Saving to checkpoint...
[25] Average reward: -1.0
Predicted reward: [[-0.07056684 -0.04058401 -0.01149733 -0.00271421]]
Buffer size: 15275
Saving to checkpoint...
[30] Average reward: -1.0
Predicted reward: [[-0.06655702  0.01047969 -0.04606633 -0.00151416]]
Buffer size: 15434
Saving to checkpoint...
[35] Average reward: -1.0
Predicted reward: [[-0.06103962 -0.01771689 -0.037844   -0.04822411]]
Buffer size: 15521
Saving to checkpoint...
[40] Average reward: -0.8
Pr

[300] Average reward: -0.4
Predicted reward: [[-0.21999177 -0.19345817 -0.27021182 -0.24769203]]
Buffer size: 33767
Saving to checkpoint...
[305] Average reward: -0.8
Predicted reward: [[ 0.05773135 -0.14098406  0.05218875 -0.12361807]]
Buffer size: 34403
Saving to checkpoint...
[310] Average reward: -1.0
Predicted reward: [[0.26641798 0.11854055 0.2757152  0.2007705 ]]
Buffer size: 34919
Saving to checkpoint...
[315] Average reward: -1.0
Predicted reward: [[-0.02143336 -0.0580307   0.00232598 -0.11535773]]
Buffer size: 35236
Saving to checkpoint...
[320] Average reward: -1.0
Predicted reward: [[-0.09193728 -0.18681043 -0.02021901 -0.08801465]]
Buffer size: 35403
Saving to checkpoint...
[325] Average reward: -0.6
Predicted reward: [[-0.10219733  0.07371285 -0.05119086  0.12247571]]
Buffer size: 35841
Saving to checkpoint...
[330] Average reward: -0.4
Predicted reward: [[-0.13370283 -0.25286472  0.01367985 -0.00959967]]
Buffer size: 36466
Saving to checkpoint...
[335] Average reward: -1

In [None]:
from matplotlib import pyplot as plt
a = hist[::1]
plt.plot(range(len(a)), a)
plt.show()

In [None]:
def test_rollout(t_max, env, close=True):
    import sys
    obs = agent.preprocess(env.reset())
    reward = 0
    for i in range(t_max):
        # print(agent.get_policy(obs))
        # act = agent.get_action(obs, greedy=True)[0]
        act = agent.get_action(obs, mode='greedy')[0][0]
        obs, r, dn, info = env.step(agent.action_wrapper(act))
        env.render()
        print(act, file=sys.stderr)
        time.sleep(0.05)
        obs = agent.preprocess(obs)
        reward += r
        if dn:
            break

    print("Total reward: {}".format(reward), file=sys.stderr)
    if close: env.close()

In [None]:
if action == 'test':
    test_rollout(10000, env, close=True)

In [None]:
# agent.train(4, t_max=500, min_buf_size=10)

In [None]:
# %lprun -f agent.train agent.train(1, t_max=500, buf_size=2000, min_buf_size=10)