This experiment, higher learning rate. Larger rollouts.

In [1]:
import deep_rl

In [2]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [3]:
import torch
from torch.nn import functional as F
from torch.autograd import Variable
from torch import nn, optim
import torch.utils.data

# load as dask array
import time

import logging
import sys
import os
import glob
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm_notebook as tqdm

In [4]:
from deep_rl.utils import Config
from deep_rl.utils.logger import get_logger, get_default_log_dir

from deep_rl.network.network_heads import CategoricalActorCriticNet, QuantileNet, OptionCriticNet, DeterministicActorCriticNet, GaussianActorCriticNet
from deep_rl.network.network_bodies import FCBody

from deep_rl.component.task import ParallelizedTask

In [5]:
from world_models_sonic.models.vae import VAE5, loss_function_vae
from world_models_sonic.helpers.summarize import TorchSummarizeDf
from world_models_sonic.helpers.dataset import load_cache_data
from world_models_sonic.models.rnn import MDNRNN2
from world_models_sonic.models.inverse_model import InverseModel
from world_models_sonic.models.world_model import WorldModel
from world_models_sonic.custom_envs.env import make_env
from world_models_sonic.custom_envs.wrappers import RandomGameReset
from world_models_sonic import config
from world_models_sonic.helpers.deep_rl import PPOAgent, run_iterations, SonicWorldModelDeepRL, CategoricalWorldActorCriticNet

Importing 0 potential games...
Imported 0 games


  from ._conv import register_converters as _register_converters


# Init

In [6]:
cuda = torch.cuda.is_available()
env_name = 'sonic256'
z_dim = 512  # latent dimensions

# RNN
action_dim = 10
image_size = 256

verbose = False  # Set this true to render (and make it go slower)

# NAME ='RNN_v3b_256im_512z_1512_v5_greenfield'
NAME = 'RNN_v3b_256im_512z_1512_v6d_VAE5_all'
CHECKPOINT_NAME = 'RNN_v3b_256im_512z_1512_v6c_VAE5_all'
ppo_save_file = './outputs/models/PPO_512z_all_d.pkl'
ppo_oldsave_file = './outputs/models/PPO_512z_all_c.pkl'

save_file_rnn = './outputs/{NAME}/mdnrnn_state_dict.pkl'.format(NAME=NAME)
save_file_vae = './outputs/{NAME}/vae_state_dict.pkl'.format(NAME=NAME)
save_file_finv = './outputs/{NAME}/finv_state_dict.pkl'.format(NAME=NAME)

checkpoint_file_rnn = './outputs/{NAME}/mdnrnn_state_dict.pkl'.format(NAME=CHECKPOINT_NAME)
checkpoint_file_vae = './outputs/{NAME}/vae_state_dict.pkl'.format(NAME=CHECKPOINT_NAME)
checkpoint_file_finv = './outputs/{NAME}/finv_state_dict.pkl'.format(NAME=CHECKPOINT_NAME)

if not os.path.isdir('./outputs/{NAME}'.format(NAME=NAME)):
    os.makedirs('./outputs/{NAME}'.format(NAME=NAME))

# Log to file and stream
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger(NAME)

deep_rl_logger = get_logger(
    NAME,
    file_name='deep_rl_ppo.log',
    level=logging.INFO,
    log_dir='./outputs/{NAME}'.format(NAME=NAME), )

# World model

In [7]:
# Load VAE
# TODO swap z and k dim, since it's inconsistent with other models
vae = VAE5(image_size=image_size, z_dim=128, conv_dim=64, code_dim=8, k_dim=z_dim)
if cuda:
    vae.cuda()

if os.path.isfile(checkpoint_file_vae):
    state_dict = torch.load(checkpoint_file_vae)
    vae.load_state_dict(state_dict)
    print('loaded vae checkpoint_file {save_file}'.format(save_file=checkpoint_file_vae))
    
if os.path.isfile(save_file_vae):
    state_dict = torch.load(save_file_vae)
    vae.load_state_dict(state_dict)
    print('loaded vae save_file {save_file}'.format(save_file=save_file_vae))
    
# Load MDRNN
action_dim, hidden_size, n_mixture, temp = action_dim, 512, 2, 0.0

mdnrnn = MDNRNN2(z_dim, action_dim, hidden_size, n_mixture, temp)
if cuda:
    mdnrnn = mdnrnn.cuda()
if os.path.isfile(checkpoint_file_rnn):
    state_dict = torch.load(checkpoint_file_rnn)
    mdnrnn.load_state_dict(state_dict)
    print('loaded mdnrnn checkpoint_file {save_file}'.format(save_file=checkpoint_file_rnn))
if os.path.isfile(save_file_rnn):
    state_dict = torch.load(save_file_rnn)
    mdnrnn.load_state_dict(state_dict)
    print('loaded mdnrnn save_file {save_file}'.format(save_file=save_file_rnn))
    
finv = InverseModel(z_dim, action_dim, hidden_size=256)
if cuda:
    finv = finv.cuda()
if os.path.isfile(checkpoint_file_finv):
    state_dict = torch.load(checkpoint_file_finv)
    finv.load_state_dict(state_dict)
    print('loaded finv checkpoint_file {save_file}'.format(save_file=checkpoint_file_finv))
if os.path.isfile(save_file_finv):
    state_dict = torch.load(save_file_finv)
    finv.load_state_dict(state_dict)
    print('loaded finv save_file {save_file}'.format(save_file=save_file_finv))
    
world_model = WorldModel(vae, mdnrnn, finv, logger=deep_rl_logger)
world_model = world_model.train() # Samples without randomness
if cuda:
    world_model = world_model.cuda()

loaded vae checkpoint_file ./outputs/RNN_v3b_256im_512z_1512_v6c_VAE5_all/vae_state_dict.pkl
loaded mdnrnn checkpoint_file ./outputs/RNN_v3b_256im_512z_1512_v6c_VAE5_all/mdnrnn_state_dict.pkl
loaded finv checkpoint_file ./outputs/RNN_v3b_256im_512z_1512_v6c_VAE5_all/finv_state_dict.pkl


In [8]:
import torch.optim.lr_scheduler
torch.cuda.empty_cache()
optimizer = optim.Adam(world_model.parameters(), lr=3e-5)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=8, verbose=True)

world_model.optimizer = optimizer
world_model.scheduler = scheduler

# Train

In [9]:
log_dir = log_dir='./outputs/{NAME}'.format(NAME=NAME)
z_state_dim=world_model.mdnrnn.z_dim + world_model.mdnrnn.hidden_size
print(log_dir)


def task_fn(log_dir):
    return SonicWorldModelDeepRL(
        env_fn=lambda: RandomGameReset(make_env(
            'sonic256')),
        max_steps=4000,
        log_dir=log_dir,
        verbose=verbose
    )


config = Config()

config.num_workers = 1
config.task_fn = lambda: ParallelizedTask(
    task_fn, config.num_workers, single_process=config.num_workers == 1)
config.optimizer_fn = lambda params: torch.optim.Adam(params, 1e-3, eps=1e-5)
config.network_fn = lambda state_dim, action_dim: CategoricalWorldActorCriticNet(
    state_dim, action_dim, FCBody(z_state_dim, hidden_units=(64, 64), gate=F.relu), gpu=0 if cuda else -1, world_model_fn=lambda: world_model,
    render=(config.num_workers==1 and verbose),
    z_shape=(32, 16)
)
config.discount = 0.99
config.logger = deep_rl_logger
config.use_gae = True
config.gae_tau = 0.95
config.entropy_weight = 0.00001
config.gradient_clip = 0.5
config.rollout_length = 64
config.optimization_epochs = 10
config.num_mini_batches = 8
config.ppo_ratio_clip = 0.3
config.iteration_log_interval = 10
agent = PPOAgent(config)

if os.path.isfile(ppo_oldsave_file):
    print('loading', ppo_oldsave_file, 'modified', time.ctime(os.path.getmtime(ppo_save_file)))
    agent.load(ppo_oldsave_file)
else:
    print("couldn't find save")

if os.path.isfile(ppo_save_file):
    print('loading', ppo_save_file, 'modified', time.ctime(os.path.getmtime(ppo_save_file)))
    agent.load(ppo_save_file)
else:
    print("couldn't find save")

# ppo_checkpoint_file = '%s/%s-%s-model-%s.pkl' % (log_dir, agent.__class__.__name__, config.tag, agent.task.name)
# if os.path.isfile(ppo_checkpoint_file):
#     print('loading ppo_checkpoint_file', ppo_checkpoint_file, 'modified', time.ctime(os.path.getmtime(ppo_checkpoint_file)))
#     agent.load(ppo_checkpoint_file)
# else:
#     print("couldn't find checkpoint")

./outputs/RNN_v3b_256im_512z_1512_v6d_VAE5_all
game: SonicTheHedgehog2-Genesis state: HillTopZone.Act1
reseting to ChemicalPlantZone.Act2.state
loading ./outputs/models/PPO_512z_all_c.pkl modified Sat May 26 15:29:48 2018
loading ./outputs/models/PPO_512z_all_d.pkl modified Sat May 26 15:29:48 2018


In [None]:
try:
    run_iterations(agent, log_dir=log_dir)
except:
    if config.num_workers==1:
        agent.task.tasks[0].env.close()
    else:
        [t.close() for t in agent.task.tasks]
    print("saving", ppo_save_file)
    agent.save(ppo_save_file)
    raise

2018-05-26 15:30:51,054 - RNN_v3b_256im_512z_1512_v6d_VAE5_all - INFO: total steps 64, min/mean/max reward 0.0000/0.0000/0.0000 of 1
2018-05-26 15:30:51,055 - RNN_v3b_256im_512z_1512_v6d_VAE5_all - INFO: running min/mean/max reward 0.0000/0.0000/0.0000 of 1 4.1229 s/rollout
2018-05-26 15:31:31,046 - RNN_v3b_256im_512z_1512_v6d_VAE5_all - INFO: total steps 704, min/mean/max reward 0.0000/0.0000/0.0000 of 1
2018-05-26 15:31:31,047 - RNN_v3b_256im_512z_1512_v6d_VAE5_all - INFO: running min/mean/max reward 0.0000/0.0000/0.0000 of 11 4.0105 s/rollout
2018-05-26 15:32:10,798 - RNN_v3b_256im_512z_1512_v6d_VAE5_all - INFO: total steps 1344, min/mean/max reward 0.0000/0.0000/0.0000 of 1
2018-05-26 15:32:10,799 - RNN_v3b_256im_512z_1512_v6d_VAE5_all - INFO: running min/mean/max reward 0.0000/0.0000/0.0000 of 21 3.9937 s/rollout
2018-05-26 15:32:49,879 - RNN_v3b_256im_512z_1512_v6d_VAE5_all - INFO: total steps 1984, min/mean/max reward 0.0000/0.0000/0.0000 of 1
2018-05-26 15:32:49,880 - RNN_v3b_2

In [None]:
%debug

In [None]:
[[h.shape for h in hh] for hh in hidden_states]
self.rnn(concat, [h[0] for h in hidden_state])

In [None]:
np.array([[[512]], [[512]]]).shape


To monitor with tensorboard
```sh
cd ~/Documents/projects/retro_sonic_comp/world-models-pytorch/log 
tensorboard  --logdir .
#then open http://localhost:6006/#scalars
```


# summarize

In [None]:
from IPython.display import display

with torch.no_grad():
    img = np.random.randn(image_size, image_size, 3)
    action = np.array(np.random.randint(0,action_dim))[np.newaxis]
    action = Variable(torch.from_numpy(action)).float().cuda()[np.newaxis]
    gpu_img = Variable(torch.from_numpy(img[np.newaxis].transpose(0, 3, 1, 2))).float().cuda()
    if cuda:
        gpu_img = gpu_img.cuda()
    with TorchSummarizeDf(vae) as tdf:
        x, mu_vae, logvar_vae = vae.forward(gpu_img)
        z = vae.sample(mu_vae, logvar_vae)
        df_vae = tdf.make_df()

    display(df_vae[df_vae.level<2])
    
    with TorchSummarizeDf(mdnrnn) as tdf: 
        pi, mu, sigma, hidden_state = mdnrnn.forward(z.unsqueeze(1).repeat((1,2,1)))
        z_next = mdnrnn.sample(pi, mu, sigma)
        df_mdnrnn = tdf.make_df()
    
    display(df_mdnrnn)
    

    with TorchSummarizeDf(finv) as tdf:
        finv(z.repeat((1,2,1)), z_next)   
        df_finv = tdf.make_df()
    display(df_finv)

    with TorchSummarizeDf(world_model) as tdf:
        world_model(gpu_img, action)
        df_world_model = tdf.make_df()
    display(df_world_model[df_world_model.level<2])
    
    del img, action, gpu_img, x, mu, z, z_next, mu_vae, pi, sigma, logvar_vae