In [1]:
import deep_rl

In [2]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [3]:
import torch
from torch.nn import functional as F
from torch.autograd import Variable
from torch import nn, optim
import torch.utils.data

# load as dask array
import dask.array as da
import dask
import h5py

import logging
import sys
import os
import glob
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm_notebook as tqdm

  from ._conv import register_converters as _register_converters


In [4]:
from world_models_sonic.models.vae import VAE6, loss_function_vae
from world_models_sonic.helpers.summarize import TorchSummarizeDf
from world_models_sonic.helpers.dataset import load_cache_data
from world_models_sonic.models.rnn import MDNRNN2
from world_models_sonic.models.inverse_model import InverseModel
from world_models_sonic.models.world_model import WorldModel
from world_models_sonic import config
from world_models_sonic.helpers.deep_rl import PPOAgent, run_iterations, SonicWorldModelDeepRL

Importing 0 potential games...
Imported 0 games


# Init

In [5]:
cuda = torch.cuda.is_available()
env_name = 'sonic256'
z_dim = 256  # latent dimensions

# RNN
action_dim = 10
image_size = 256

verbose = True  # Set this true to render (and make it go slower)

# NAME ='RNN_v3b_256im_512z_1512_v5_greenfield'
NAME = 'RNN_v3b_256im_512z_v6_greenfield'
ppo_save_file = './outputs/models/PPO_greenfields_256z_v3.pkl'
ppo_save_file_reward_norm = ppo_save_file.replace('.pkl',
                                                  '') + '_reward_norm.pkl'
ppo_save_file_state_norm = ppo_save_file.replace('.pkl',
                                                 '') + '_state_norm.pkl'

save_file_rnn = './outputs/{NAME}/mdnrnn_state_dict.pkl'.format(NAME=NAME)
save_file_vae = './outputs/{NAME}/vae_state_dict.pkl'.format(NAME=NAME)
save_file_finv = './outputs/{NAME}/finv_state_dict.pkl'.format(NAME=NAME)

if not os.path.isdir('./outputs/{NAME}'.format(NAME=NAME)):
    os.makedirs('./outputs/{NAME}'.format(NAME=NAME))

# Log to file and stream
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger(NAME)

# Load Data

# Load VAE

In [6]:
# Load VAE
# TODO swap z and k dim, since it's inconsistent with other models
vae = VAE6(image_size=image_size, z_dim=32, conv_dim=48, code_dim=8, k_dim=z_dim)
if cuda:
    vae.cuda()
    
# # Resume    
if os.path.isfile(save_file_vae):
    state_dict = torch.load(save_file_vae)
    vae.load_state_dict(state_dict)
    print('loaded save_file {save_file}'.format(save_file=save_file_vae))

loaded save_file ./outputs/RNN_v3b_256im_512z_v6_greenfield/vae_state_dict.pkl


# Load RNN

In [7]:
# Load MDRNN
action_dim, hidden_size, n_mixture, temp = action_dim, 128, 3, 0.0


mdnrnn = MDNRNN2(z_dim, action_dim, hidden_size, n_mixture, temp)

if cuda:
    mdnrnn = mdnrnn.cuda()
    
# # Resume?
if os.path.isfile(save_file_rnn):
    state_dict = torch.load(save_file_rnn)
    mdnrnn.load_state_dict(state_dict)
    print('loaded {save_file}'.format(save_file=save_file_rnn))

loaded ./outputs/RNN_v3b_256im_512z_v6_greenfield/mdnrnn_state_dict.pkl


# FInverse Model

In [8]:
finv = InverseModel(z_dim, action_dim, hidden_size=256).cuda()

# Resume?
if os.path.isfile(save_file_finv):
    state_dict = torch.load(save_file_finv)
    finv.load_state_dict(state_dict)
    print('loaded {save_file}'.format(save_file=save_file_finv))

loaded ./outputs/RNN_v3b_256im_512z_v6_greenfield/finv_state_dict.pkl


# Init

In [9]:
world_model = WorldModel(vae, mdnrnn, finv)
world_model = world_model.eval() # Samples without randomness

# summarize

In [10]:
from IPython.display import display

with torch.no_grad():
    img = np.random.randn(image_size, image_size, 3)
    action = np.array(np.random.randint(0,action_dim))[np.newaxis]
    action = Variable(torch.from_numpy(action)).float().cuda()[np.newaxis]
    gpu_img = Variable(torch.from_numpy(img[np.newaxis].transpose(0, 3, 1, 2))).float().cuda()
    if cuda:
        gpu_img = gpu_img.cuda()
    with TorchSummarizeDf(vae) as tdf:
        x, mu_vae, logvar_vae = vae.forward(gpu_img)
        z = vae.sample(mu_vae, logvar_vae)
        df_vae = tdf.make_df()

    display(df_vae[df_vae.level<2])
    
    with TorchSummarizeDf(mdnrnn) as tdf: 
        pi, mu, sigma, hidden_state = mdnrnn.forward(z.unsqueeze(1).repeat((1,2,1)), action.repeat((1,2)))
        z_next = mdnrnn.sample(pi, mu, sigma)
        df_mdnrnn = tdf.make_df()
    
    display(df_mdnrnn)
    

    with TorchSummarizeDf(finv) as tdf:
        finv(z.repeat((1,2,1)), z_next)   
        df_finv = tdf.make_df()
    display(df_finv)

    with TorchSummarizeDf(world_model) as tdf:
        world_model(gpu_img, action)
        df_world_model = tdf.make_df()
    display(df_world_model[df_world_model.level<2])
    
    del img, action, gpu_img, x, mu, z, z_next, mu_vae, pi, sigma, logvar_vae

Total parameters 8909862
Total trainable parameters 8909862


Unnamed: 0,name,class_name,input_shape,output_shape,nb_params,level
4,encoder.0,BasicConv2d,"[(-1, 3, 256, 256)]","[(-1, 48, 256, 256)]",1440,1
42,encoder.1,ConvBlock5,"[(-1, 48, 256, 256)]","[(-1, 96, 128, 128)]",93213,1
80,encoder.2,ConvBlock5,"[(-1, 96, 128, 128)]","[(-1, 144, 64, 64)]",281034,1
118,encoder.3,ConvBlock5,"[(-1, 144, 64, 64)]","[(-1, 192, 32, 32)]",566055,1
156,encoder.4,ConvBlock5,"[(-1, 192, 32, 32)]","[(-1, 240, 16, 16)]",948276,1
194,encoder.5,ConvBlock5,"[(-1, 240, 16, 16)]","[(-1, 288, 8, 8)]",1427697,1
232,encoder.6,ConvBlock5,"[(-1, 288, 8, 8)]","[(-1, 32, 8, 8)]",351550,1
233,mu,Linear,"[(-1, 2048)]","[(-1, 256)]",524544,0
234,logvar,Linear,"[(-1, 2048)]","[(-1, 256)]",524544,0
235,z,Linear,"[(-1, 256)]","[(-1, 2048)]",526336,0


Total parameters 1778688
Total trainable parameters 1778688


Unnamed: 0,name,class_name,input_shape,output_shape,nb_params,level
1,rnn,LSTM,"[[(-1, 2, 266)], [[(-1, 1, 128)], [(-1, 1, 128...","[[(-1, 2, 128)], [[(-1, 1, 128)], [(-1, 1, 128...",202752,0
2,ln1,Linear,"[(-1, 128), (-1, 128)]","[(-1, 128), (-1, 128)]",16512,0
3,ln2,Linear,"[(-1, 128), (-1, 128)]","[(-1, 640), (-1, 640)]",82560,0
4,mdn,Linear,"[(-1, 640), (-1, 640)]","[(-1, 2304), (-1, 2304)]",1476864,0


Total parameters 199690
Total trainable parameters 199690


Unnamed: 0,name,class_name,input_shape,output_shape,nb_params,level
1,ln1,Linear,"[(-1, 2, 512)]","[(-1, 2, 256)]",131328,0
2,ln2,Linear,"[(-1, 2, 256)]","[(-1, 2, 256)]",65792,0
3,ln3,Linear,"[(-1, 2, 256)]","[(-1, 2, 10)]",2570,0


Total parameters 10888240
Total trainable parameters 10888240


Unnamed: 0,name,class_name,input_shape,output_shape,nb_params,level
233,vae.mu,Linear,"[(-1, 2048)]","[(-1, 256)]",524544,1
234,vae.logvar,Linear,"[(-1, 2048)]","[(-1, 256)]",524544,1
235,vae.z,Linear,"[(-1, 256)]","[(-1, 2048)]",526336,1
465,vae.sigmoid,Sigmoid,"[(-1, 3, 256, 256)]","[(-1, 3, 256, 256)]",0,1
466,mdnrnn.rnn,LSTM,"[[(-1, 1, 266)], [[(-1, 1, 128)], [(-1, 1, 128...","[[(-1, 1, 128)], [[(-1, 1, 128)], [(-1, 1, 128...",202752,1
467,mdnrnn.ln1,Linear,"[(-1, 128)]","[(-1, 128)]",16512,1
468,mdnrnn.ln2,Linear,"[(-1, 128)]","[(-1, 640)]",82560,1
469,mdnrnn.mdn,Linear,"[(-1, 640)]","[(-1, 2304)]",1476864,1


# Env wrappers

In [11]:
from deep_rl.utils import Config
from deep_rl.utils.logger import get_logger, get_default_log_dir

from deep_rl.network.network_heads import CategoricalActorCriticNet, QuantileNet, OptionCriticNet, DeterministicActorCriticNet, GaussianActorCriticNet
from deep_rl.network.network_bodies import FCBody

from deep_rl.component.task import ParallelizedTask
from deep_rl.utils.misc import run_episodes, run_iterations

# Train

In [12]:
import datetime
timestamp = datetime.datetime.utcnow().strftime('%Y%m%d_%H-%M-%S')
log_dir = get_default_log_dir(os.path.basename(ppo_save_file)+timestamp)
print(log_dir)
task_fn = lambda log_dir: SonicWorldModelDeepRL(
    'sonic256', 
    max_steps=1000, 
    log_dir=log_dir, 
    world_model_func=lambda :world_model,
    state='GreenHillZone',
    game='SonicTheHedgehog-Genesis',
    verbose=verbose
)

config = Config()

config.num_workers = 1
config.task_fn = lambda: ParallelizedTask(task_fn, config.num_workers, single_process=config.num_workers==1)
config.optimizer_fn = lambda params: torch.optim.RMSprop(params, 1e-3)
config.network_fn = lambda state_dim, action_dim: CategoricalActorCriticNet(
        state_dim, action_dim, FCBody(state_dim), gpu=-1)
config.discount = 0.99
config.logger = get_logger(NAME, file_name='deep_rl_ppo.log', level=logging.INFO, log_dir='./outputs/{NAME}'.format(NAME=NAME))
config.use_gae = True
config.gae_tau = 0.95
config.entropy_weight = 0.0001
config.gradient_clip = 0.4
config.rollout_length = 128
config.optimization_epochs = 10
config.num_mini_batches = 4
config.ppo_ratio_clip = 0.2
config.iteration_log_interval = 10
agent=PPOAgent(config)
env = agent.task.tasks[0].env
if os.path.isfile(ppo_save_file):
    print('loading', ppo_save_file)
    agent.load(ppo_save_file)
    agent.config.state_normalizer.load_state_dict(torch.load(ppo_save_file_state_norm))
    agent.config.reward_normalizer.load_state_dict(torch.load(ppo_save_file_reward_norm))

./log/PPO_greenfields_256z_v2.pkl20180524_21-12-07-180525-051207
game: SonicTheHedgehog-Genesis state: GreenHillZone.Act3
reseting to GreenHillZone.Act1.state
loading ./outputs/models/PPO_greenfields_256z_v2.pkl


In [13]:
try:
    run_iterations(agent)
except:
    agent.task.tasks[0].env.close()
    print("saving", ppo_save_file)
    agent.save(ppo_save_file)
    torch.save(agent.config.state_normalizer.state_dict(), ppo_save_file_state_norm)
    torch.save(agent.config.reward_normalizer.state_dict(), ppo_save_file_reward_norm)
    raise

saving ./outputs/models/PPO_greenfields_256z_v2.pkl


KeyboardInterrupt: 

In [None]:
agent.save(ppo_save_file)
torch.save(agent.config.state_normalizer.state_dict(), ppo_save_file_state_norm)
torch.save(agent.config.reward_normalizer.state_dict(), ppo_save_file_reward_norm)

TODO:
- [ ] save normalizers?

To monitor with tensorboard
```sh
cd ~/Documents/projects/retro_sonic_comp/world-models-pytorch/log 
tensorboard  --logdir 
#then open http://localhost:6006/#scalars
```

For video
- Screencast with Kazaam
- with handbreak
    - profile: normal
    - cropping:
        - right: 670
        - bottom: 250
    - web optimized

In [None]:
config.logger.info('total steps %d, mean/max/min reward %f/%f/%f' % (
    agent.total_steps, np.mean(agent.last_episode_rewards),
    np.max(agent.last_episode_rewards),
    np.min(agent.last_episode_rewards)
))

In [None]:
# config.logger.vanilla_logger
config.logger??

In [None]:
# %debug

In [None]:
config.logger.scalar_summary('1',1)
config.logger.scalar_summary??
config.logger.writer.file_writer.flush()

In [None]:
config.logger.scalar_summary('1',1)

# Debug

In [None]:

task = task_fn(log_dir)
try:
    task.env.reset()
    task.env.render()
    for i in tqdm(range(100)):
        action = task.env.action_space.sample()
        task.env.step(action)
        task.env.unwrapped.step(action)
        task.env.render()
except:
    task.env.unwrapped.close()
    raise

In [None]:
task.env.unwrapped.close()

In [None]:
env = make_env('sonic256')
env.reset()
env.render()
for i in tqdm(range(1000)):
    action = env.action_space.sample()
    env.step(action)
    env.render()

## deubg mdrnn TODO

Oh is it because observations got transformed when I saved them to disc?? To uint and back

In [None]:
task = task_fn(log_dir)

In [None]:
world_model.eval()

In [None]:
env = task.env.env.env.env
env.reset()
action = env.action_space.sample()
observation, reward, done, info = task.env.env.env.env.step(action)
action = np.array(action)
plt.imshow(observation)

In [None]:
observation = (observation * 255.0).astype(np.uint8)
action = action.astype(np.uint8)
# done_data = np.concatenate(done_data, axis=0).astype(np.uint8)
# reward_data = np.concatenate(reward_data, axis=0).astype(np.float32)


observation =(observation/ 255.).astype(np.float32)

# observations = da.from_array(h5py.File(data_cache_file, mode='r')['x'], chunks=(chunksize, image_size, image_size, 3))
action = action[None].astype(np.uint8)
# rewards = rewards[:, None].astype(np.float32)
# dones = dones.astype(np.uint8)
# print("Loaded from cache", data_cache_file)

In [None]:
action = torch.from_numpy(action).unsqueeze(0).cuda().float()
print(observation.shape)
observation = torch.from_numpy(observation).unsqueeze(0).transpose(1, 3).cuda()
action = action.repeat((1,2))
observation = observation.repeat((2,1,1,1))
print(observation.shape), action

In [None]:
_, mu_vae, logvar_vae = world_model.vae.forward(observation)
z = vae.sample(mu_vae, logvar_vae)
zz=z.reshape((2,16,16)).data
plt.imshow(zz[0])
plt.title('z')
plt.show()

x = world_model.vae.decode(z)
x = x.cpu().data[0].transpose(0,2).numpy()
plt.imshow(x)
plt.title('z decoded')

In [None]:

# print(pi.shape)
# for i in range(pi.size(2)):
#     pi[:,:,i,:]
print(pi.shape)
torch.distributions.Multinomial?

In [None]:
# world_model.mdnrnn.eval()
k

In [None]:
# mdnrnn.sample??
k = mdnrnn.multinomial_on_axis(pi, axis=2)
mu_k = (mu * k).sum(2)
sigma_k = (sigma * k).sum(2)
z_normals = torch.distributions.Normal(mu_k, sigma_k)

z_normals.rsample()

In [None]:
z = z.view(batch_size, seq_len, -1).cuda()
action = action.view(batch_size, seq_len).cuda()
pi, mu, sigma, hidden_state = mdnrnn.forward(z, action)
print(pi.shape, pi.dtype, pi.is_cuda)
print(mu.shape, mu.dtype, mu.is_cuda)
print(sigma.shape, sigma.dtype, sigma.is_cuda)
print(mdnrnn.training)
z_next_pred = mdnrnn.sample(pi, mu, sigma)
# print(mu)
zz=z_next_pred.reshape((2,16,16)).data
plt.imshow(zz[0])
plt.show()

x = world_model.vae.decode(z_next_pred[0])
x = x.cpu().data[0].transpose(0,2).numpy()
plt.imshow(x)

In [None]:
?np.testing.assert_almost_equal(pi.cpu().data.numpy(), 1, decimal=4)

In [None]:
pi.shape

In [None]:
# decode average?
for i in range(n_mixture):
    x = world_model.vae.decode(pi[0,:,i])[0]
    x = x.transpose(0,2).data.cpu().numpy()
    plt.imshow(x)
    plt.title('mixture %s'%i)
    plt.show()
    # x.shape

In [None]:
# pi, mu, sigma, hidden_state = world_model.mdnrnn.forward(z[:, None], action, hidden_state=None)
# z_next_pred = world_model.mdnrnn.sample(pi, mu, sigma)

# x = world_model.vae.decode(z_next_pred[0])
# x = x.cpu().data[0].transpose(0,2).numpy()
# plt.imshow(x)

In [None]:
# Plot reconstructions
def plot_results(loader, n=2, epoch=0, figsize=(9,6)):
    with torch.no_grad():
#         vae.eval()
#         mdnrnn.eval()

        observations, actions, rewards, dones = next(iter(loader))

        X = Variable(observations.transpose(1,3))
        _, channels, height, width = X.size()
        if cuda:
            X=X.cuda()
        Y, mu_vae, logvar = vae.forward(X)
        loss_recon, loss_KLD = loss_function_vae(Y, X, mu_vae, logvar)
        loss_vae = loss_recon + lambda_vae_kld * torch.abs(loss_KLD-C)

        # TODO do we want to sample in test or training mode?
        z_v = vae.sample(mu_vae, logvar)

        z_v = z_v.view(batch_size, seq_len, -1)
        Y = Y.view((batch_size, seq_len, channels, height, width))
        X = X.view((batch_size, seq_len, channels, height, width))
        loss_vae = loss_vae.view(batch_size, seq_len)
        actions = actions.view(batch_size, seq_len)

        # Forward
        actions_v = Variable(actions).float()


        if cuda:
            z_v=z_v.cuda()
            actions_v=actions_v.cuda()
        pi, mu, sigma, hidden_state = mdnrnn.forward(z_v, actions_v)
        z_true_next = z_v[:,1:]
        loss_mdn_rnn = mdnrnn.rnn_loss(z_true_next, pi[:,:-1], mu[:,:-1], sigma[:,:-1])

        mu2 = mu.mean(2).view((batch_size*seq_len, mdnrnn.z_dim))
        X_pred = vae.decode(mu2)
        X_pred = X_pred.view((batch_size, seq_len, channels, height, width))

        # Finv forward
        print(pi.shape, pi.dtype, pi.is_cuda)
        print(mu.shape, mu.dtype, mu.is_cuda)
        print(sigma.shape, sigma.dtype, sigma.is_cuda)
        print(mdnrnn.training)
        z_next_pred = mdnrnn.sample(pi, mu, sigma)
        action_pred = finv(z_v[:,1:], z_next_pred[:,:-1]).float()

        actions_v_hot = torch.eye(action_dim)[actions_v.long()].cuda()
        loss_inv = F.binary_cross_entropy_with_logits(action_pred, actions_v_hot[:,1:])
        action_pred_int = action_pred.max(-1)[1]
        print(action_pred_int)

        loss = loss_vae.mean(1) + loss_mdn_rnn.mean(1) + loss_inv.mean()

        for i in np.linspace(0,seq_len-2,n):
            batch = np.random.randint(0,batch_size)
            i=int(i)
            y=Y[batch][i].cpu().data.transpose(0,2).numpy()
            x_orig = X[batch][i].transpose(0,2).data.cpu().numpy()
            x_next = X[batch][i+1].transpose(0,2).data.cpu().numpy()
            x_pred = X_pred[batch][i].transpose(0,2).data.cpu().numpy()
            loss_vae_i = loss_vae[batch][i].cpu().data.item()
            loss_mdnrnn_i = loss_mdn_rnn[batch][i].cpu().data.item()
            loss_inv_i = loss_inv.cpu().data.item()
            loss_i = loss[batch].cpu().data.item()

            print('action_pred', action_pred_int[batch][i].data.cpu().item())
            print('action_true', actions_v[:,1:][batch][i].data.cpu().item())
            print('finv loss {:2.4f}'.format(loss_inv_i))

            plt.figure(figsize=figsize)

            plt.subplot(2, 3, 1)
            plt.axis("off")
            plt.title('original')
            plt.imshow(x_orig)

            plt.subplot(2, 3, 4)
            plt.axis("off")
            plt.imshow(y)
            plt.title('reconstructed \nloss_vae={:2.4f}'.format(loss_vae_i))

            plt.subplot(2, 3, 2)
            plt.axis("off")
            plt.imshow(x_next)
            plt.title('true next')

            plt.subplot(2, 3, 5)
            plt.axis("off")
            plt.imshow(x_pred)
            plt.title('pred next \nloss_mdnrnn={:2.4f}'.format(loss_mdnrnn_i))

            plt.subplot(2, 3, 3)
            plt.axis("off")
            plt.imshow(np.abs(x_orig-x_next))
            plt.title('actual changes')

            plt.subplot(2, 3, 6)
            plt.axis("off")
            plt.imshow(np.abs(y[i]-x_pred))
            plt.title('predicted changes')

            plt.suptitle('epoch {}, seq index {}, batch={}. loss {:2.4f}'.format(
                epoch, 
                i,
                batch,
                loss_i
            ))
    #         plt.subplots_adjust(wspace=-.4, hspace=.1)#, bottom=0.1, right=0.8, top=0.9)
            plt.show()
        


In [None]:
from world_models_sonic.config import base_vae_data_dir
data_cache_file = os.path.join(base_vae_data_dir, 'sonic_rnn_256_v30.hdf5')
seq_len=2
batch_size=1
chunksize=seq_len*20
loader_train, loader_test = load_cache_data(
    basedir=base_vae_data_dir, 
    env_name=env_name, 
    data_cache_file=data_cache_file, 
    image_size=image_size, 
    chunksize=chunksize, 
    action_dim=action_dim,
    batch_size=batch_size,
    seq_len=seq_len,
)
loader_train, loader_test
len(loader_train.dataset), len(loader_test.dataset)

In [None]:
lambda_vae_kld=1
C=0
plot_results(loader_train, n=4, epoch=0)

In [None]:
observations, actions, rewards, dones = next(iter(loader_train))
observations

In [None]:
plt.hist(observations.cpu().data.numpy().flatten(), bins=55)
plt.show()
plt.imshow(observations[0])

In [None]:
plt.hist(observation.cpu().data.numpy().flatten(), bins=55)
1

In [None]:
observation

In [None]:
x=((observation.cpu().data.numpy()*255).astype(np.uint8)/255.0).astype(np.float32)
torch.from_numpy(x)