In this notebook we load a saved dreamer, and run it, to look at params, speed and improve hackability

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
# TODO make this a proper package
import os, sys
sys.path.append('..')


from dreamer import parse_args, main, make_env, make_dataset, count_steps,Dreamer

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


Loading textures from cache


In [4]:
# emulate cli
argv = f"../dreamer.py --configs craftax_small2 --logdir ../logdir/craftax_small2"
argv = argv.split()
print(argv)
config = parse_args(argv)
config

['../dreamer.py', '--configs', 'craftax_small2', '--logdir', '../logdir/craftax_small2']


Namespace(act='SiLU', action_repeat=1, actor={'layers': 3, 'dist': 'onehot', 'entropy': 0.0003, 'unimix_ratio': 0.01, 'std': 'none', 'min_std': 0.1, 'max_std': 1.0, 'temp': 0.1, 'lr': 3e-05, 'eps': 1e-05, 'grad_clip': 100.0, 'outscale': 1.0}, batch_length=32, batch_size=256, compile=True, cont_head={'layers': 3, 'loss_scale': 1.0, 'outscale': 1.0}, critic={'layers': 2, 'dist': 'symlog_disc', 'slow_target': True, 'slow_target_update': 1, 'slow_target_fraction': 0.02, 'lr': 3e-05, 'eps': 1e-05, 'grad_clip': 100.0, 'outscale': 0.0}, dataset_size=1000000, debug=False, decoder={'mlp_keys': 'state_inventory', 'cnn_keys': 'state_map', 'act': 'SiLU', 'norm': True, 'cnn_depth': 32, 'kernel_size': 4, 'minres': 2, 'mlp_layers': 2, 'mlp_units': 16, 'cnn_sigmoid': False, 'image_dist': 'mse', 'vector_dist': 'symlog_mse', 'outscale': 1.0}, deterministic_run=False, device='cuda:0', disag_action_cond=False, disag_layers=4, disag_log=True, disag_models=10, disag_offset=1, disag_target='stoch', disag_uni

In [5]:
from loguru import logger
from tqdm.auto import tqdm
import pathlib

import torch
from torch import nn
from torch import distributions as torchd

import exploration as expl
import models
import tools
import envs.wrappers as wrappers
from parallel import Parallel, Damy

# from main
tools.set_seed_everywhere(config.seed)
if config.deterministic_run:
    tools.enable_deterministic_run()
logdir = pathlib.Path(config.logdir).expanduser()
config.traindir = config.traindir or logdir / "train_eps"
config.evaldir = config.evaldir or logdir / "eval_eps"
config.steps //= config.action_repeat
config.eval_every //= config.action_repeat
config.log_every //= config.action_repeat
config.time_limit //= config.action_repeat

logger.info(f"Logdir {logdir}")
logdir.mkdir(parents=True, exist_ok=True)
config.traindir.mkdir(parents=True, exist_ok=True)
config.evaldir.mkdir(parents=True, exist_ok=True)
step = count_steps(config.traindir)
# step in logger is environmental step
tlogger = tools.Logger(logdir, config.action_repeat * step)
logger.add(logdir/"logger.log")

logger.info("Create envs.")
if config.offline_traindir:
    directory = config.offline_traindir.format(**vars(config))
else:
    directory = config.traindir
train_eps = tools.load_episodes(directory, limit=config.dataset_size)
if config.offline_evaldir:
    directory = config.offline_evaldir.format(**vars(config))
else:
    directory = config.evaldir
eval_eps = tools.load_episodes(directory, limit=1)
make = lambda mode, id: make_env(config, mode, id)
train_envs = [make("train", i) for i in range(config.envs)]
eval_envs = [make("eval", i) for i in range(config.envs)]
if config.parallel:
    train_envs = [Parallel(env, "process") for env in train_envs]
    eval_envs = [Parallel(env, "process") for env in eval_envs]
else:
    train_envs = [Damy(env) for env in train_envs]
    eval_envs = [Damy(env) for env in eval_envs]
acts = train_envs[0].action_space
logger.info(f"Action Space {acts}" )
config.num_actions = acts.n if hasattr(acts, "n") else acts.shape[0]

state = None
if not config.offline_traindir:
    prefill = max(0, config.prefill - count_steps(config.traindir))
    logger.info(f"Prefill dataset ({prefill} steps).")
    if hasattr(acts, "discrete"):
        random_actor = tools.OneHotDist(
            torch.zeros(config.num_actions).repeat(config.envs, 1)
        )
    else:
        random_actor = torchd.independent.Independent(
            torchd.uniform.Uniform(
                torch.Tensor(acts.low).repeat(config.envs, 1),
                torch.Tensor(acts.high).repeat(config.envs, 1),
            ),
            1,
        )

    def random_agent(o, d, s):
        action = random_actor.sample()
        logprob = random_actor.log_prob(action)
        return {"action": action, "logprob": logprob}, None

    state = tools.simulate(
        random_agent,
        train_envs,
        train_eps,
        config.traindir,
        tlogger,
        limit=config.dataset_size,
        steps=prefill,
    )
    tlogger.step += prefill * config.action_repeat
    logger.info(f"Logger: ({tlogger.step} steps).")

logger.info("Simulate agent.")
train_dataset = make_dataset(train_eps, config)
eval_dataset = make_dataset(eval_eps, config)

[32m2024-06-06 17:08:10.379[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m27[0m - [1mLogdir ../logdir/craftax_small2[0m
[32m2024-06-06 17:08:10.384[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m36[0m - [1mCreate envs.[0m
[32m2024-06-06 17:08:41.190[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m57[0m - [1mAction Space Box(0.0, 1.0, (43,), float32)[0m
[32m2024-06-06 17:08:41.191[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m63[0m - [1mPrefill dataset (26 steps).[0m
[32m2024-06-06 17:09:31.587[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m92[0m - [1mLogger: (2500 steps).[0m
[32m2024-06-06 17:09:31.588[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m94[0m - [1mSimulate agent.[0m


In [6]:
train_envs[0].observation_space

  multiarray.copyto(a, fill_value, casting='unsafe')


Dict('image': Box(0, 255, (130, 110, 3), uint8), 'is_first': Box(0, 0, (1,), uint8), 'is_last': Box(0, 0, (1,), uint8), 'is_terminal': Box(0, 0, (1,), uint8), 'log_achievement_cast_fireball': Box(-inf, inf, (1,), float32), 'log_achievement_cast_iceball': Box(-inf, inf, (1,), float32), 'log_achievement_collect_coal': Box(-inf, inf, (1,), float32), 'log_achievement_collect_diamond': Box(-inf, inf, (1,), float32), 'log_achievement_collect_drink': Box(-inf, inf, (1,), float32), 'log_achievement_collect_iron': Box(-inf, inf, (1,), float32), 'log_achievement_collect_ruby': Box(-inf, inf, (1,), float32), 'log_achievement_collect_sapling': Box(-inf, inf, (1,), float32), 'log_achievement_collect_sapphire': Box(-inf, inf, (1,), float32), 'log_achievement_collect_stone': Box(-inf, inf, (1,), float32), 'log_achievement_collect_wood': Box(-inf, inf, (1,), float32), 'log_achievement_damage_necromancer': Box(-inf, inf, (1,), float32), 'log_achievement_defeat_archer': Box(-inf, inf, (1,), float32), 'l

In [7]:
config = parse_args(argv)
config.num_actions = acts.n if hasattr(acts, "n") else acts.shape[0]
agent = Dreamer(
    train_envs[0].observation_space,
    train_envs[0].action_space,
    config,
    tlogger,
    train_dataset,
).to(config.device)
print(agent)
agent.requires_grad_(requires_grad=False)
if (logdir / "latest.pt").exists():
    checkpoint = torch.load(logdir / "latest.pt")
    agent.load_state_dict(checkpoint["agent_state_dict"])
    tools.recursively_load_optim_state_dict(agent, checkpoint["optims_state_dict"])
    agent._should_pretrain._once = False
    logger.warning(f"Loaded model from {logdir / 'latest.pt'}")

[32m2024-06-06 17:09:31.695[0m | [1mINFO    [0m | [36mnetworks[0m:[36m__init__[0m:[36m324[0m - [1mEncoder CNN shapes: {'state_map': (12, 12, 166)}[0m
[32m2024-06-06 17:09:31.696[0m | [1mINFO    [0m | [36mnetworks[0m:[36m__init__[0m:[36m325[0m - [1mEncoder MLP shapes: {'state_inventory': (102,)}[0m
[32m2024-06-06 17:09:31.913[0m | [1mINFO    [0m | [36mnetworks[0m:[36m__init__[0m:[36m391[0m - [1mDecoder CNN shapes: {'state_map': (12, 12, 166)}[0m
[32m2024-06-06 17:09:31.914[0m | [1mINFO    [0m | [36mnetworks[0m:[36m__init__[0m:[36m392[0m - [1mDecoder MLP shapes: {'state_inventory': (102,)}[0m
[32m2024-06-06 17:09:32.650[0m | [1mINFO    [0m | [36mmodels[0m:[36m__init__[0m:[36m102[0m - [1mOptimizer model_opt has 2357196 variables.[0m
[32m2024-06-06 17:09:32.657[0m | [1mINFO    [0m | [36mmodels[0m:[36m__init__[0m:[36m281[0m - [1mOptimizer actor_opt has 356651 variables.[0m
[32m2024-06-06 17:09:32.657[0m | [1mINFO    

  self.pid = os.fork()


Dreamer(
  (_wm): OptimizedModule(
    (_orig_mod): WorldModel(
      (encoder): MultiEncoder(
        (_cnn): ConvEncoder(
          (layers): Sequential(
            (0): Conv2dSamePad(166, 32, kernel_size=(4, 4), stride=(2, 2), bias=False)
            (1): ImgChLayerNorm(
              (norm): LayerNorm((32,), eps=0.001, elementwise_affine=True)
            )
            (2): SiLU()
            (3): Conv2dSamePad(32, 64, kernel_size=(4, 4), stride=(2, 2), bias=False)
            (4): ImgChLayerNorm(
              (norm): LayerNorm((64,), eps=0.001, elementwise_affine=True)
            )
            (5): SiLU()
          )
        )
        (_mlp): MLP(
          (layers): Sequential(
            (Encoder_linear0): Linear(in_features=102, out_features=16, bias=False)
            (Encoder_norm0): LayerNorm((16,), eps=0.001, elementwise_affine=True)
            (Encoder_act0): SiLU()
            (Encoder_linear1): Linear(in_features=16, out_features=16, bias=False)
            (Encoder

- note model_opt includes actor.wm
  - encoder
  - rssm
  - heads
- actor

## Now lets play

In [8]:
assert state is not None
import numpy as np

# state

In [9]:
from tools import convert, add_to_cache
envs = train_envs
cache = train_eps

step, episode = 0, 0
done = np.ones(len(envs), bool)
length = np.zeros(len(envs), np.int32)
obs = [None] * len(envs)
agent_state = None
reward = [0] * len(envs)

indices = [index for index, d in enumerate(done) if d]
results = [envs[i].reset() for i in indices]
results = [r() for r in results]
for index, result in zip(indices, results):
    t = result.copy()
    t = {k: convert(v) for k, v in t.items()}
    # action will be added to transition in add_to_cache
    t["reward"] = 0.0
    t["discount"] = 1.0
    # initial state should be added to cache
    add_to_cache(cache, envs[index].id, t)
    # replace obs with done by initial state
    obs[index] = result
# step agents

In [10]:
# from tools.simulate

# step
# step, episode, done, length, obs, agent_state, reward = state
obs2 = {k: np.stack([o[k] for o in obs]) for k in obs[0] if "log_" not in k}
action, agent_state = agent(obs2, done, agent_state)

  ret = F.conv2d(


AssertionError: (torch.Size([256, 32, 8, 8, 166]), torch.Size([256, 32, 12, 12, 166]))

In [None]:
from torchinfo import summary

summary(agent, input=(obs, done, agent_state), depth=3)

Layer (type:depth-idx)                                            Param #
Dreamer                                                           --
├─OptimizedModule: 1-1                                            --
│    └─WorldModel: 2-1                                            --
│    │    └─MultiEncoder: 3-1                                     (44,480)
│    │    └─RSSM: 3-2                                             (2,397,952)
│    │    └─ModuleDict: 3-3                                       (1,580,204)
├─OptimizedModule: 1-2                                            --
│    └─ImagBehavior: 2-2                                          4,022,636
│    │    └─WorldModel: 3-4                                       (recursive)
│    │    └─MLP: 3-5                                              (536,875)
│    │    └─MLP: 3-6                                              (525,311)
│    │    └─MLP: 3-7                                              (525,311)
├─OptimizedModule: 1-3               

In [None]:
summary(agent, input=(obs, done, agent_state), depth=4)

Layer (type:depth-idx)                                            Param #
Dreamer                                                           --
├─OptimizedModule: 1-1                                            --
│    └─WorldModel: 2-1                                            --
│    │    └─MultiEncoder: 3-1                                     --
│    │    │    └─ConvEncoder: 4-1                                 (42,528)
│    │    │    └─MLP: 4-2                                         (1,952)
│    │    └─RSSM: 3-2                                             512
│    │    │    └─Sequential: 4-3                                  (273,664)
│    │    │    └─GRUCell: 4-4                                     (1,182,720)
│    │    │    └─Sequential: 4-5                                  (131,584)
│    │    │    └─Sequential: 4-6                                  (283,136)
│    │    │    └─Linear: 4-7                                      (263,168)
│    │    │    └─Linear: 4-8                     

In [None]:
# agent._wm.heads

## Fine grained torchinfo

In [None]:
wm = agent._wm
data = next(agent._dataset) 
# self._train()
# post, context, mets = wm._train(data)
data = wm.preprocess(data)
embed = wm.encoder(data)
post, prior = wm.dynamics.observe(
    embed, data["action"], data["is_first"]
)

In [None]:
summary(wm.encoder, input_data=(data,), depth=4, col_names=["input_size", "output_size", "num_params", ])

Layer (type:depth-idx)                   Input Shape               Output Shape              Param #
MultiEncoder                             [256, 16, 130, 110, 3]    [256, 16, 592]            --
├─ConvEncoder: 1-1                       [256, 16, 12, 12, 166]    [256, 16, 576]            --
│    └─Sequential: 2-1                   [4096, 166, 12, 12]       [4096, 16, 6, 6]          --
│    │    └─Conv2dSamePad: 3-1           [4096, 166, 12, 12]       [4096, 16, 6, 6]          (42,496)
│    │    └─ImgChLayerNorm: 3-2          [4096, 16, 6, 6]          [4096, 16, 6, 6]          --
│    │    │    └─LayerNorm: 4-1          [4096, 6, 6, 16]          [4096, 6, 6, 16]          (32)
│    │    └─SiLU: 3-3                    [4096, 16, 6, 6]          [4096, 16, 6, 6]          --
├─MLP: 1-2                               [256, 16, 102]            [256, 16, 16]             --
│    └─Sequential: 2-2                   [256, 16, 102]            [256, 16, 16]             --
│    │    └─Linear: 3-4    

In [None]:
# heads
feat = wm.dynamics.get_feat(post)
for name, head in wm.heads.items():
    try:
        o = summary(head, input_data=(feat,), depth=3, col_names=["input_size", "output_size", "num_params", ])
        print(name)
        print(o)
    except Exception as e:
        print(f"Summary Failed for {name} {e}")
        continue

decoder
Layer (type:depth-idx)                   Input Shape               Output Shape              Param #
MultiDecoder                             [256, 16, 1536]           --                        --
├─ConvDecoder: 1-1                       [256, 16, 1536]           [256, 16, 8, 8, 166]      --
│    └─Linear: 2-1                       [256, 16, 1536]           [256, 16, 256]            (393,472)
│    └─Sequential: 2-2                   [4096, 16, 4, 4]          [4096, 166, 8, 8]         --
│    │    └─ConvTranspose2d: 3-1         [4096, 16, 4, 4]          [4096, 166, 8, 8]         (42,662)
├─MLP: 1-2                               [256, 16, 1536]           --                        --
│    └─Sequential: 2-3                   [256, 16, 1536]           [256, 16, 16]             --
│    │    └─Linear: 3-2                  [256, 16, 1536]           [256, 16, 16]             (24,576)
│    │    └─LayerNorm: 3-3               [256, 16, 16]             [256, 16, 16]             (32)
│    │

In [None]:
# fail as no call method
# summary(wm.dynamics, input_data=(embed, data["action"], data["is_first"]), depth=3, col_names=["output_size", "num_params", ])

In [None]:
actor = agent._task_behavior.actor

summary(actor.layers, input_data=(feat,), depth=3, col_names=["output_size", "num_params", "output_size" ])



Layer (type:depth-idx)                   Output Shape              Param #                   Output Shape
Sequential                               [256, 16, 256]            --                        [256, 16, 256]
├─Linear: 1-1                            [256, 16, 256]            (393,216)                 [256, 16, 256]
├─LayerNorm: 1-2                         [256, 16, 256]            (512)                     [256, 16, 256]
├─SiLU: 1-3                              [256, 16, 256]            --                        [256, 16, 256]
├─Linear: 1-4                            [256, 16, 256]            (65,536)                  [256, 16, 256]
├─LayerNorm: 1-5                         [256, 16, 256]            (512)                     [256, 16, 256]
├─SiLU: 1-6                              [256, 16, 256]            --                        [256, 16, 256]
├─Linear: 1-7                            [256, 16, 256]            (65,536)                  [256, 16, 256]
├─LayerNorm: 1-8              

In [None]:
value = agent._task_behavior.actor
summary(value.layers, input_data=(feat,), depth=3, col_names=["output_size", "num_params", "output_size" ])

Layer (type:depth-idx)                   Output Shape              Param #                   Output Shape
Sequential                               [256, 16, 256]            --                        [256, 16, 256]
├─Linear: 1-1                            [256, 16, 256]            (393,216)                 [256, 16, 256]
├─LayerNorm: 1-2                         [256, 16, 256]            (512)                     [256, 16, 256]
├─SiLU: 1-3                              [256, 16, 256]            --                        [256, 16, 256]
├─Linear: 1-4                            [256, 16, 256]            (65,536)                  [256, 16, 256]
├─LayerNorm: 1-5                         [256, 16, 256]            (512)                     [256, 16, 256]
├─SiLU: 1-6                              [256, 16, 256]            --                        [256, 16, 256]
├─Linear: 1-7                            [256, 16, 256]            (65,536)                  [256, 16, 256]
├─LayerNorm: 1-8              

In [None]:
8268

8268