In [1]:
%load_ext autoreload

# --------------- #
# region: Imports #
import os
import sys

module_path = os.path.abspath("../../..")
if module_path not in sys.path:
    sys.path.insert(0, module_path)
# endregion       #
# --------------- #

In [2]:
import torch
from sorrel.config import load_config, argparse
from sorrel.models.pytorch import PyTorchIQN
from examples.cleanup.env import Cleanup
from examples.cleanup.agents import CleanupAgent
from sorrel.utils.visualization import visual_field_sprite, image_from_array, animate

cfg = load_config(argparse.Namespace(config="../configs/config.yaml"))
seed = torch.random.seed()
N_AGENTS = 1
agents: list[CleanupAgent] = []
models: list[PyTorchIQN] = []
for i in range(N_AGENTS):
    models.append(
        PyTorchIQN(
            input_size=(984,),
            seed=seed,
            num_frames=cfg.agent.agent.obs.num_frames,
            **cfg.model.iqn.parameters.to_dict()
        )
    )
    agents.append(CleanupAgent(cfg, model=models[i]))
env = Cleanup(cfg, agents=agents, mode="DEFAULT")
env.reset()

In [3]:
imgs = []
for _ in range(cfg.experiment.max_turns):
    img = visual_field_sprite(env)
    img = image_from_array(img)
    imgs.append(img)
    env.take_turn()
# Final frame
img = visual_field_sprite(env)
img = image_from_array(img)
imgs.append(img)

In [4]:
animate(imgs, "test", "../data/")

In [6]:
states, actions, rewards, next_states, dones, valids = agents[0].model.memory.sample(
    64, stacked_frames=5
)

In [9]:
agents[0].pov(env).shape

(1, 5, 984)