In [1]:
import deep_rl

In [2]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [3]:
import torch
from torch.nn import functional as F
from torch.autograd import Variable
from torch import nn, optim
import torch.utils.data

# load as dask array
import dask.array as da
import dask
import h5py

import logging
import sys
import os
import glob
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm_notebook as tqdm

  from ._conv import register_converters as _register_converters


In [None]:
from deep_rl.utils import Config
from deep_rl.utils.logger import get_logger, get_default_log_dir

from deep_rl.network.network_heads import CategoricalActorCriticNet, QuantileNet, OptionCriticNet, DeterministicActorCriticNet, GaussianActorCriticNet
from deep_rl.network.network_bodies import FCBody

from deep_rl.component.task import ParallelizedTask

In [1]:
from world_models_sonic.models.vae import VAE6, loss_function_vae
from world_models_sonic.helpers.summarize import TorchSummarizeDf
from world_models_sonic.helpers.dataset import load_cache_data
from world_models_sonic.models.rnn import MDNRNN2
from world_models_sonic.models.inverse_model import InverseModel
from world_models_sonic.models.world_model import WorldModel
from world_models_sonic.custom_envs.env import make_env
from world_models_sonic import config
from world_models_sonic.helpers.deep_rl import PPOAgent, run_iterations, SonicWorldModelDeepRL, CategoricalWorldActorCriticNet

Importing 0 potential games...
Imported 0 games


  return f(*args, **kwds)
  from ._conv import register_converters as _register_converters


# Init

In [5]:
cuda = torch.cuda.is_available()
env_name = 'sonic256'
z_dim = 256  # latent dimensions

# RNN
action_dim = 10
image_size = 256

verbose = False  # Set this true to render (and make it go slower)

# NAME ='RNN_v3b_256im_512z_1512_v5_greenfield'
NAME = 'RNN_v3b_256im_512z_v6_greenfield'
ppo_save_file = './outputs/models/PPO_greenfields_256z_v5_wd_in_net.pkl'


save_file_rnn = './outputs/{NAME}/mdnrnn_state_dict.pkl'.format(NAME=NAME)
save_file_vae = './outputs/{NAME}/vae_state_dict.pkl'.format(NAME=NAME)
save_file_finv = './outputs/{NAME}/finv_state_dict.pkl'.format(NAME=NAME)

if not os.path.isdir('./outputs/{NAME}'.format(NAME=NAME)):
    os.makedirs('./outputs/{NAME}'.format(NAME=NAME))

# Log to file and stream
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger(NAME)

# World model

In [7]:
# Load VAE
# TODO swap z and k dim, since it's inconsistent with other models
vae = VAE6(image_size=image_size, z_dim=32, conv_dim=48, code_dim=8, k_dim=z_dim)
if cuda:
    vae.cuda()
    
# # Resume    
if os.path.isfile(save_file_vae):
    state_dict = torch.load(save_file_vae)
    vae.load_state_dict(state_dict)
    print('loaded save_file {save_file}'.format(save_file=save_file_vae))
    
# Load MDRNN
action_dim, hidden_size, n_mixture, temp = action_dim, 128, 3, 0.0

mdnrnn = MDNRNN2(z_dim, action_dim, hidden_size, n_mixture, temp)
if cuda:
    mdnrnn = mdnrnn.cuda()
state_dict = torch.load(save_file_rnn)
mdnrnn.load_state_dict(state_dict)
print('loaded {save_file}'.format(save_file=save_file_rnn))
    
finv = InverseModel(z_dim, action_dim, hidden_size=256)
if cuda:
    finv = finv.cuda()
state_dict = torch.load(save_file_finv)
finv.load_state_dict(state_dict)
print('loaded {save_file}'.format(save_file=save_file_finv))
    
world_model = WorldModel(vae, mdnrnn, finv)
world_model = world_model.eval() # Samples without randomness
if cuda:
    world_model = world_model.cuda()

loaded ./outputs/RNN_v3b_256im_512z_v6_greenfield/mdnrnn_state_dict.pkl


# summarize

In [10]:
from IPython.display import display

with torch.no_grad():
    img = np.random.randn(image_size, image_size, 3)
    action = np.array(np.random.randint(0,action_dim))[np.newaxis]
    action = Variable(torch.from_numpy(action)).float().cuda()[np.newaxis]
    gpu_img = Variable(torch.from_numpy(img[np.newaxis].transpose(0, 3, 1, 2))).float().cuda()
    if cuda:
        gpu_img = gpu_img.cuda()
    with TorchSummarizeDf(vae) as tdf:
        x, mu_vae, logvar_vae = vae.forward(gpu_img)
        z = vae.sample(mu_vae, logvar_vae)
        df_vae = tdf.make_df()

    display(df_vae[df_vae.level<2])
    
    with TorchSummarizeDf(mdnrnn) as tdf: 
        pi, mu, sigma, hidden_state = mdnrnn.forward(z.unsqueeze(1).repeat((1,2,1)))
        z_next = mdnrnn.sample(pi, mu, sigma)
        df_mdnrnn = tdf.make_df()
    
    display(df_mdnrnn)
    

    with TorchSummarizeDf(finv) as tdf:
        finv(z.repeat((1,2,1)), z_next)   
        df_finv = tdf.make_df()
    display(df_finv)

    with TorchSummarizeDf(world_model) as tdf:
        world_model(gpu_img, action)
        df_world_model = tdf.make_df()
    display(df_world_model[df_world_model.level<2])
    
    del img, action, gpu_img, x, mu, z, z_next, mu_vae, pi, sigma, logvar_vae

Total parameters 8909862
Total trainable parameters 8909862


Unnamed: 0,name,class_name,input_shape,output_shape,nb_params,level
4,encoder.0,BasicConv2d,"[(-1, 3, 256, 256)]","[(-1, 48, 256, 256)]",1440,1
42,encoder.1,ConvBlock5,"[(-1, 48, 256, 256)]","[(-1, 96, 128, 128)]",93213,1
80,encoder.2,ConvBlock5,"[(-1, 96, 128, 128)]","[(-1, 144, 64, 64)]",281034,1
118,encoder.3,ConvBlock5,"[(-1, 144, 64, 64)]","[(-1, 192, 32, 32)]",566055,1
156,encoder.4,ConvBlock5,"[(-1, 192, 32, 32)]","[(-1, 240, 16, 16)]",948276,1
194,encoder.5,ConvBlock5,"[(-1, 240, 16, 16)]","[(-1, 288, 8, 8)]",1427697,1
232,encoder.6,ConvBlock5,"[(-1, 288, 8, 8)]","[(-1, 32, 8, 8)]",351550,1
233,mu,Linear,"[(-1, 2048)]","[(-1, 256)]",524544,0
234,logvar,Linear,"[(-1, 2048)]","[(-1, 256)]",524544,0
235,z,Linear,"[(-1, 256)]","[(-1, 2048)]",526336,0


Total parameters 1778688
Total trainable parameters 1778688


Unnamed: 0,name,class_name,input_shape,output_shape,nb_params,level
1,rnn,LSTM,"[[(-1, 2, 266)], [[(-1, 1, 128)], [(-1, 1, 128...","[[(-1, 2, 128)], [[(-1, 1, 128)], [(-1, 1, 128...",202752,0
2,ln1,Linear,"[(-1, 128), (-1, 128)]","[(-1, 128), (-1, 128)]",16512,0
3,ln2,Linear,"[(-1, 128), (-1, 128)]","[(-1, 640), (-1, 640)]",82560,0
4,mdn,Linear,"[(-1, 640), (-1, 640)]","[(-1, 2304), (-1, 2304)]",1476864,0


Total parameters 199690
Total trainable parameters 199690


Unnamed: 0,name,class_name,input_shape,output_shape,nb_params,level
1,ln1,Linear,"[(-1, 2, 512)]","[(-1, 2, 256)]",131328,0
2,ln2,Linear,"[(-1, 2, 256)]","[(-1, 2, 256)]",65792,0
3,ln3,Linear,"[(-1, 2, 256)]","[(-1, 2, 10)]",2570,0


Total parameters 10888240
Total trainable parameters 10888240


Unnamed: 0,name,class_name,input_shape,output_shape,nb_params,level
233,vae.mu,Linear,"[(-1, 2048)]","[(-1, 256)]",524544,1
234,vae.logvar,Linear,"[(-1, 2048)]","[(-1, 256)]",524544,1
235,vae.z,Linear,"[(-1, 256)]","[(-1, 2048)]",526336,1
465,vae.sigmoid,Sigmoid,"[(-1, 3, 256, 256)]","[(-1, 3, 256, 256)]",0,1
466,mdnrnn.rnn,LSTM,"[[(-1, 1, 266)], [[(-1, 1, 128)], [(-1, 1, 128...","[[(-1, 1, 128)], [[(-1, 1, 128)], [(-1, 1, 128...",202752,1
467,mdnrnn.ln1,Linear,"[(-1, 128)]","[(-1, 128)]",16512,1
468,mdnrnn.ln2,Linear,"[(-1, 128)]","[(-1, 640)]",82560,1
469,mdnrnn.mdn,Linear,"[(-1, 640)]","[(-1, 2304)]",1476864,1


# Train

In [20]:
log_dir = log_dir='./outputs/{NAME}'.format(NAME=NAME)
z_state_dim=world_model.mdnrnn.z_dim + world_model.mdnrnn.hidden_size
print(log_dir)


def task_fn(log_dir):
    return SonicWorldModelDeepRL(
        env_fn=lambda: make_env(
            'sonic256', state='GreenHillZone', game='SonicTheHedgehog-Genesis'),
        max_steps=1000,
        log_dir=log_dir,
        verbose=verbose
    )


config = Config()

config.num_workers = 3
config.task_fn = lambda: ParallelizedTask(
    task_fn, config.num_workers, single_process=config.num_workers == 1)
config.optimizer_fn = lambda params: torch.optim.RMSprop(params, 3e-4)
config.network_fn = lambda state_dim, action_dim: CategoricalActorCriticNet(
    state_dim, action_dim, FCBody(z_state_dim, hidden_units=(64, 64), gate=F.relu), gpu=0 if cuda else -1, world_model_fn=lambda: world_model,
    render=(config.num_workers==1 and verbose)
)
config.discount = 0.99
config.logger = get_logger(
    NAME,
    file_name='deep_rl_ppo.log',
    level=logging.INFO,
    log_dir='./outputs/{NAME}'.format(NAME=NAME), )
config.use_gae = True
config.gae_tau = 0.95
config.entropy_weight = 0.0001
config.gradient_clip = 0.2
config.rollout_length = 128
config.optimization_epochs = 10
config.num_mini_batches = 4
config.ppo_ratio_clip = 0.2
config.iteration_log_interval = 1
agent = PPOAgent(config)
if os.path.isfile(ppo_save_file):
    print('loading', ppo_save_file)
    agent.load(ppo_save_file)

./outputs/RNN_v3b_256im_512z_v6_greenfield
game: SonicTheHedgehog-Genesis state: GreenHillZone.Act3
loading ./outputs/models/PPO_greenfields_256z_v5_wd_in_net.pkl


In [21]:
try:
    run_iterations(agent, log_dir=log_dir)
except:
    if config.num_workers==1:
        agent.task.tasks[0].env.close()
    else:
        [t.close() for t in agent.task.tasks]
    print("saving", ppo_save_file)
    agent.save(ppo_save_file)
    raise

2018-05-25 18:15:31,753 - RNN_v3b_256im_512z_v6_greenfield - INFO: total steps 128, min/mean/max reward 0.0000/0.0000/0.0000 of 1
2018-05-25 18:15:31,754 - RNN_v3b_256im_512z_v6_greenfield - INFO: running min/mean/max reward 0.0000/0.0000/0.0000 of 1 19.8994 s/rollout
2018-05-25 18:15:50,486 - RNN_v3b_256im_512z_v6_greenfield - INFO: total steps 256, min/mean/max reward 0.0000/0.0000/0.0000 of 1
2018-05-25 18:15:50,487 - RNN_v3b_256im_512z_v6_greenfield - INFO: running min/mean/max reward 0.0000/0.0000/0.0000 of 1 18.7332 s/rollout


saving ./outputs/models/PPO_greenfields_256z_v5_wd_in_net.pkl


KeyboardInterrupt: 


To monitor with tensorboard
```sh
cd ~/Documents/projects/retro_sonic_comp/world-models-pytorch/log 
tensorboard  --logdir .
#then open http://localhost:6006/#scalars
```
