In this notebook we load a saved dreamer, and run it, to look at params, speed and improve hackability

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# TODO make this a proper package
import os, sys
sys.path.append('..')


from dreamer import parse_args, main, make_env, make_dataset, count_steps,Dreamer

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


Loading textures from cache


In [3]:
# emulate cli
argv = f"../dreamer.py --configs craftax_small2 --logdir ../logdir/craftax_small2"
argv = argv.split()
print(argv)
config = parse_args(argv)
config

['../dreamer.py', '--configs', 'craftax_small2', '--logdir', '../logdir/craftax_small2']


Namespace(act='SiLU', action_repeat=1, actor={'layers': 3, 'dist': 'onehot', 'entropy': 0.0003, 'unimix_ratio': 0.01, 'std': 'none', 'min_std': 0.1, 'max_std': 1.0, 'temp': 0.1, 'lr': 3e-05, 'eps': 1e-05, 'grad_clip': 100.0, 'outscale': 1.0}, batch_length=16, batch_size=256, compile=True, cont_head={'layers': 3, 'loss_scale': 1.0, 'outscale': 1.0}, critic={'layers': 2, 'dist': 'symlog_disc', 'slow_target': True, 'slow_target_update': 1, 'slow_target_fraction': 0.02, 'lr': 3e-05, 'eps': 1e-05, 'grad_clip': 100.0, 'outscale': 0.0}, dataset_size=1000000, debug=False, decoder={'mlp_keys': 'state_inventory', 'cnn_keys': 'state_map', 'act': 'SiLU', 'norm': True, 'cnn_depth': 16, 'kernel_size': 4, 'minres': 4, 'mlp_layers': 2, 'mlp_units': 16, 'cnn_sigmoid': False, 'image_dist': 'mse', 'vector_dist': 'symlog_mse', 'outscale': 1.0}, deterministic_run=False, device='cuda:0', disag_action_cond=False, disag_layers=4, disag_log=True, disag_models=10, disag_offset=1, disag_target='stoch', disag_uni

In [7]:
from loguru import logger
from tqdm.auto import tqdm
import pathlib

import torch
from torch import nn
from torch import distributions as torchd

import exploration as expl
import models
import tools
import envs.wrappers as wrappers
from parallel import Parallel, Damy

# from main
tools.set_seed_everywhere(config.seed)
if config.deterministic_run:
    tools.enable_deterministic_run()
logdir = pathlib.Path(config.logdir).expanduser()
config.traindir = config.traindir or logdir / "train_eps"
config.evaldir = config.evaldir or logdir / "eval_eps"
config.steps //= config.action_repeat
config.eval_every //= config.action_repeat
config.log_every //= config.action_repeat
config.time_limit //= config.action_repeat

logger.info(f"Logdir {logdir}")
logdir.mkdir(parents=True, exist_ok=True)
config.traindir.mkdir(parents=True, exist_ok=True)
config.evaldir.mkdir(parents=True, exist_ok=True)
step = count_steps(config.traindir)
# step in logger is environmental step
tlogger = tools.Logger(logdir, config.action_repeat * step)
logger.add(logdir/"logger.log")

logger.info("Create envs.")
if config.offline_traindir:
    directory = config.offline_traindir.format(**vars(config))
else:
    directory = config.traindir
train_eps = tools.load_episodes(directory, limit=config.dataset_size)
if config.offline_evaldir:
    directory = config.offline_evaldir.format(**vars(config))
else:
    directory = config.evaldir
eval_eps = tools.load_episodes(directory, limit=1)
make = lambda mode, id: make_env(config, mode, id)
train_envs = [make("train", i) for i in range(config.envs)]
eval_envs = [make("eval", i) for i in range(config.envs)]
if config.parallel:
    train_envs = [Parallel(env, "process") for env in train_envs]
    eval_envs = [Parallel(env, "process") for env in eval_envs]
else:
    train_envs = [Damy(env) for env in train_envs]
    eval_envs = [Damy(env) for env in eval_envs]
acts = train_envs[0].action_space
logger.info(f"Action Space {acts}" )
config.num_actions = acts.n if hasattr(acts, "n") else acts.shape[0]

state = None
if not config.offline_traindir:
    prefill = max(0, config.prefill - count_steps(config.traindir))
    logger.info(f"Prefill dataset ({prefill} steps).")
    if hasattr(acts, "discrete"):
        random_actor = tools.OneHotDist(
            torch.zeros(config.num_actions).repeat(config.envs, 1)
        )
    else:
        random_actor = torchd.independent.Independent(
            torchd.uniform.Uniform(
                torch.Tensor(acts.low).repeat(config.envs, 1),
                torch.Tensor(acts.high).repeat(config.envs, 1),
            ),
            1,
        )

    def random_agent(o, d, s):
        action = random_actor.sample()
        logprob = random_actor.log_prob(action)
        return {"action": action, "logprob": logprob}, None

    state = tools.simulate(
        random_agent,
        train_envs,
        train_eps,
        config.traindir,
        tlogger,
        limit=config.dataset_size,
        steps=prefill,
    )
    tlogger.step += prefill * config.action_repeat
    logger.info(f"Logger: ({tlogger.step} steps).")

logger.info("Simulate agent.")
train_dataset = make_dataset(train_eps, config)
eval_dataset = make_dataset(eval_eps, config)

[32m2024-06-06 16:21:39.870[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m27[0m - [1mLogdir ../logdir/craftax_small2[0m
[32m2024-06-06 16:21:39.887[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m36[0m - [1mCreate envs.[0m
[32m2024-06-06 16:22:16.800[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m57[0m - [1mAction Space Box(0.0, 1.0, (43,), float32)[0m
[32m2024-06-06 16:22:16.801[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m63[0m - [1mPrefill dataset (2500 steps).[0m
[32m2024-06-06 16:23:40.174[0m | [1mINFO    [0m | [36mtools[0m:[36mwrite[0m:[36m85[0m - [1m[0] log_achievements/cast_fireball [31m0.0[0m[1m / log_achievements/cast_iceball [31m0.0[0m[1m / log_achievements/collect_coal [31m0.0[0m[1m / log_achievements/collect_diamond [31m0.0[0m[1m / log_achievements/collect_drink [31m0.0[0m[1m / log_achievements/collect_iron [31m0.0[0m[1m / log_achievements/collect_ruby

In [None]:
train_envs[0].observation_space

In [None]:
config = parse_args(argv)
config.num_actions = acts.n if hasattr(acts, "n") else acts.shape[0]
agent = Dreamer(
    train_envs[0].observation_space,
    train_envs[0].action_space,
    config,
    tlogger,
    train_dataset,
).to(config.device)
# print(agent)
agent.requires_grad_(requires_grad=False)
if (logdir / "latest.pt").exists():
    checkpoint = torch.load(logdir / "latest.pt")
    agent.load_state_dict(checkpoint["agent_state_dict"])
    tools.recursively_load_optim_state_dict(agent, checkpoint["optims_state_dict"])
    agent._should_pretrain._once = False
    logger.warning(f"Loaded model from {logdir / 'latest.pt'}")

- note model_opt includes actor.wm
  - encoder
  - rssm
  - heads
- actor

## Now lets play

In [None]:
assert state is not None
import numpy as np

state

In [None]:
from tools import convert, add_to_cache
envs = train_envs
cache = train_eps

step, episode = 0, 0
done = np.ones(len(envs), bool)
length = np.zeros(len(envs), np.int32)
obs = [None] * len(envs)
agent_state = None
reward = [0] * len(envs)

indices = [index for index, d in enumerate(done) if d]
results = [envs[i].reset() for i in indices]
results = [r() for r in results]
for index, result in zip(indices, results):
    t = result.copy()
    t = {k: convert(v) for k, v in t.items()}
    # action will be added to transition in add_to_cache
    t["reward"] = 0.0
    t["discount"] = 1.0
    # initial state should be added to cache
    add_to_cache(cache, envs[index].id, t)
    # replace obs with done by initial state
    obs[index] = result
# step agents

In [None]:
envs[0].observation_space

In [None]:
obs[0]['state_map'].shape, obs[0]['state_inventory'].shape

In [None]:
# from tools.simulate

# step
# step, episode, done, length, obs, agent_state, reward = state
obs2 = {k: np.stack([o[k] for o in obs]) for k in obs[0] if "log_" not in k}
action, agent_state = agent(obs2, done, agent_state)

In [None]:
from torchinfo import summary

summary(agent, input=(obs, done, agent_state), depth=4)

In [None]:
# agent._wm.heads

## Fine grained torchinfo

In [None]:
wm = agent._wm
data = next(agent._dataset) 
# self._train()
# post, context, mets = wm._train(data)
data = wm.preprocess(data)
embed = wm.encoder(data)
post, prior = wm.dynamics.observe(
    embed, data["action"], data["is_first"]
)

In [None]:
summary(wm.encoder, input_data=(data,), depth=4, col_names=["input_size", "output_size", "num_params", ])

In [None]:
# heads
feat = wm.dynamics.get_feat(post)
for name, head in wm.heads.items():
    try:
        o = summary(head, input_data=(feat,), depth=3, col_names=["input_size", "output_size", "num_params", ])
        print(name)
        print(o)
    except Exception as e:
        print(f"Summary Failed for {name} {e}")
        continue

In [None]:
# fail as no call method
# summary(wm.dynamics, input_data=(embed, data["action"], data["is_first"]), depth=3, col_names=["output_size", "num_params", ])

In [None]:
actor = agent._task_behavior.actor

summary(actor.layers, input_data=(feat,), depth=3, col_names=["output_size", "num_params", "output_size" ])



In [None]:
value = agent._task_behavior.actor
summary(value.layers, input_data=(feat,), depth=3, col_names=["output_size", "num_params", "output_size" ])

In [None]:
8268

In [None]:
o = obs['state'].reshape((-1, 8268))
map = o[:, :8217].reshape((-1, 9, 11, 83))
map.shape
inventories = o[:, 8217:]
inventories

map