This experiment, higher learning rate. Larger rollouts.

In [1]:
import deep_rl

In [2]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [3]:
import torch
from torch.nn import functional as F
from torch.autograd import Variable
from torch import nn, optim
import torch.utils.data

# load as dask array
import time

import logging
import sys
import os
import glob
import numpy as np
import datetime
from matplotlib import pyplot as plt
from tqdm import tqdm_notebook as tqdm

In [4]:
from deep_rl.utils.logger import get_logger, get_default_log_dir

from deep_rl.network.network_heads import CategoricalActorCriticNet, QuantileNet, OptionCriticNet, DeterministicActorCriticNet, GaussianActorCriticNet
from deep_rl.network.network_bodies import FCBody
from deep_rl.utils.normalizer import RunningStatsNormalizer, RescaleNormalizer
from deep_rl.component.task import ParallelizedTask

In [5]:
from world_models_sonic.models.vae import VAE5, loss_function_vae
from world_models_sonic.helpers.summarize import TorchSummarizeDf
from world_models_sonic.models.rnn import MDNRNN, MDNRNN2
from world_models_sonic.models.inverse_model import InverseModel
from world_models_sonic.models.world_model import WorldModel
from world_models_sonic.custom_envs.env import make_env
from world_models_sonic.custom_envs.wrappers import RandomGameReset
from world_models_sonic import config
from world_models_sonic.helpers.deep_rl import PPOAgent, run_iterations, SonicWorldModelDeepRL, CategoricalWorldActorCriticNet, Config

Importing 0 potential games...
Imported 0 games


# Init

In [6]:
cuda = torch.cuda.is_available()
env_name = 'sonic256'
z_dim = 512  # latent dimensions
channels = 3*4

# RNN
action_dim = 10
image_size = 128

verbose = True  # Set this true to render (and make it go slower)

NAME = 'RNN_v3b_128im_512z_1512_v6k_VAE5_all'
ppo_save_file = './outputs/{NAME}/PPO_512z_all_g.pkl'.format(NAME=NAME)

if not os.path.isdir('./outputs/{NAME}'.format(NAME=NAME)):
    os.makedirs('./outputs/{NAME}'.format(NAME=NAME))

# Log to file and stream
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger(NAME)

log_dir = log_dir='./outputs/{NAME}'.format(NAME=NAME)
print(log_dir)

deep_rl_logger = get_logger(
    NAME,
    file_name='deep_rl_ppo.log',
    level=logging.INFO,
    log_dir='./outputs/{NAME}'.format(NAME=NAME), )

./outputs/RNN_v3b_128im_512z_1512_v6k_VAE5_all


# World model

In [7]:
# Load VAE
# TODO swap z and k dim, since it's inconsistent with other models
vae = VAE5(image_size=image_size, z_dim=128, conv_dim=64, code_dim=8, k_dim=z_dim, channels=channels)
    
# Load MDRNN
action_dim, hidden_size, n_mixture, temp = action_dim, z_dim*2, 5, 0.0

mdnrnn = MDNRNN(z_dim, action_dim, hidden_size, n_mixture, temp)
    
finv = InverseModel(z_dim, action_dim, hidden_size=z_dim*2)
    
world_model = WorldModel(vae, mdnrnn, finv, logger=deep_rl_logger, lambda_vae_kld=1 / 256., lambda_finv=1/50, lambda_vae=1/8, lambda_loss=1000)
world_model = world_model.train()
if cuda:
    world_model = world_model.cuda()

In [8]:
# Optimiser for world models
import torch.optim.lr_scheduler
torch.cuda.empty_cache()
optimizer = optim.Adam(world_model.parameters(), lr=3e-5)

world_model.optimizer = optimizer

# Train

In [9]:
# from deep_rl.utils.normalizer import  BaseNormalizer

# class RunningStatsNormalizer(BaseNormalizer):
#     def __init__(self, read_only=False):
#         BaseNormalizer.__init__(self, read_only)
#         self.needs_reset = True
#         self.read_only = read_only

#     def reset(self, x_size):
#         self.m = np.zeros(x_size)
#         self.v = np.zeros(x_size)
#         self.n = 0.0
#         self.needs_reset = False

#     def state_dict(self):
#         return {'m': self.m, 'v': self.v, 'n': self.n}

#     def load_state_dict(self, stored):
#         self.m = stored['m']
#         self.v = stored['v']
#         self.n = stored['n']
#         self.needs_reset = False

#     def __call__(self, x):
#         if np.isscalar(x) or len(x.shape) == 1:
#             # if dim of x is 1, it can be interpreted as 1 vector entry or batches of scalar entry,
#             # fortunately resetting the size to 1 applies to both cases
#             if self.needs_reset: self.reset(1)
#             return self.nomalize_single(x)
#         elif len(x.shape) == 2:
#             if self.needs_reset: self.reset(x.shape[1])
#             new_x = np.zeros(x.shape)
#             for i in range(x.shape[0]):
#                 new_x[i] = self.nomalize_single(x[i])
#             return new_x
#         else:
#             assert 'Unsupported Shape'

#     def nomalize_single(self, x):
#         is_scalar = np.isscalar(x)
#         if is_scalar:
#             x = np.asarray([x])

#         if not self.read_only:
#             new_m = self.m * (self.n / (self.n + 1)) + x / (self.n + 1)
#             self.v = self.v * (self.n / (self.n + 1)) + (x - self.m) * (x - new_m) / (self.n + 1)
#             self.m = new_m
#             self.n += 1

#         std = (self.v + 1e-6) ** .5
#         x = (x - self.m) / std
#         if is_scalar:
#             x = np.asscalar(x)
#         return x

In [10]:
z_state_dim=world_model.mdnrnn.z_dim + world_model.mdnrnn.hidden_size  + world_model.mdnrnn.action_dim


def task_fn(log_dir):
    return SonicWorldModelDeepRL(
        env_fn=lambda: RandomGameReset(make_env(
            'sonic', max_episode_steps=1000, to_gray=False, image_size=image_size)),
        log_dir=log_dir,
        verbose=verbose
    )

config = Config()

verbose = False  # Set this true to render (and make it go slower)
config.num_workers = 1 if verbose else 8
config.task_fn = lambda: ParallelizedTask(
    task_fn, config.num_workers, single_process=config.num_workers == 1)
config.optimizer_fn = lambda params: torch.optim.Adam(params, 3e-4)
config.network_fn = lambda state_dim, action_dim: CategoricalWorldActorCriticNet(
    state_dim, action_dim, FCBody(z_state_dim, hidden_units=(64, 64), gate=F.relu), gpu=0 if cuda else -1, world_model_fn=lambda: world_model,
    render=(config.num_workers==1 and verbose),
    z_shape=(32, 16)
)
config.discount = 0.99
config.logger = deep_rl_logger
config.use_gae = True
config.gae_tau = 0.95
config.entropy_weight = 0.001
config.gradient_clip = 0.4
config.rollout_length = 1*34
config.optimization_epochs = 4
config.num_mini_batches = 8
config.ppo_ratio_clip = 0.2
config.iteration_log_interval = 20

config.train_world_model = True
config.world_model_batch_size = 2

# I tuned these so the intrinsic reward was 1) within an order of magnitude of the extrinsic. 2) smaller, 3) negative when stuck
# TODO use reward normalisers etc to reduce the need for these hyperparameters
config.curiosity = True
config.curiosity_only = True
config.curiosity_weight = 0.01
config.curiosity_boredom = 1 # how many standard deviations above the mean does it's new experience need to be, so it's not bored
config.intrinsic_reward_normalizer = RunningStatsNormalizer() #RescaleNormalizer()
config.intrinsic_reward_normalizer.reset(1)
config.reward_normalizer = RunningStatsNormalizer()
config.reward_normalizer.reset(1)
agent = PPOAgent(config)

print('rollout of ', config.rollout_length*config.num_workers)
print('mini batch', (config.rollout_length*config.num_workers)/config.num_mini_batches)
print('sequence of batch', (config.rollout_length))

if os.path.isfile(ppo_save_file):
    print('loading ppo_save_file', ppo_save_file, 'modified', time.ctime(os.path.getmtime(ppo_save_file)))
    agent.load(ppo_save_file)
    
    # also load normalizers
    state_dict = torch.load(ppo_save_file.replace('.pkl', '-intrinsic_reward_normalizer.pkl'))
    config.intrinsic_reward_normalizer.load_state_dict(state_dict)

    state_dict = torch.load(ppo_save_file.replace('.pkl', '-reward_normalizer.pkl'))
    config.reward_normalizer.load_state_dict(state_dict)
else:
    print("couldn't find save")

game: SonicTheHedgehog-Genesis state: MarbleZone.Act1
game: SonicTheHedgehog2-Genesis state: EmeraldHillZone.Act2
game: SonicAndKnuckles3-Genesis state: MushroomHillZone.Act2
game: SonicTheHedgehog2-Genesis state: MysticCaveZone.Act1
game: SonicTheHedgehog2-Genesis state: ChemicalPlantZone.Act1
game: SonicTheHedgehog-Genesis state: LabyrinthZone.Act2
game: SonicTheHedgehog-Genesis state: ScrapBrainZone.Act2
game: SonicAndKnuckles3-Genesis state: MarbleGardenZone.Act1
rollout of  272
mini batch 34.0
sequence of batch 34
loading ppo_save_file ./outputs/RNN_v3b_128im_512z_1512_v6k_VAE5_all/PPO_512z_all_g.pkl modified Fri Jun  8 20:44:01 2018


In [11]:
# # DEBUG

# # reset from checkpoint
# agent.load('./outputs/RNN_v3b_128im_512z_1512_v6j_VAE5_all/PPO_512z_all_g-20180606_02-06-59.pkl')

# # # Reset just rnn
# world_model.mdnrnn = MDNRNN(z_dim, action_dim, hidden_size, n_mixture, temp)
# world_model.mdnrnn.cuda()

# # if we want to reset the actor
# from deep_rl.network.network_heads import ActorCriticNet
# agent.network.network = ActorCriticNet(agent.network.z_state_dim, action_dim, FCBody(z_state_dim, hidden_units=(64, 64), gate=F.relu), None, None)
# agent.network.network.cuda()

In [None]:
try:
    run_iterations(agent, log_dir=log_dir)
except:
    if config.num_workers == 1:
        agent.task.tasks[0].env.close()
    else:
        [t.close() for t in agent.task.tasks]
    print("saving", ppo_save_file)
    agent.save(ppo_save_file)
    torch.save(config.intrinsic_reward_normalizer.state_dict(), ppo_save_file.replace('.pkl', '-intrinsic_reward_normalizer.pkl'))
    torch.save(config.reward_normalizer.state_dict(), ppo_save_file.replace('.pkl', '-reward_normalizer.pkl'))

    # Backup since it sometimes get's corrupted
    ts = datetime.datetime.utcnow().strftime('%Y%m%d_%H-%M-%S')
    print("saving backup",
          ppo_save_file.replace('.pkl', '-%s.pkl' % ts),)
    agent.save(ppo_save_file.replace('.pkl', '-%s.pkl' % ts))
    raise

2018-06-08 20:44:45,721 - RNN_v3b_128im_512z_1512_v6k_VAE5_all - INFO: loss_rnn=0.1273, loss_inv= 13.1328=0.0200 * 656.6401, loss_vae=11.4084=0.1250 * (91.1336 + 0.0039 * 34.2132)
2018-06-08 20:44:45,723 - RNN_v3b_128im_512z_1512_v6k_VAE5_all - INFO: total steps 272, min/mean/max reward 0.0000/0.0000/0.0000 of 8
2018-06-08 20:44:45,725 - RNN_v3b_128im_512z_1512_v6k_VAE5_all - INFO: running min/mean/max reward 0.0000/0.0000/0.0000 of 8 27.4534 step/s
2018-06-08 20:46:26,373 - RNN_v3b_128im_512z_1512_v6k_VAE5_all - INFO: loss_rnn=0.1728, loss_inv= 13.1178=0.0200 * 655.8893, loss_vae=7.6373=0.1250 * (60.9600 + 0.0039 * 35.3535)
2018-06-08 20:46:26,374 - RNN_v3b_128im_512z_1512_v6k_VAE5_all - INFO: total steps 2992, min/mean/max reward 0.0000/3.4119/9.0474 of 8
2018-06-08 20:46:26,375 - RNN_v3b_128im_512z_1512_v6k_VAE5_all - INFO: running min/mean/max reward 0.0000/1.8316/9.0474 of 88 27.0987 step/s
2018-06-08 20:48:04,399 - RNN_v3b_128im_512z_1512_v6k_VAE5_all - INFO: loss_rnn=0.1791, los

In [None]:
agent.save(ppo_save_file)
torch.save(config.intrinsic_reward_normalizer.state_dict(), ppo_save_file.replace('.pkl', '-intrinsic_reward_normalizer.pkl'))
torch.save(config.reward_normalizer.state_dict(), ppo_save_file.replace('.pkl', '-reward_normalizer.pkl'))

2018-06-06 16:36:42,501 - RNN_v3b_128im_512z_1512_v6k_VAE5_all - INFO: loss_rnn=**35.4942**, loss_inv= 7.3279=0.0100 * **732.7912**, loss_vae=1.6531=0.0156 * (**105.5821** + 0.0010 * 224.7847)

2018-06-06 20:24:16,225 - RNN_v3b_128im_512z_1512_v6k_VAE5_all - INFO: loss_rnn=35.3646, loss_inv= 7.0060=0.0100 * 700.6001, loss_vae=1.4467=0.0156 * (92.4289 + 0.0010 * 160.6168) total steps 384120

2018-06-07 06:49:28,590 - RNN_v3b_128im_512z_1512_v6k_VAE5_all - INFO: loss_rnn=12.3602, loss_inv= 7.2967=0.0100 * 729.6658, loss_vae=3.7269=0.0156 * (238.4971 + 0.0010 * 24.8053)  total steps 6120, (12h)

018-06-07 22:23:10,358 - RNN_v3b_128im_512z_1512_v6k_VAE5_all - INFO: loss_rnn=**2.1091**, loss_inv= 13.5913=0.0200 * **679.5642**, loss_vae=19.2946=0.1250 * (**154.1847** + 0.0039 * 43.9611)  total steps 6120, (24h)

2018-06-08 06:22:56,834 - RNN_v3b_128im_512z_1512_v6k_VAE5_all - INFO: loss_rnn=0.9232, loss_inv= 13.1371=0.0200 * 656.8546, loss_vae=14.0925=0.1250 * (112.5510 + 0.0039 * 48.4184) total steps 744120, 32 h  27.1891 step/s

2018-06-08 09:24:36,719 - RNN_v3b_128im_512z_1512_v6k_VAE5_all - INFO: loss_rnn=0.7665, loss_inv= 13.1950=0.0200 * 659.7520, loss_vae=15.3335=0.1250 * (122.4766 + 0.0039 * 49.0904)  INFO: total steps 264120

2018-06-08 11:30:40,861 - RNN_v3b_128im_512z_1512_v6k_VAE5_all - INFO: loss_rnn=0.6207, loss_inv= 13.1358=0.0200 * 656.7897, loss_vae=14.3136=0.1250 * (114.3138 + 0.0039 * 50.0098)  total steps 468120

2018-06-08 13:49:25,792 - RNN_v3b_128im_512z_1512_v6k_VAE5_all - INFO: loss_rnn=0.5512, loss_inv= 13.1396=0.0200 * 656.9778, loss_vae=14.7301=0.1250 * (117.6453 + 0.0039 * 50.1072)  total steps 696120,

~~2018-06-08 16:43:15,161 - RNN_v3b_128im_512z_1512_v6k_VAE5_all - INFO: loss_rnn=0.4015, loss_inv= 13.1253=0.0200 * 656.2637, loss_vae=10.7760=0.1250 * (86.0128 + 0.0039 * 50.0830) INFO: total steps 156120,~~ nope ths was just sonic 1

In [None]:
# TODO plot rewards over ti

In [None]:
print(f""""
# To monitor with tensorboard at http://localhost:6006/
cd ./outputs/{NAME}/
tensorboard  --logdir .
""")

In [None]:
agent.save(ppo_save_file)
ppo_save_file

# Summarize model

In [None]:
agent.network.network

In [None]:
from IPython.display import display

with torch.no_grad():
    img = np.random.randn(image_size, image_size, 3)
    action = np.array(np.random.randint(0,action_dim))[np.newaxis]
    action = Variable(torch.from_numpy(action)).float().cuda()[np.newaxis]
    gpu_img = Variable(torch.from_numpy(img[np.newaxis].transpose(0, 3, 1, 2))).float().cuda()
    if cuda:
        gpu_img = gpu_img.cuda()
    with TorchSummarizeDf(vae) as tdf:
        x, mu_vae, logvar_vae = vae.forward(gpu_img)
        z = vae.sample(mu_vae, logvar_vae)
        df_vae = tdf.make_df()

    display(df_vae[df_vae.level<2])
    
    with TorchSummarizeDf(mdnrnn) as tdf: 
        pi, mu, sigma, hidden_state = mdnrnn.forward(z.unsqueeze(1).repeat((1,2,1)))
        z_next = mdnrnn.sample(pi, mu, sigma)
        df_mdnrnn = tdf.make_df()
    
    display(df_mdnrnn)
    

    with TorchSummarizeDf(finv) as tdf:
        finv(z.repeat((1,2,1)), z_next)   
        df_finv = tdf.make_df()
    display(df_finv)

    with TorchSummarizeDf(world_model) as tdf:
        world_model(gpu_img, action)
        df_world_model = tdf.make_df()
    display(df_world_model[df_world_model.level<2])
    
    del img, action, gpu_img, x, mu, z, z_next, mu_vae, pi, sigma, logvar_vae