# üß† Deep Reinforcement Learning ‚Äî Doom Agent (SS2025)

Welcome to the last assignment for the **Deep Reinforcement Learning** course (SS2025). In this notebook, you'll implement and train a reinforcement learning agent to play **Doom**.

You will:
- Set up a custom VizDoom environment with shaped rewards
- Train an agent using an approach of your choice
- Track reward components across episodes
- Evaluate the best model
- Visualize performance with replays and GIFs
- Export the trained agent to ONNX to submit to the evaluation server

In [1]:
# Install the dependencies
!python -m pip install --upgrade pip
!pip install --upgrade notebook ipywidgets ipykernel -q
!pip install torch numpy matplotlib vizdoom portpicker gym onnx wandb stable-baselines3 stable-baselines3[extra] Shimmy einops torchvision -q

[0m

In [2]:
import os
import subprocess

base_dir = os.path.abspath(os.getcwd())
dir_path = os.path.join(base_dir, "jku.wad")

if os.path.isdir(dir_path):
    os.chdir(dir_path)
    subprocess.run(["git", "pull", "origin", "main"])
else:
    subprocess.run(["git", "clone", "https://github.com/syseitz/jku.wad.git", dir_path])
    os.chdir(dir_path)

Cloning into '/jku.wad'...


## Environment configuration

ViZDoom supports multiple visual buffers that can be used as input for training agents. Each buffer provides different information about the game environment, as seen from left to right:


Screen
- The default first-person RGB view seen by the agent.

Labels
- A semantic map where each pixel is tagged with an object ID (e.g., enemy, item, wall).

Depth
- A grayscale map showing the distance from the agent to surfaces in the scene.

Automap
- A top-down schematic view of the map, useful for global navigation tasks.

![buffers gif](https://vizdoom.farama.org/_images/vizdoom-demo.gif)

In [3]:
import wandb
from typing import Dict, Sequence, Tuple

import torch
from collections import deque, OrderedDict
from copy import deepcopy
import random
import numpy as np
import torch.nn.functional as F
import torch.optim as optim
import pandas as pd
from matplotlib import pyplot as plt
from PIL import Image

from gym import Env
import gymnasium as gym
from torch import nn
from einops import rearrange

from doom_arena import VizdoomMPEnv
from doom_arena.reward import VizDoomReward
from doom_arena.render import render_episode
from IPython.display import HTML

from vizdoom import ScreenFormat
from stable_baselines3 import DQN
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor

In [None]:
USE_GRAYSCALE = False # ‚Üê flip to False for RGB

PLAYER_CONFIG = {
    "n_stack_frames": 1,
    "extra_state": ["depth"],
    "hud": "none",
    "crosshair": True,
    "screen_format": ScreenFormat.CRCGCB if not USE_GRAYSCALE else ScreenFormat.GRAY8,
}

## Reward function
In this task, you will define a reward function to guide the agent's learning. The function is called at every step and receives the current and previous game variables (e.g., number of frags, hits taken, health).

Your goal is to combine these into a meaningful reward, encouraging desirable behavior, such as:

- Rewarding frags (enemy kills)

- Rewarding accuracy (hitting enemies)

- Penalizing damage taken

- (Optional) Encouraging survival, ammo efficiency, etc.

You can return multiple reward components, which are summed during training. Consider the class below as a great starting point!

In [5]:
# TODO: environment training paramters
N_STACK_FRAMES = 4
NUM_BOTS = 4
EPISODE_TIMEOUT = 1000
# TODO: model hyperparams
GAMMA = 0.95
EPISODES = 1000 
BATCH_SIZE = 256
REPLAY_BUFFER_SIZE = 100000
LEARNING_RATE = 5e-4
EPSILON_START = 1.0
EPSILON_END = 0.1
EPSILON_DECAY = 0.999
FEATURES_DIM = 512
#N_EPOCHS = 50 # Not used with stable_baseline3
TOTAL_TIMESTEPS = 1000000
EXPLORATION_FRACTION = 0.2


In [6]:
class YourReward(VizDoomReward):
    def __init__(self, num_players: int):
        super().__init__(num_players)
        # Initialize a list to store the last computed reward components for each player
        self.last_rewards = [None] * num_players

    def __call__(
        self,
        vizdoom_reward: float,
        game_var: Dict[str, float],
        game_var_old: Dict[str, float],
        player_id: int,
    ) -> Tuple[float, ...]:
        """
        Custom reward function for training and evaluation:
        * +0.01  for each damage point dealt (rwd_damage)
        * -0.1   for each missed shot (rwd_missed)
        * +1.0   for each new frag (rwd_frag)
        * +0.02  for each health point gained (rwd_health_pickup)
        * +0.001 for surviving each step (rwd_survival)
        * -0.5   if the player dies (rwd_dead)
        * -0.01  for shooting without dealing damage (rwd_spam_penalty)
        * +0.00005 for moving, -0.0025 for staying still (rwd_movement)
        """
        # Increment internal step counter
        self._step += 1
        # Ignore vizdoom_reward and player_id for internal use
        _ = vizdoom_reward, player_id

        # Calculate reward for damage dealt
        damage_done = game_var["DAMAGECOUNT"] - game_var_old["DAMAGECOUNT"]
        rwd_damage = 0.01 * damage_done

        # Calculate reward for achieving a frag
        rwd_frag = 1.0 * (game_var["FRAGCOUNT"] - game_var_old["FRAGCOUNT"])

        # Calculate shots fired and missed shots
        ammo_delta = game_var_old["SELECTED_WEAPON_AMMO"] - game_var["SELECTED_WEAPON_AMMO"]
        if ammo_delta > 0:
            shots_fired = ammo_delta
            hits = game_var["HITCOUNT"] - game_var_old["HITCOUNT"]
            missed_shots = max(0, shots_fired - hits)
            rwd_missed = -0.1 * missed_shots
        else:
            rwd_missed = 0

        # Calculate survival reward and death penalty
        rwd_survival = 0.001
        rwd_dead = -0.5 if game_var["DEAD"] == 1 else 0.0

        # Calculate penalty for spamming shots without dealing damage
        rwd_spam_penalty = -0.01 if ammo_delta > 0 and damage_done <= 0 else 0.0

        # Calculate reward for gaining health
        health_delta = game_var["HEALTH"] - game_var_old["HEALTH"]
        health_gained = max(0, health_delta)
        rwd_health_pickup = 0.02 * health_gained

        # Calculate reward for movement
        position_changed = (game_var["POSITION_X"] != game_var_old["POSITION_X"]) or (game_var["POSITION_Y"] != game_var_old["POSITION_Y"])
        rwd_movement = 0.00005 if position_changed else -0.0025

        # Store the computed reward components for the current player
        rewards = (rwd_damage, rwd_frag, rwd_missed, rwd_survival, rwd_dead, rwd_spam_penalty, rwd_health_pickup, rwd_movement)
        self.last_rewards[player_id] = rewards
        
        # Return the tuple of reward components
        return rewards

In [None]:
device = "cuda"
DTYPE = torch.float32

reward_fn = YourReward(num_players=1)

from vizdoom import Button, ScreenFormat

env = VizdoomMPEnv(
    num_players=1,
    num_bots=NUM_BOTS,
    bot_skill=0,
    doom_map="ROOM",
    extra_state=PLAYER_CONFIG["extra_state"],
    episode_timeout=EPISODE_TIMEOUT,
    n_stack_frames=PLAYER_CONFIG["n_stack_frames"],
    crosshair=PLAYER_CONFIG["crosshair"],
    hud=PLAYER_CONFIG["hud"],
    reward_fn=None,
    available_buttons=[Button.ATTACK, Button.TURN_LEFT, Button.TURN_RIGHT],
    frame_skip=4,
)

Host 43263
Player 43263


## Agent

Implement **your own agent** in the code cell that follows.

* In `agents/dqn.py` and `agents/ppo.py` you‚Äôll find very small **skeletons**‚Äîthey compile but are meant only as reference or quick tests.  
  Feel free to open them, borrow ideas, extend them, or ignore them entirely.
* The notebook does **not** import those files automatically; whatever class you define in the next cell is the one that will be trained.
* You may keep the DQN interface, switch to PPO, or try something else.
* Tweak any hyper-parameters (`PLAYER_CONFIG`, Œµ-schedule, optimiser, etc.) and document what you tried.


In [8]:
class CustomCNN(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 256):
        super(CustomCNN, self).__init__(observation_space, features_dim)
        assert isinstance(observation_space, gym.spaces.Box), "Observation space must be a Box"

        # Annahme: Die Kan√§le sind wie folgt geordnet: Bildschirm (3), Tiefe (1), Labels (1)
        c = observation_space.shape[2]
        screen_ch = 3
        depth_ch = 1
        labels_ch = 1
        assert screen_ch + depth_ch + labels_ch == c, "Kanal-Mismatch"

        # Definiere CNNs f√ºr jeden Teil
        self.cnn_screen = nn.Sequential(
            nn.Conv2d(screen_ch, 16, kernel_size=8, stride=4, padding=0),  
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=4, stride=2, padding=0),         
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=0),   
            nn.ReLU(),
            nn.Flatten(),
        )

        self.cnn_depth = nn.Sequential(
            nn.Conv2d(depth_ch, 16, kernel_size=4, stride=2, padding=0),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=0),
            nn.ReLU(),
            nn.Flatten(),
        )

        self.cnn_labels = nn.Sequential(
            nn.Conv2d(labels_ch, 16, kernel_size=4, stride=2, padding=0),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=0),
            nn.ReLU(),
            nn.Flatten(),
        )

        # Berechne die flachen Gr√∂√üen
        with torch.no_grad():
            sample = torch.as_tensor(observation_space.sample()[None]).float()
            sample = rearrange(sample, 'n h w c -> n c h w')
            n_flatten_screen = self.cnn_screen(sample[:, :screen_ch, :, :]).shape[1]
            n_flatten_depth = self.cnn_depth(sample[:, screen_ch:screen_ch+depth_ch, :, :]).shape[1]
            n_flatten_labels = self.cnn_labels(sample[:, screen_ch+depth_ch:, :, :]).shape[1]

        total_flatten = n_flatten_screen + n_flatten_depth + n_flatten_labels
        self.linear = nn.Sequential(nn.Linear(total_flatten, features_dim), nn.ReLU())

    def forward(self, observations: torch.Tensor) -> torch.Tensor:
        # observations shape: (N, H, W, C)
        observations = rearrange(observations, 'n h w c -> n c h w')
        screen = observations[:, :3, :, :]
        depth = observations[:, 3:4, :, :]
        labels = observations[:, 4:5, :, :]
        features_screen = self.cnn_screen(screen)
        features_depth = self.cnn_depth(depth)
        features_labels = self.cnn_labels(labels)
        combined = torch.cat((features_screen, features_depth, features_labels), dim=1)
        return self.linear(combined)

In [None]:
# ================================================================
# Initialise your networks and training utilities
# ================================================================

# main Q-network
in_channels = env.observation_space.shape[0]   # 1 if grayscale, else 3/4
#model = DQN(
#    input_dim    = in_channels,
#    action_space = env.action_space.n,
#    hidden       = 64,   # change or ignore
#).to(device, dtype=DTYPE)

policy_kwargs = dict(
    features_extractor_class=CustomCNN,
    features_extractor_kwargs=dict(features_dim=FEATURES_DIM),
)

model = DQN(
    "CnnPolicy",
    env,
    learning_rate=0.0001,          
    buffer_size=10000,            
    batch_size=32,                 
    gamma=0.99,                    
    exploration_fraction=0.1,     
    exploration_initial_eps=1.0,   
    exploration_final_eps=0.02,   
    verbose=1,
    policy_kwargs=policy_kwargs,
)

# TODO ------------------------------------------------------------
# 1. Create a target network (hard-copy or EMA)
# 2. Choose an optimiser + learning-rate schedule
# 3. Instantiate a replay buffer and set the initial epsilon value
#
# Hints:
#   model_tgt  = deepcopy(model).to(device)y
#   optimiser  = torch.optim.Adam(...)
#   scheduler  = torch.optim.lr_scheduler.ExponentialLR(...)
#   replay_buf = collections.deque(maxlen=...)
# ---------------------------------------------------------------

#model_tgt = deepcopy(model).to(device, dtype=DTYPE)
#optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
#scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
#replay_buffer = deque(maxlen=REPLAY_BUFFER_SIZE)
#epsilon = EPSILON_START


Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.




In [10]:
class EpisodeCallback(BaseCallback):
    def __init__(self):
        super(EpisodeCallback, self).__init__()
        self.episode_reward = 0
        self.episode_num = 0
        self.episode_rwd_components = {
            "hit": 0.0,
            "hit_taken": 0.0,
            "frag": 0.0,
            "missed": 0.0,
            "survival": 0.0,
            "dead": 0.0,
            "spam_penalty": 0.0,
            "health": 0.0
        }

    def _on_step(self) -> bool:
        # Accumulate the total reward for the current step
        self.episode_reward += self.locals['rewards'][0]

        # Access the actual environment (VizdoomMPEnv) through DummyVecEnv and Monitor
        monitor_env = self.locals['env'].envs[0]  # This is the Monitor wrapper
        actual_env = monitor_env.env              # This is the VizdoomMPEnv
        last_rewards = actual_env.reward_fn.last_rewards[0]

        # Update reward components if rewards are available
        if last_rewards is not None:
            self.episode_rwd_components["hit"] += last_rewards[0]
            self.episode_rwd_components["hit_taken"] += last_rewards[1]
            self.episode_rwd_components["frag"] += last_rewards[2]
            self.episode_rwd_components["missed"] += last_rewards[3]
            self.episode_rwd_components["survival"] += last_rewards[4]
            self.episode_rwd_components["dead"] += last_rewards[5]
            self.episode_rwd_components["spam_penalty"] += last_rewards[6]
            self.episode_rwd_components["health"] += last_rewards[7]

        # Log data if the episode ends
        if self.locals['dones'][0]:
            self.episode_num += 1
            wandb.log({
                "episode": self.episode_num,
                "return": self.episode_reward,
                "rwd_hit": self.episode_rwd_components["hit"],
                "rwd_hit_taken": self.episode_rwd_components["hit_taken"],
                "rwd_frag": self.episode_rwd_components["frag"],
                "rwd_missed": self.episode_rwd_components["missed"],
                "rwd_survival": self.episode_rwd_components["survival"],
                "rwd_dead": self.episode_rwd_components["dead"],
                "rwd_spam_penalty": self.episode_rwd_components["spam_penalty"],
                "rwd_health": self.episode_rwd_components["health"],
            })
            # Reset accumulators
            self.episode_reward = 0
            for key in self.episode_rwd_components:
                self.episode_rwd_components[key] = 0.0

        return True

## Training loop

In [None]:
# ---------------------  TRAINING LOOP  ----------------------
# Feel free to change EVERYTHING below:
#   ‚Ä¢ choose your own reward function
#   ‚Ä¢ track different episode statistics in `ep_metrics`
#   ‚Ä¢ switch optimiser, scheduler, update rules, etc.
run = wandb.init(project="doom-rl", entity="soerenseitz-university-of-vienna", config={
    "gamma": GAMMA,
    "episodes": EPISODES,
    "batch_size": BATCH_SIZE,
    "replay_buffer_size": REPLAY_BUFFER_SIZE,
    "learning_rate": LEARNING_RATE,
    "epsilon_start": EPSILON_START,
    "epsilon_end": EPSILON_END,
    "epsilon_decay": EPSILON_DECAY,
    "num_bots": NUM_BOTS,
    "episode_timeout": EPISODE_TIMEOUT,
    "use_grayscale": USE_GRAYSCALE,
    "extra_state": PLAYER_CONFIG["extra_state"],
    "hud": PLAYER_CONFIG["hud"],
    "crosshair": PLAYER_CONFIG["crosshair"],
    "screen_format": PLAYER_CONFIG["screen_format"].name,
    "doom_map": "ROOM",
})
callback = EpisodeCallback()

model.learn(total_timesteps=TOTAL_TIMESTEPS, callback=callback, progress_bar=True)
final_model = model


[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33msoerenseitz[0m ([33msoerenseitz-university-of-vienna[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Output()

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1e+03    |
|    ep_rew_mean      | -3.54    |
|    exploration_rate | 0.982    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 54       |
|    time_elapsed     | 72       |
|    total_timesteps  | 4000     |
| train/              |          |
|    learning_rate    | 0.0005   |
|    loss             | 0.00022  |
|    n_updates        | 974      |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1e+03    |
|    ep_rew_mean      | -3.47    |
|    exploration_rate | 0.964    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 54       |
|    time_elapsed     | 147      |
|    total_timesteps  | 8000     |
| train/              |          |
|    learning_rate    | 0.0005   |
|    loss             | 0.000193 |
|    n_updates        | 1974     |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1e+03    |
|    ep_rew_mean      | -3.44    |
|    exploration_rate | 0.946    |
| time/               |          |
|    episodes         | 12       |
|    fps              | 54       |
|    time_elapsed     | 220      |
|    total_timesteps  | 12000    |
| train/              |          |
|    learning_rate    | 0.0005   |
|    loss             | 0.000177 |
|    n_updates        | 2974     |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1e+03    |
|    ep_rew_mean      | -3.46    |
|    exploration_rate | 0.928    |
| time/               |          |
|    episodes         | 16       |
|    fps              | 53       |
|    time_elapsed     | 297      |
|    total_timesteps  | 16000    |
| train/              |          |
|    learning_rate    | 0.0005   |
|    loss             | 0.000114 |
|    n_updates        | 3974     |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1e+03    |
|    ep_rew_mean      | -8.22    |
|    exploration_rate | 0.91     |
| time/               |          |
|    episodes         | 20       |
|    fps              | 53       |
|    time_elapsed     | 375      |
|    total_timesteps  | 20000    |
| train/              |          |
|    learning_rate    | 0.0005   |
|    loss             | 0.000126 |
|    n_updates        | 4974     |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1e+03    |
|    ep_rew_mean      | -9.37    |
|    exploration_rate | 0.892    |
| time/               |          |
|    episodes         | 24       |
|    fps              | 53       |
|    time_elapsed     | 450      |
|    total_timesteps  | 24000    |
| train/              |          |
|    learning_rate    | 0.0005   |
|    loss             | 0.000309 |
|    n_updates        | 5974     |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1e+03    |
|    ep_rew_mean      | -10.5    |
|    exploration_rate | 0.874    |
| time/               |          |
|    episodes         | 28       |
|    fps              | 53       |
|    time_elapsed     | 525      |
|    total_timesteps  | 28000    |
| train/              |          |
|    learning_rate    | 0.0005   |
|    loss             | 0.000189 |
|    n_updates        | 6974     |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1e+03    |
|    ep_rew_mean      | -12.6    |
|    exploration_rate | 0.856    |
| time/               |          |
|    episodes         | 32       |
|    fps              | 53       |
|    time_elapsed     | 601      |
|    total_timesteps  | 32000    |
| train/              |          |
|    learning_rate    | 0.0005   |
|    loss             | 0.000181 |
|    n_updates        | 7974     |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1e+03    |
|    ep_rew_mean      | -14.1    |
|    exploration_rate | 0.838    |
| time/               |          |
|    episodes         | 36       |
|    fps              | 50       |
|    time_elapsed     | 712      |
|    total_timesteps  | 36000    |
| train/              |          |
|    learning_rate    | 0.0005   |
|    loss             | 0.000164 |
|    n_updates        | 8974     |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1e+03    |
|    ep_rew_mean      | -13.1    |
|    exploration_rate | 0.82     |
| time/               |          |
|    episodes         | 40       |
|    fps              | 50       |
|    time_elapsed     | 789      |
|    total_timesteps  | 40000    |
| train/              |          |
|    learning_rate    | 0.0005   |
|    loss             | 0.000199 |
|    n_updates        | 9974     |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1e+03    |
|    ep_rew_mean      | -12.2    |
|    exploration_rate | 0.802    |
| time/               |          |
|    episodes         | 44       |
|    fps              | 50       |
|    time_elapsed     | 864      |
|    total_timesteps  | 44000    |
| train/              |          |
|    learning_rate    | 0.0005   |
|    loss             | 0.000207 |
|    n_updates        | 10974    |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1e+03    |
|    ep_rew_mean      | -11.5    |
|    exploration_rate | 0.784    |
| time/               |          |
|    episodes         | 48       |
|    fps              | 51       |
|    time_elapsed     | 939      |
|    total_timesteps  | 48000    |
| train/              |          |
|    learning_rate    | 0.0005   |
|    loss             | 0.00046  |
|    n_updates        | 11974    |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1e+03    |
|    ep_rew_mean      | -10.9    |
|    exploration_rate | 0.766    |
| time/               |          |
|    episodes         | 52       |
|    fps              | 51       |
|    time_elapsed     | 1015     |
|    total_timesteps  | 52000    |
| train/              |          |
|    learning_rate    | 0.0005   |
|    loss             | 0.000593 |
|    n_updates        | 12974    |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1e+03    |
|    ep_rew_mean      | -10.3    |
|    exploration_rate | 0.748    |
| time/               |          |
|    episodes         | 56       |
|    fps              | 51       |
|    time_elapsed     | 1093     |
|    total_timesteps  | 56000    |
| train/              |          |
|    learning_rate    | 0.0005   |
|    loss             | 0.000167 |
|    n_updates        | 13974    |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1e+03    |
|    ep_rew_mean      | -9.86    |
|    exploration_rate | 0.73     |
| time/               |          |
|    episodes         | 60       |
|    fps              | 49       |
|    time_elapsed     | 1202     |
|    total_timesteps  | 60000    |
| train/              |          |
|    learning_rate    | 0.0005   |
|    loss             | 0.000131 |
|    n_updates        | 14974    |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1e+03    |
|    ep_rew_mean      | -9.44    |
|    exploration_rate | 0.712    |
| time/               |          |
|    episodes         | 64       |
|    fps              | 20       |
|    time_elapsed     | 3073     |
|    total_timesteps  | 64000    |
| train/              |          |
|    learning_rate    | 0.0005   |
|    loss             | 0.000108 |
|    n_updates        | 15974    |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1e+03    |
|    ep_rew_mean      | -9.07    |
|    exploration_rate | 0.694    |
| time/               |          |
|    episodes         | 68       |
|    fps              | 17       |
|    time_elapsed     | 3901     |
|    total_timesteps  | 68000    |
| train/              |          |
|    learning_rate    | 0.0005   |
|    loss             | 0.000331 |
|    n_updates        | 16974    |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1e+03    |
|    ep_rew_mean      | -8.82    |
|    exploration_rate | 0.676    |
| time/               |          |
|    episodes         | 72       |
|    fps              | 12       |
|    time_elapsed     | 5679     |
|    total_timesteps  | 72000    |
| train/              |          |
|    learning_rate    | 0.0005   |
|    loss             | 0.000237 |
|    n_updates        | 17974    |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1e+03    |
|    ep_rew_mean      | -8.53    |
|    exploration_rate | 0.658    |
| time/               |          |
|    episodes         | 76       |
|    fps              | 13       |
|    time_elapsed     | 5806     |
|    total_timesteps  | 76000    |
| train/              |          |
|    learning_rate    | 0.0005   |
|    loss             | 8.57e-05 |
|    n_updates        | 18974    |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1e+03    |
|    ep_rew_mean      | -9.44    |
|    exploration_rate | 0.64     |
| time/               |          |
|    episodes         | 80       |
|    fps              | 13       |
|    time_elapsed     | 5898     |
|    total_timesteps  | 80000    |
| train/              |          |
|    learning_rate    | 0.0005   |
|    loss             | 0.000601 |
|    n_updates        | 19974    |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1e+03    |
|    ep_rew_mean      | -9.16    |
|    exploration_rate | 0.622    |
| time/               |          |
|    episodes         | 84       |
|    fps              | 13       |
|    time_elapsed     | 6012     |
|    total_timesteps  | 84000    |
| train/              |          |
|    learning_rate    | 0.0005   |
|    loss             | 0.000197 |
|    n_updates        | 20974    |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1e+03    |
|    ep_rew_mean      | -8.89    |
|    exploration_rate | 0.604    |
| time/               |          |
|    episodes         | 88       |
|    fps              | 14       |
|    time_elapsed     | 6134     |
|    total_timesteps  | 88000    |
| train/              |          |
|    learning_rate    | 0.0005   |
|    loss             | 0.000173 |
|    n_updates        | 21974    |
----------------------------------


## Dump to ONNX

In [None]:
import onnx
import json

def onnx_dump(env, model, config, run, filename_prefix="model"):
    # Create dummy input on CPU
    dummy_input = torch.randn(1, *env.observation_space.shape).float().to('cpu')
    print("Dummy input shape:", dummy_input.shape)  # Debug output
    
    # Ensure policy network is on CPU
    policy_net = model.policy.to('cpu')
    
    # Generate unique filename using wandb run ID
    run_id = run.id
    filename = f"{filename_prefix}_{run_id}.onnx"
    
    # Export the model to ONNX
    torch.onnx.export(
        policy_net,
        args=dummy_input,
        f=filename,
        export_params=True,
        opset_version=11,
        do_constant_folding=True,
        input_names=["input"],
        output_names=["output"],
        dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
    )
    
    # Add metadata to the ONNX model
    onnx_model = onnx.load(filename)
    meta = onnx_model.metadata_props.add()
    meta.key = "config"
    meta.value = json.dumps(config)
    onnx.save(onnx_model, filename)
    
    return filename

# Usage
export_config = {
    **{k: str(v) if isinstance(v, ScreenFormat) else v for k, v in PLAYER_CONFIG.items()},
    "algo_type": "Q",
}

# Assuming 'run' is the wandb run object
filename = onnx_dump(env, final_model, export_config, run, filename_prefix="model")
print(f"Best network exported to {filename}")

# Upload to wandb
artifact = wandb.Artifact('model', type='model')
artifact.add_file(filename)
run.log_artifact(artifact)
artifact.wait()
run.finish()

Dummy input shape: torch.Size([1, 128, 128, 5])
Best network exported to model.onnx


0,1
episode,‚ñÅ
return,‚ñÅ
rwd_dead,‚ñÅ
rwd_frag,‚ñÅ
rwd_health,‚ñÅ
rwd_hit,‚ñÅ
rwd_hit_taken,‚ñÅ
rwd_missed,‚ñÅ
rwd_spam_penalty,‚ñÅ
rwd_survival,‚ñÅ

0,1
episode,1.0
return,-1.4305
rwd_dead,-0.2
rwd_frag,-2.0
rwd_health,-0.2305
rwd_hit,0.0
rwd_hit_taken,0.0
rwd_missed,1.0
rwd_spam_penalty,0.0
rwd_survival,0.0


### Evaluation and Visualization

In this final section, you can evaluate your trained agent, inspect its performance visually, and analyze reward components over time.


In [None]:
# ---------------------------------------------------------------
# üìà  Reward-plot helper  (feel free to edit / extend)
# ---------------------------------------------------------------
import pandas as pd
import matplotlib.pyplot as plt

def plot_reward_components(reward_log, smooth_window: int = 5):
    """
    Plot raw and smoothed episode-level reward components.

    Parameters
    ----------
    reward_log : list[dict]
        Append a dict for each episode, e.g. {"frag": ‚Ä¶, "hit": ‚Ä¶, "hittaken": ‚Ä¶}
    smooth_window : int
        Rolling-mean window size for the smoothed curve.
    """
    if not reward_log:
        print("reward_log is empty ‚Äì nothing to plot.")
        return

    df = pd.DataFrame(reward_log)
    df_smooth = df.rolling(window=smooth_window, min_periods=1).mean()

    # raw
    plt.figure(figsize=(12, 5))
    for col in df.columns:
        plt.plot(df.index, df[col], label=col)
    plt.title("Raw episode reward components")
    plt.legend(); plt.grid(True); plt.tight_layout()
    plt.show()

    # smoothed
    plt.figure(figsize=(12, 5))
    for col in df.columns:
        plt.plot(df.index, df_smooth[col], label=f"{col} (avg)")
    plt.title(f"Smoothed (window={smooth_window})")
    plt.legend(); plt.grid(True); plt.tight_layout()
    plt.show()


# ----------------------------------------------------------------
# üîç  Hint for replay visualisation:
# ----------------------------------------------------------------
# env.enable_replay()
# ... run an evaluation episode ...
# env.disable_replay()
# replays = env.get_player_replays()
#
# from doom_arena.render import render_episode
# from IPython.display import HTML
# HTML(render_episode(replays, subsample=5).to_html5_video())
#
# Feel free to adapt or write your own GIF/MP4 export.
