# VizDoom tournament

In [None]:
%load_ext autoreload
%autoreload 2

In [3]:
import numpy as np
import torch
from torch.nn import functional as F
from torchvision import transforms

from arena import VizdoomMPEnv

In [4]:
def stack_dict(x):
    return np.concat([v for v in x.values()], 1)


def to_tensor(x):
    if isinstance(x, np.ndarray):
        x = torch.from_numpy(x)
    return x


def resize(x):
    # batch dimension for interpolation
    if x.ndim < 4:
        x = x.unsqueeze(0)
    return F.interpolate(x, (128, 128))


def minmax(x):
    # channelwise minmax (preserves different buffers as well)
    x_max = x.view(x.shape[0], x.shape[1], -1).max(-1)[0][..., *[None] * (x.ndim - 2)]
    x = x / (x_max + 1e-8)
    return torch.nan_to_num(x)


frame_transform = transforms.Compose([stack_dict, to_tensor, minmax, resize])

In [None]:
env = VizdoomMPEnv(
    num_players=2,
    num_bots=0,
    doom_map="map01",
    episode_timeout=5000,
    player_transform=frame_transform,
)

In [6]:
env.enable_replay()

## Random policy (2 players)

In [None]:
for episode in range(1):
    ep_return = {k: 0.0 for k in range(env.num_players)}
    ep_step = 0
    obs = env.reset()
    done = None
    for _ in range(300):
        act = env.action_space.sample()
        obs, rwd, done, info = env.step(act)
        ep_return = {k: ep_return[k] + rwd[i] for i, k in enumerate(ep_return)}
        if done:
            print("ep steps: {}; ep return: {}".format(ep_step, ep_return))
            break
        else:
            ep_step += 1

In [None]:
from IPython.display import HTML

from arena.render import render_episode


ani = render_episode(env.get_player_replays())
HTML(ani.to_html5_video())

## Eval DQN

In [None]:
env = VizdoomMPEnv(
    num_players=1,
    num_bots=4,
    doom_map="map01",
    episode_timeout=5000,
    player_transform=frame_transform,
)

env.enable_replay()

In [None]:
dqn = ...

In [None]:
ep_return = {k: 0.0 for k in range(env.num_players)}
ep_step = 0
done = False
obs = env.reset()
while not done:
    obs = obs[0].to(device)
    act = dqn(obs).argmax().item()
    obs, rwd, done, info = env.step(act)
    ep_return = {k: ep_return[k] + rwd[i] for i, k in enumerate(ep_return)}
    if done:
        print("ep steps: {}; ep return: {}".format(ep_step, ep_return))
        break
    else:
        ep_step += 1

In [None]:
from IPython.display import HTML

from arena.render import render_episode


ani = render_episode(env.get_player_replays())
HTML(ani.to_html5_video())