In [None]:
import gym
import os
import tensorflow as tf
import numpy as np
import sys
import time
import gym_snake
import json
import importlib
import random
from PIL import Image
from tqdm import tqdm
from matplotlib import pyplot as plt

sys.path.insert(0, '..')
os.chdir('..')
from utils.Buffer import ReplayBuffer
from utils.Conv import ConvHead
from rl.models import get_policy_architecture, get_value_architecture, get_vision_architecture
from algos.PPO import PPO_agent
from algos.DQN import DQN_agent
from utils.Loader import load_agent
from utils.utils import *
from utils.Env import get_env

%matplotlib notebook
# %load_ext line_profiler

In [None]:
agent = load('snake-model-zoo', 'snake')

In [None]:
agent.load('model_zoo/snake-03-06-22')

In [None]:
generate_video_from_rollout(agent, agent.env, t_max=1000)

In [None]:
agent = load('tetris-simple-7', 'tetris-simple', override=True)

In [None]:
agent.train(epochs=100, t_max=500, display=True)

In [None]:
agent.eval()
for _ in range(5):
    print(agent.collect_rollout(t_max=1000, display=True, eval=True))

In [None]:
runs = 5
epochs = 100
env = 'lunarlander'
base_name = 'lunarlander-DQN-eps_anneal_compare2'

In [None]:
cb = [
    {
        "type": "InitBufferCallback",
        "kwargs": {
            "episodes": 50
        }
    },
    {
        "type": "AnnealingSchedulerCallback",
        "kwargs": {
            "target": "epsilon",
            "schedule": [
                {
                    "type": "Schedule",
                    "kwargs": {
                        "length": 200,
                        "start_val": 0.4,
                        "end_val": 0.01,
                        "fn": "linear"
                    }
                }
            ]
        }
    }
]

In [None]:
hist1, hist2 = compare_algos(base_name, runs, epochs, env, 
    # dict({'algo': ['DDQN', 'PER', 'Dueling']})
    dict({'callbacks': cb})
)

In [None]:
plot_runs(hist1, hist2)

In [None]:
plot_runs(hist1, hist2) # blue is hist1, orange is hist2

In [None]:
# tetris = importlib.import_module('pytris-effect.src.gameui')

In [None]:
run_name = 'snake'
action = 'train'
algo = ('DDQN', 'Dueling')

In [None]:
cfg_fp = os.path.join('..', 'configs', run_name + '.json')
with open(cfg_fp, 'r') as f:
    config = json.load(f)
ckpt_folder = os.path.join('..', 'checkpoints')

In [None]:
env_name = config['env']
if run_name == 'tetris':
    env = tetris.GameUI(graphic_mode=False, its_per_sec=2, sec_per_tick=0.5)
else:
    env = gym.make(env_name).env if 'use_raw_env' in config else gym.make(env_name)

In [None]:
env.reset().shape

In [None]:
def show_img(arr, scaling=30):
    data = np.zeros((scaling*arr.shape[0], scaling*arr.shape[1], 3), dtype=np.uint8)
    for i in range(data.shape[0]):
        for j in range(data.shape[1]):
            for k in range(data.shape[2]):
                data[i,j,k] = arr[i//scaling,j//scaling,k]
    img = Image.fromarray(data, 'RGB')
    # img.save('my.png')
    img.show()

In [None]:
if action == 'evaluate':
    %lprun -f env.drawMatrix env.drawMatrix()

In [None]:
if action == 'evaluate':
    arr = env.reset()[::10,::10,:]
    img = Image.fromarray(arr, 'RGB')
    img.show()
    #show_img(env.reset())

In [None]:
if False:
    action = 1
    obs, reward, dn, info = env.step(action)
    show_img(obs)
    print(reward, dn, info)

In [None]:
def do_step():
    _, _, dn, _ = env.step(random.choice(range(7)))
    if dn:
        env.reset()

In [None]:
#%timeit env.reset()

In [None]:
#%timeit do_step()

In [None]:
#%lprun -f env.get_obs do_step()

In [None]:
model = get_policy_architecture(env_name, algo=algo)
if 'DQN' in "\n".join(algo):
    target = tf.keras.models.clone_model(model)
else:
    value = get_value_architecture(env_name)

In [None]:
if 'DQN' in "\n".join(algo):
    agent = DQN_agent(
        model,
        ReplayBuffer(config.get("max_buf_size", 20000), mode='uniform'),
        target=target,
        env=env,
        mode=('DDQN'), # 'PER'
        learning_rate=config['learning_rate'],
        batch_size=config['batch_size'],
        update_steps=1,
        update_freq=4,
        multistep=5,
        alpha=1.5,
        beta=1.0,
        gamma=0.95,
        target_delay=1000,
        delta=1.0,
        # delta=0.000003,
        env_name=config['env_name'],
        algo_name='DQN',
        ckpt_folder=ckpt_folder,
        run_name='snake-DQN-pretrain-hard_update-uniform-multistep5-7'
    )
elif 'PPO' in "\n".join(algo):
    agent = PPO_agent(
        model,
        value,
        env=env,
        learning_rate=config['learning_rate'],
        minibatch_size=config['minibatch_size'],
        gamma=0.99,
        env_name=config['env_name'],
        run_name='snake-PPO-pretrain',
        ckpt_folder=ckpt_folder
    )

In [None]:
t_max = config['t_max']

In [None]:
p_buf = []

def collect_rollout(env, t_max, policy):
    s = agent.preprocess(env.reset())
    for t in range(t_max):
        act = policy(s)
        ss, r, dn, _ = env.step(agent.action_wrapper(act))
        ss = agent.preprocess(ss)
        p_buf.append([s, ss])
        s = ss
        if dn:
            break

In [None]:
pretrain = True
if pretrain and action == 'train': # only necessary for tasks on raw pixels (vision)
    model = get_vision_architecture(env_name)
    out = tf.keras.layers.Dense(16, activation=None)(model.output)
    embed = tf.keras.Model(inputs=model.input, outputs=out)
    # get some data from random interactions with the env
    for i in tqdm(range(500)):
        collect_rollout(env, t_max, lambda x: np.random.choice(4))
    print("Collected {} samples".format(len(p_buf)))
    head = ConvHead(embed, p_buf)
    head.train(6)
    
    out = head.model.layers[-2].output
    vision = tf.keras.Model(inputs=head.model.input, outputs=out)
    pretrained_model = get_policy_architecture(env_name, algo=algo, head=vision)
    agent.set_model(pretrained_model)

In [None]:
pretrained_model = get_policy_architecture(env_name, algo=algo, head=vision)
agent.set_model(pretrained_model)

In [None]:
agent.load_from_checkpoint()
hist = []

In [None]:
if action == 'train':
    if 'DQN' in "\n".join(algo):
        # fill buffer with some random samples
        for i in tqdm(range(500)):
            agent.collect_rollout(t_max=t_max, policy=lambda x: np.random.choice(4), train=False, display=False)
        #print(agent.epsilon)
        #agent.epsilon = 0.05
        hist += agent.train(epochs=config['train_epochs'], t_max=t_max, display=False)
    elif 'PPO' in "\n".join(algo):
        agent.train(epochs=config['train_epochs'], t_max=t_max, buf_size=3000, min_buf_size=600, display=False)

In [None]:
from matplotlib import pyplot as plt
a = hist[::1]
plt.plot(range(len(a)), a)
plt.show()

In [None]:
def test_rollout(t_max, env, close=True):
    import sys
    obs = agent.preprocess(env.reset())
    reward = 0
    for i in range(t_max):
        # print(agent.get_policy(obs))
        # act = agent.get_action(obs, greedy=True)[0]
        act = agent.get_action(obs, mode='greedy')[0][0]
        obs, r, dn, info = env.step(agent.action_wrapper(act))
        env.render()
        print(act, file=sys.stderr)
        time.sleep(0.05)
        obs = agent.preprocess(obs)
        reward += r
        if dn:
            break

    print("Total reward: {}".format(reward), file=sys.stderr)
    if close: env.close()

In [None]:
if action == 'test':
    test_rollout(10000, env, close=True)

In [None]:
# agent.train(4, t_max=500, min_buf_size=10)

In [None]:
# %lprun -f agent.train agent.train(1, t_max=500, buf_size=2000, min_buf_size=10)