In [None]:
import nest
import torch
import matplotlib.pyplot as plt
import math

from torchvision.utils import make_grid, save_image

from torchbeast import utils
from torchbeast.core import models
from torchbeast import polybeast_learner as polybeast

# path to flags.savedir/flags.xpid/model.tar
checkpointpath = "/root/logs/torchbeast/latest/model.tar"

checkpoint_states = torch.load(checkpointpath)

flags = checkpoint_states["flags"]

class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__
    
flags = dotdict(flags)

checkpoint_states = torch.load(checkpointpath, map_location=flags.learner_device)

dataset_uses_color = flags.dataset not in ["mnist", "omniglot"]
grayscale = dataset_uses_color and not flags.use_color

env_uses_color = flags.use_color or flags.env_type == "fluid"
if env_uses_color is False:
    grayscale = True
else:
    grayscale = is_color and not dataset_uses_color

if flags.condition:
    dataset = utils.create_dataset(flags.dataset, grayscale)
else:
    dataset = None

env_name, config = utils.parse_flags(flags)
env = utils.create_env(env_name, config, grayscale=True, dataset=dataset)

grid_width = 32

obs_shape = env.observation_space["canvas"].shape
action_shape = env.action_space.nvec
order = env.order

model = models.Net(
    obs_shape=obs_shape,
    order=order,
    action_shape=action_shape,
    grid_shape=(grid_width, grid_width),
).eval()
model = model.to(flags.learner_device)

In [None]:
checkpoint_states = torch.load(checkpointpath, map_location=flags.learner_device)
model.load_state_dict(checkpoint_states["model_state_dict"])

obs = env.reset()

for k in obs.keys():
    obs[k] = torch.from_numpy(obs[k]).unsqueeze(0).unsqueeze(0)

done = torch.tensor(False).view(1, 1)

core_state = model.initial_state()

for i in range(flags.episode_length - 1):
    (action, logits, baseline), core_state = model(obs, done, core_state)
    action = torch.flatten(action.cpu(), 0, 1).int().numpy()
    
    obs, reward, done, info = env.step(action)
    
    for k in obs.keys():
        tensor = torch.from_numpy(obs[k])
        # prev_action's original shape has batch dimension.
        if k == "prev_action":
            obs[k] = tensor.unsqueeze(0)
        else:
            obs[k] = tensor.unsqueeze(0).unsqueeze(0)
            
    done = torch.tensor(done).view(1, 1)
    
img = obs["canvas"].view(obs_shape).permute(1, 2, 0).numpy()

if flags.condition:
    target = img[..., 1:2]
    img = img[..., 0:1]
    
plt.figure(figsize=(5, 5))
plt.imshow(img, cmap='gray', vmin=0.0, vmax=1.0, interpolation="nearest")

if flags.condition:
    plt.figure(figsize=(5, 5))
    plt.imshow(target, cmap='gray', vmin=0.0, vmax=1.0, interpolation="nearest")

In [None]:
def sample():
    obs = env.reset()

    for k in obs.keys():
        obs[k] = torch.from_numpy(obs[k]).unsqueeze(0).unsqueeze(0)

    done = torch.tensor(False).view(1, 1)

    core_state = model.initial_state()

    for i in range(flags.episode_length - 1):
        (action, logits, baseline), core_state = model(obs, done, core_state)
        action = torch.flatten(action.cpu(), 0, 1).int().numpy()
    
        obs, reward, done, info = env.step(action)
    
        for k in obs.keys():
            tensor = torch.from_numpy(obs[k])
            # prev_action's original shape has batch dimension.
            if k == "prev_action":
                obs[k] = tensor.unsqueeze(0)
            else:
                obs[k] = tensor.unsqueeze(0).unsqueeze(0)
            
        done = torch.tensor(done).view(1, 1)
        
    return obs["canvas"].view(obs_shape)

# load model from checkpoint path.
checkpoint_states = torch.load(checkpointpath, map_location=flags.learner_device)
model.load_state_dict(checkpoint_states["model_state_dict"])

renders = [sample() for _ in range(flags.batch_size)]
renders = torch.stack(renders)

if flags.condition:
    targets = renders[:, 1:2, ...]
    renders = renders[:, 0:1, ...]
    
img = make_grid(
    renders, nrow=math.ceil(flags.batch_size ** 0.5)
).permute(1, 2, 0).numpy()

plt.figure(figsize=(7, 7))
plt.imshow(img, cmap='gray', interpolation="nearest")

if flags.condition:
    target_img = make_grid(
        targets, nrow=math.ceil(flags.batch_size ** 0.5)
    ).permute(1, 2, 0).numpy()

    plt.figure(figsize=(7, 7))
    plt.imshow(target_img, cmap='gray', interpolation="nearest")

In [None]:
def sample_history():
    obs = env.reset()

    for k in obs.keys():
        obs[k] = torch.from_numpy(obs[k]).unsqueeze(0).unsqueeze(0)

    done = torch.tensor(False).view(1, 1)

    core_state = model.initial_state()
    
    render = obs["canvas"].view(obs_shape)
    
    if flags.condition:
        target = render[1:2]
        render = render[0:1]

    renders = [render]
    
    for i in range(flags.episode_length - 1):
        (action, logits, baseline), core_state = model(obs, done, core_state)
        action = torch.flatten(action.cpu(), 0, 1).int().numpy()
    
        obs, reward, done, info = env.step(action)
    
        for k in obs.keys():
            tensor = torch.from_numpy(obs[k])
            # prev_action's original shape has batch dimension.
            if k == "prev_action":
                obs[k] = tensor.unsqueeze(0)
            else:
                obs[k] = tensor.unsqueeze(0).unsqueeze(0)
            
        done = torch.tensor(done).view(1, 1)
        
        render = obs["canvas"].view(obs_shape)
        if flags.condition:
            render = render[0:1]
            
        renders.append(render)
    
    if flags.condition:
        renders.append(target)
        
    return torch.stack(renders)

# load model from checkpoint path.
checkpoint_states = torch.load(checkpointpath, map_location=flags.learner_device)
model.load_state_dict(checkpoint_states["model_state_dict"])

renders = [sample_history() for _ in range(flags.batch_size)]

if flags.condition:
    nrow = math.ceil(
        (flags.batch_size / (flags.episode_length + 1)) ** 0.5
    ) * (flags.episode_length + 1)
else:
    nrow = math.ceil(
        (flags.batch_size / flags.episode_length) ** 0.5
    ) * flags.episode_length

img = make_grid(torch.cat(renders), nrow=nrow).permute(1, 2, 0).numpy()

plt.figure(figsize=(20, 20))
plt.imshow(img, cmap='gray', interpolation="nearest")

In [None]:
replay_buffer = polybeast.ReplayBuffer(flags.replay_buffer_size)
replay_buffer.load_checkpoint(checkpoint_states["replay_buffer"])

sample = replay_buffer.sample(flags.batch_size)

if flags.condition:
    target = sample[:,1:2]
    sample = sample[:, 0:1]
    
img = make_grid(
    sample, 
    nrow=math.ceil(flags.batch_size ** 0.5)
).permute(1, 2, 0).numpy()

plt.figure(figsize=(7, 7))
plt.imshow(img, cmap="gray", interpolation="nearest")

if flags.condition:
    img = make_grid(
        target, 
        nrow=math.ceil(flags.batch_size ** 0.5)
    ).permute(1, 2, 0).numpy()

    plt.figure(figsize=(7, 7))
    plt.imshow(img, cmap="gray", interpolation="nearest")

In [None]:
print(flags)