This experiment, higher learning rate. Larger rollouts.

In [1]:
import deep_rl

In [2]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [3]:
import torch
from torch.nn import functional as F
from torch.autograd import Variable
from torch import nn, optim
import torch.utils.data

# load as dask array
import time

import logging
import sys
import os
import glob
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm_notebook as tqdm

In [4]:
from deep_rl.utils import Config
from deep_rl.utils.logger import get_logger, get_default_log_dir

from deep_rl.network.network_heads import CategoricalActorCriticNet, QuantileNet, OptionCriticNet, DeterministicActorCriticNet, GaussianActorCriticNet
from deep_rl.network.network_bodies import FCBody

from deep_rl.component.task import ParallelizedTask

In [5]:
from world_models_sonic.models.vae import VAE5, loss_function_vae
from world_models_sonic.helpers.summarize import TorchSummarizeDf
from world_models_sonic.helpers.dataset import load_cache_data
from world_models_sonic.models.rnn import MDNRNN2
from world_models_sonic.models.inverse_model import InverseModel
from world_models_sonic.models.world_model import WorldModel
from world_models_sonic.custom_envs.env import make_env
from world_models_sonic.custom_envs.wrappers import RandomGameReset
from world_models_sonic import config
from world_models_sonic.helpers.deep_rl import PPOAgent, run_iterations, SonicWorldModelDeepRL, CategoricalWorldActorCriticNet

Importing 0 potential games...
Imported 0 games


  from ._conv import register_converters as _register_converters


# Init

In [6]:
cuda = torch.cuda.is_available()
env_name = 'sonic256'
z_dim = 512  # latent dimensions

# RNN
action_dim = 10
image_size = 256

verbose = True  # Set this true to render (and make it go slower)

# NAME ='RNN_v3b_256im_512z_1512_v5_greenfield'
NAME = 'RNN_v3b_256im_512z_1512_v6c_VAE5_all'
ppo_save_file = './outputs/models/PPO_512z_all_c.pkl'


save_file_rnn = './outputs/{NAME}/mdnrnn_state_dict.pkl'.format(NAME=NAME)
save_file_vae = './outputs/{NAME}/vae_state_dict.pkl'.format(NAME=NAME)
save_file_finv = './outputs/{NAME}/finv_state_dict.pkl'.format(NAME=NAME)

if not os.path.isdir('./outputs/{NAME}'.format(NAME=NAME)):
    os.makedirs('./outputs/{NAME}'.format(NAME=NAME))

# Log to file and stream
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger(NAME)

# World model

In [7]:
# Load VAE
# TODO swap z and k dim, since it's inconsistent with other models
vae = VAE5(image_size=image_size, z_dim=128, conv_dim=64, code_dim=8, k_dim=z_dim)
if cuda:
    vae.cuda()
    
# # Resume    
if os.path.isfile(save_file_vae):
    state_dict = torch.load(save_file_vae)
    vae.load_state_dict(state_dict)
    print('loaded save_file {save_file}'.format(save_file=save_file_vae))
    
# Load MDRNN
action_dim, hidden_size, n_mixture, temp = action_dim, 512, 2, 0.0

mdnrnn = MDNRNN2(z_dim, action_dim, hidden_size, n_mixture, temp)
if cuda:
    mdnrnn = mdnrnn.cuda()
state_dict = torch.load(save_file_rnn)
mdnrnn.load_state_dict(state_dict)
print('loaded {save_file}'.format(save_file=save_file_rnn))
    
finv = InverseModel(z_dim, action_dim, hidden_size=256)
if cuda:
    finv = finv.cuda()
state_dict = torch.load(save_file_finv)
finv.load_state_dict(state_dict)
print('loaded {save_file}'.format(save_file=save_file_finv))
    
world_model = WorldModel(vae, mdnrnn, finv)
world_model = world_model.eval() # Samples without randomness
if cuda:
    world_model = world_model.cuda()

loaded save_file ./outputs/RNN_v3b_256im_512z_1512_v6c_VAE5_all/vae_state_dict.pkl
loaded ./outputs/RNN_v3b_256im_512z_1512_v6c_VAE5_all/mdnrnn_state_dict.pkl
loaded ./outputs/RNN_v3b_256im_512z_1512_v6c_VAE5_all/finv_state_dict.pkl


# summarize

In [8]:
from IPython.display import display

with torch.no_grad():
    img = np.random.randn(image_size, image_size, 3)
    action = np.array(np.random.randint(0,action_dim))[np.newaxis]
    action = Variable(torch.from_numpy(action)).float().cuda()[np.newaxis]
    gpu_img = Variable(torch.from_numpy(img[np.newaxis].transpose(0, 3, 1, 2))).float().cuda()
    if cuda:
        gpu_img = gpu_img.cuda()
    with TorchSummarizeDf(vae) as tdf:
        x, mu_vae, logvar_vae = vae.forward(gpu_img)
        z = vae.sample(mu_vae, logvar_vae)
        df_vae = tdf.make_df()

    display(df_vae[df_vae.level<2])
    
    with TorchSummarizeDf(mdnrnn) as tdf: 
        pi, mu, sigma, hidden_state = mdnrnn.forward(z.unsqueeze(1).repeat((1,2,1)))
        z_next = mdnrnn.sample(pi, mu, sigma)
        df_mdnrnn = tdf.make_df()
    
    display(df_mdnrnn)
    

    with TorchSummarizeDf(finv) as tdf:
        finv(z.repeat((1,2,1)), z_next)   
        df_finv = tdf.make_df()
    display(df_finv)

    with TorchSummarizeDf(world_model) as tdf:
        world_model(gpu_img, action)
        df_world_model = tdf.make_df()
    display(df_world_model[df_world_model.level<2])
    
    del img, action, gpu_img, x, mu, z, z_next, mu_vae, pi, sigma, logvar_vae

Total parameters 21877123
Total trainable parameters 21877123


Unnamed: 0,name,class_name,input_shape,output_shape,nb_params,level
4,encoder.0,ConvBlock4,"[(-1, 3, 256, 256)]","[(-1, 64, 256, 256)]",1920,1
8,encoder.1,ConvBlock4,"[(-1, 64, 256, 256)]","[(-1, 128, 128, 128)]",131456,1
12,encoder.2,ConvBlock4,"[(-1, 128, 128, 128)]","[(-1, 192, 64, 64)]",393792,1
16,encoder.3,ConvBlock4,"[(-1, 192, 64, 64)]","[(-1, 256, 32, 32)]",787200,1
20,encoder.4,ConvBlock4,"[(-1, 256, 32, 32)]","[(-1, 320, 16, 16)]",1311680,1
24,encoder.5,ConvBlock4,"[(-1, 320, 16, 16)]","[(-1, 384, 8, 8)]",1967232,1
25,encoder.6,Conv2d,"[(-1, 384, 8, 8)]","[(-1, 128, 8, 8)]",49280,1
26,mu,Linear,"[(-1, 8192)]","[(-1, 512)]",4194816,0
27,logvar,Linear,"[(-1, 8192)]","[(-1, 512)]",4194816,0
28,z,Linear,"[(-1, 512)]","[(-1, 8192)]",4202496,0


Total parameters 11565056
Total trainable parameters 11565056


Unnamed: 0,name,class_name,input_shape,output_shape,nb_params,level
1,rnn,LSTM,"[[(-1, 2, 522)], [[(-1, 1, 512)], [(-1, 1, 512...","[[(-1, 2, 512)], [[(-1, 1, 512)], [(-1, 1, 512...",2121728,0
2,ln1,Linear,"[(-1, 512), (-1, 512)]","[(-1, 512), (-1, 512)]",262656,0
3,ln2,Linear,"[(-1, 512), (-1, 512)]","[(-1, 2560), (-1, 2560)]",1313280,0
4,mdn,Linear,"[(-1, 2560), (-1, 2560)]","[(-1, 3072), (-1, 3072)]",7867392,0


Total parameters 330762
Total trainable parameters 330762


Unnamed: 0,name,class_name,input_shape,output_shape,nb_params,level
1,ln1,Linear,"[(-1, 2, 1024)]","[(-1, 2, 256)]",262400,0
2,ln2,Linear,"[(-1, 2, 256)]","[(-1, 2, 256)]",65792,0
3,ln3,Linear,"[(-1, 2, 256)]","[(-1, 2, 10)]",2570,0


Total parameters 33772941
Total trainable parameters 33772941


Unnamed: 0,name,class_name,input_shape,output_shape,nb_params,level
26,vae.mu,Linear,"[(-1, 8192)]","[(-1, 512)]",4194816,1
27,vae.logvar,Linear,"[(-1, 8192)]","[(-1, 512)]",4194816,1
28,vae.z,Linear,"[(-1, 512)]","[(-1, 8192)]",4202496,1
54,vae.sigmoid,Sigmoid,"[(-1, 3, 256, 256)]","[(-1, 3, 256, 256)]",0,1
55,mdnrnn.rnn,LSTM,"[[(-1, 1, 522)], [[(-1, 1, 512)], [(-1, 1, 512...","[[(-1, 1, 512)], [[(-1, 1, 512)], [(-1, 1, 512...",2121728,1
56,mdnrnn.ln1,Linear,"[(-1, 512)]","[(-1, 512)]",262656,1
57,mdnrnn.ln2,Linear,"[(-1, 512)]","[(-1, 2560)]",1313280,1
58,mdnrnn.mdn,Linear,"[(-1, 2560)]","[(-1, 3072)]",7867392,1


# Train

In [9]:
log_dir = log_dir='./outputs/{NAME}'.format(NAME=NAME)
z_state_dim=world_model.mdnrnn.z_dim + world_model.mdnrnn.hidden_size
print(log_dir)


def task_fn(log_dir):
    return SonicWorldModelDeepRL(
        env_fn=lambda: RandomGameReset(make_env(
            'sonic256')),
        max_steps=4000,
        log_dir=log_dir,
        verbose=verbose
    )


config = Config()

config.num_workers = 1
config.task_fn = lambda: ParallelizedTask(
    task_fn, config.num_workers, single_process=config.num_workers == 1)
config.optimizer_fn = lambda params: torch.optim.Adam(params, 3e-4, eps=1e-5)
config.network_fn = lambda state_dim, action_dim: CategoricalWorldActorCriticNet(
    state_dim, action_dim, FCBody(z_state_dim, hidden_units=(64, 64), gate=F.relu), gpu=0 if cuda else -1, world_model_fn=lambda: world_model,
    render=(config.num_workers==1 and verbose),
    z_shape=(32, 16)
)
config.discount = 0.99
config.logger = get_logger(
    NAME,
    file_name='deep_rl_ppo.log',
    level=logging.INFO,
    log_dir='./outputs/{NAME}'.format(NAME=NAME), )
config.use_gae = True
config.gae_tau = 0.95
config.entropy_weight = 0.001
config.gradient_clip = 0.5
config.rollout_length = 512
config.optimization_epochs = 10
config.num_mini_batches = 32
config.ppo_ratio_clip = 0.3
config.iteration_log_interval = 10
agent = PPOAgent(config)
if os.path.isfile(ppo_save_file):
    print('loading', ppo_save_file, 'modified', time.ctime(os.path.getmtime(ppo_save_file)))
    agent.load(ppo_save_file)
else:
    print("couldn't find save")

# ppo_checkpoint_file = '%s/%s-%s-model-%s.pkl' % (log_dir, agent.__class__.__name__, config.tag, agent.task.name)
# if os.path.isfile(ppo_checkpoint_file):
#     print('loading ppo_checkpoint_file', ppo_checkpoint_file, 'modified', time.ctime(os.path.getmtime(ppo_checkpoint_file)))
#     agent.load(ppo_checkpoint_file)
# else:
#     print("couldn't find checkpoint")

./outputs/RNN_v3b_256im_512z_1512_v6c_VAE5_all
game: SonicAndKnuckles3-Genesis state: LaunchBaseZone.Act2
game: SonicTheHedgehog2-Genesis state: MysticCaveZone.Act1
reseting to AngelIslandZone.Act1.state
reseting to WingFortressZone.state
loading ./outputs/models/PPO_512z_all_c.pkl modified Fri May 25 21:09:35 2018
reseting to WingFortressZone.state
reseting to WingFortressZone.state
reseting to OilOceanZone.Act1.state
reseting to CarnivalNightZone.Act2.state
reseting to MetropolisZone.Act2.state
reseting to MysticCaveZone.Act2.state
reseting to OilOceanZone.Act2.state
reseting to MysticCaveZone.Act2.state
reseting to WingFortressZone.state
reseting to WingFortressZone.state
reseting to LavaReefZone.Act2.state
reseting to MarbleGardenZone.Act1.state
reseting to DeathEggZone.Act2.state
reseting to OilOceanZone.Act2.state
reseting to HillTopZone.Act1.state
reseting to MysticCaveZone.Act2.state
reseting to MushroomHillZone.Act2.state
reseting to EmeraldHillZone.Act2.state
reseting to Casi

reseting to AquaticRuinZone.Act1.state
reseting to SandopolisZone.Act2.state
reseting to AquaticRuinZone.Act2.state
reseting to MarbleGardenZone.Act2.state
reseting to OilOceanZone.Act1.state
reseting to EmeraldHillZone.Act1.state
reseting to AquaticRuinZone.Act2.state
reseting to MysticCaveZone.Act2.state
reseting to AquaticRuinZone.Act1.state
reseting to SandopolisZone.Act2.state
reseting to IcecapZone.Act1.state
reseting to MetropolisZone.Act1.state
reseting to OilOceanZone.Act1.state
reseting to WingFortressZone.state
reseting to IcecapZone.Act1.state
reseting to AquaticRuinZone.Act2.state
reseting to HillTopZone.Act1.state
reseting to CasinoNightZone.Act1.state
reseting to LaunchBaseZone.Act1.state
reseting to AquaticRuinZone.Act2.state
reseting to AquaticRuinZone.Act2.state
reseting to WingFortressZone.state
reseting to MushroomHillZone.Act1.state
reseting to ChemicalPlantZone.Act1.state
reseting to LavaReefZone.Act2.state
reseting to CasinoNightZone.Act1.state
reseting to Hydroc

reseting to EmeraldHillZone.Act2.state
reseting to MarbleGardenZone.Act1.state
reseting to ChemicalPlantZone.Act2.state
reseting to AngelIslandZone.Act1.state
reseting to MetropolisZone.Act1.state
reseting to DeathEggZone.Act2.state
reseting to DeathEggZone.Act1.state
reseting to HillTopZone.Act1.state
reseting to WingFortressZone.state
reseting to WingFortressZone.state
reseting to LavaReefZone.Act2.state
reseting to HydrocityZone.Act2.state
reseting to SandopolisZone.Act1.state
reseting to HydrocityZone.Act2.state
reseting to IcecapZone.Act2.state
reseting to DeathEggZone.Act2.state
reseting to EmeraldHillZone.Act2.state
reseting to CarnivalNightZone.Act2.state
reseting to EmeraldHillZone.Act2.state
reseting to MarbleGardenZone.Act2.state
reseting to CasinoNightZone.Act1.state
reseting to AquaticRuinZone.Act1.state
reseting to SandopolisZone.Act1.state
reseting to SandopolisZone.Act2.state
reseting to ChemicalPlantZone.Act2.state
reseting to DeathEggZone.Act1.state
reseting to Aquati

Process ProcessWrapper-1:
Process ProcessWrapper-2:
Traceback (most recent call last):
  File "/home/wassname/.pyenv/versions/3.5.3/lib/python3.5/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/media/oldhome/wassname/Documents/projects/retro_sonic_comp/DeepRL/deep_rl/component/task.py", line 177, in run
    op, data = self.pipe.recv()
  File "/home/wassname/.pyenv/versions/3.5.3/lib/python3.5/multiprocessing/connection.py", line 250, in recv
    buf = self._recv_bytes()
  File "/home/wassname/.pyenv/versions/3.5.3/lib/python3.5/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
  File "/home/wassname/.pyenv/versions/3.5.3/lib/python3.5/multiprocessing/connection.py", line 407, in _recv_bytes
    buf = self._recv(4)
  File "/home/wassname/.pyenv/versions/3.5.3/lib/python3.5/multiprocessing/connection.py", line 379, in _recv
    chunk = read(handle, remaining)
  File "/media/oldhome/wassname/Documents/projec

In [10]:
try:
    run_iterations(agent, log_dir=log_dir)
except:
    if config.num_workers==1:
        agent.task.tasks[0].env.close()
    else:
        [t.close() for t in agent.task.tasks]
    print("saving", ppo_save_file)
    agent.save(ppo_save_file)
    raise

2018-05-25 21:13:23,623 - RNN_v3b_256im_512z_1512_v6c_VAE5_all - INFO: total steps 1024, min/mean/max reward 0.0000/0.6750/1.3500 of 2
2018-05-25 21:13:23,624 - RNN_v3b_256im_512z_1512_v6c_VAE5_all - INFO: running min/mean/max reward 0.6750/0.6750/0.6750 of 1 14.6599 s/rollout
2018-05-25 21:18:08,745 - RNN_v3b_256im_512z_1512_v6c_VAE5_all - INFO: total steps 11264, min/mean/max reward 1.8509/10.4010/18.9510 of 2
2018-05-25 21:18:08,746 - RNN_v3b_256im_512z_1512_v6c_VAE5_all - INFO: running min/mean/max reward 0.6750/8.4259/16.1538 of 11 14.2928 s/rollout
2018-05-25 21:22:48,856 - RNN_v3b_256im_512z_1512_v6c_VAE5_all - INFO: total steps 21504, min/mean/max reward 4.6975/15.6533/26.6091 of 2
2018-05-25 21:22:48,857 - RNN_v3b_256im_512z_1512_v6c_VAE5_all - INFO: running min/mean/max reward 0.6750/9.9723/16.1538 of 21 14.1560 s/rollout
2018-05-25 21:27:32,763 - RNN_v3b_256im_512z_1512_v6c_VAE5_all - INFO: total steps 31744, min/mean/max reward 5.9409/7.4569/8.9728 of 2
2018-05-25 21:27:32,

2018-05-25 23:21:35,411 - RNN_v3b_256im_512z_1512_v6c_VAE5_all - INFO: total steps 297984, min/mean/max reward 16.2402/21.6942/27.1483 of 2
2018-05-25 23:21:35,412 - RNN_v3b_256im_512z_1512_v6c_VAE5_all - INFO: running min/mean/max reward 3.1768/14.8894/35.7751 of 100 13.0797 s/rollout
2018-05-25 23:26:10,907 - RNN_v3b_256im_512z_1512_v6c_VAE5_all - INFO: total steps 308224, min/mean/max reward 19.4836/23.6183/27.7530 of 2
2018-05-25 23:26:10,909 - RNN_v3b_256im_512z_1512_v6c_VAE5_all - INFO: running min/mean/max reward 3.1768/14.9821/35.7751 of 100 13.1485 s/rollout
2018-05-25 23:30:45,352 - RNN_v3b_256im_512z_1512_v6c_VAE5_all - INFO: total steps 318464, min/mean/max reward 6.0592/11.5861/17.1130 of 2
2018-05-25 23:30:45,354 - RNN_v3b_256im_512z_1512_v6c_VAE5_all - INFO: running min/mean/max reward 3.1768/15.1579/30.1297 of 100 13.2133 s/rollout
2018-05-25 23:35:20,902 - RNN_v3b_256im_512z_1512_v6c_VAE5_all - INFO: total steps 328704, min/mean/max reward 4.0285/13.9563/23.8842 of 2
2

2018-05-26 01:29:32,708 - RNN_v3b_256im_512z_1512_v6c_VAE5_all - INFO: total steps 594944, min/mean/max reward 27.4205/31.5618/35.7031 of 2
2018-05-26 01:29:32,709 - RNN_v3b_256im_512z_1512_v6c_VAE5_all - INFO: running min/mean/max reward 2.9500/22.5786/73.6673 of 100 13.1479 s/rollout
2018-05-26 01:33:54,195 - RNN_v3b_256im_512z_1512_v6c_VAE5_all - INFO: total steps 605184, min/mean/max reward 19.4836/23.6183/27.7530 of 2
2018-05-26 01:33:54,196 - RNN_v3b_256im_512z_1512_v6c_VAE5_all - INFO: running min/mean/max reward 2.9500/23.8471/73.6673 of 100 13.1507 s/rollout
2018-05-26 01:38:16,709 - RNN_v3b_256im_512z_1512_v6c_VAE5_all - INFO: total steps 615424, min/mean/max reward 19.4836/23.6183/27.7530 of 2
2018-05-26 01:38:16,710 - RNN_v3b_256im_512z_1512_v6c_VAE5_all - INFO: running min/mean/max reward 2.9500/25.1395/73.6673 of 100 13.1581 s/rollout
2018-05-26 01:42:39,306 - RNN_v3b_256im_512z_1512_v6c_VAE5_all - INFO: total steps 625664, min/mean/max reward 3.6940/17.0568/30.4196 of 2


2018-05-26 03:35:55,545 - RNN_v3b_256im_512z_1512_v6c_VAE5_all - INFO: total steps 891904, min/mean/max reward 10.5132/13.1393/15.7654 of 2
2018-05-26 03:35:55,546 - RNN_v3b_256im_512z_1512_v6c_VAE5_all - INFO: running min/mean/max reward 2.4079/19.1951/58.4675 of 100 13.0577 s/rollout
2018-05-26 03:40:16,187 - RNN_v3b_256im_512z_1512_v6c_VAE5_all - INFO: total steps 902144, min/mean/max reward 6.6848/13.0265/19.3682 of 2
2018-05-26 03:40:16,188 - RNN_v3b_256im_512z_1512_v6c_VAE5_all - INFO: running min/mean/max reward 2.4079/20.0716/58.4675 of 100 13.0580 s/rollout
2018-05-26 03:44:37,056 - RNN_v3b_256im_512z_1512_v6c_VAE5_all - INFO: total steps 912384, min/mean/max reward 4.1155/9.9419/15.7683 of 2
2018-05-26 03:44:37,058 - RNN_v3b_256im_512z_1512_v6c_VAE5_all - INFO: running min/mean/max reward 8.0958/20.1745/58.4675 of 100 13.0559 s/rollout
2018-05-26 03:48:58,886 - RNN_v3b_256im_512z_1512_v6c_VAE5_all - INFO: total steps 922624, min/mean/max reward 19.4836/36.3888/53.2941 of 2
20

2018-05-26 05:42:05,803 - RNN_v3b_256im_512z_1512_v6c_VAE5_all - INFO: total steps 1188864, min/mean/max reward 20.7766/58.1597/95.5428 of 2
2018-05-26 05:42:05,804 - RNN_v3b_256im_512z_1512_v6c_VAE5_all - INFO: running min/mean/max reward 2.9452/11.3615/58.1597 of 100 13.0524 s/rollout
2018-05-26 05:46:27,336 - RNN_v3b_256im_512z_1512_v6c_VAE5_all - INFO: total steps 1199104, min/mean/max reward 13.1350/13.4693/13.8035 of 2
2018-05-26 05:46:27,337 - RNN_v3b_256im_512z_1512_v6c_VAE5_all - INFO: running min/mean/max reward 2.9452/14.2497/58.1597 of 100 13.0547 s/rollout
2018-05-26 05:50:48,983 - RNN_v3b_256im_512z_1512_v6c_VAE5_all - INFO: total steps 1209344, min/mean/max reward 11.2955/12.2153/13.1350 of 2
2018-05-26 05:50:48,983 - RNN_v3b_256im_512z_1512_v6c_VAE5_all - INFO: running min/mean/max reward 4.2058/14.8634/58.1597 of 100 13.0572 s/rollout
2018-05-26 05:55:09,982 - RNN_v3b_256im_512z_1512_v6c_VAE5_all - INFO: total steps 1219584, min/mean/max reward 13.0228/16.2661/19.5095 

saving ./outputs/models/PPO_512z_all_c.pkl


KeyboardInterrupt: 


To monitor with tensorboard
```sh
cd ~/Documents/projects/retro_sonic_comp/world-models-pytorch/log 
tensorboard  --logdir .
#then open http://localhost:6006/#scalars
```
