# `jku.wad` Tournament

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import random
import yaml
from collections import defaultdict

from doom_arena import VizdoomMPEnv

In [None]:
STORE_REPLAYS = True
SUBMISSION_ROOT = ""
REPLAY_ROOT = ""

## Set up tree

In [None]:
class Tournament:
    def __init__(self, submission_dir: str, max_players_per_match: int = 4):
        self.submissions = {}
        # TODO what is the default naming pattern?
        for filename in os.listdir(submission_dir):
            if filename.endswith(".onnx"):
                self.submissions[filename] = os.path.join(submission_dir, filename)

        self.round = 0
        self.init_max_players_per_match = max_players_per_match
        self.max_players_per_match = max_players_per_match
        self.adjust_max_players_per_match(len(self.submissions))

    def adjust_max_players_per_match(self, players_left: int):
        if players_left % self.max_players_per_match == 0:
            return

        for i in range(self.max_players_per_match, 1, -1):
            if players_left % i != 1:
                self.max_players_per_match = i
                break
        else:
            self.max_players_per_match = 2

    def create_initial_round(self):
        self.round = 0
        self.max_players_per_match = self.init_max_players_per_match
        players = list(self.submissions.keys())
        random.shuffle(players)

        while len(players) % self.max_players_per_match != 0:
            players.append(None)

        matches = [
            players[i : i + self.max_players_per_match]
            for i in range(0, len(players), self.max_players_per_match)
        ]
        return matches

    def advance_round(self, winners):
        self.round += 1
        if len(winners) == 1:
            return winners
        next_round_players = winners.copy()
        self.adjust_max_players_per_match(len(next_round_players))

        while len(next_round_players) % self.max_players_per_match != 0:
            next_round_players.append(None)

        next_matches = [
            next_round_players[i : i + self.max_players_per_match]
            for i in range(0, len(next_round_players), self.max_players_per_match)
        ]
        return next_matches

In [None]:
def mock_listdir(_):
    return [f"student_{i:02}.onnx" for i in range(1, 11)]


os.listdir = mock_listdir

## Run tournament

In [None]:
tourn = Tournament(SUBMISSION_ROOT)

In [None]:
def load_onnx(path):
    # TODO
    def rnd_policy(obs):
        act = 2
        return act

    return rnd_policy


def play_match(
    players, submissions, num_episodes: int = 1, record_replay: bool = False
):
    players = [p for p in players if p is not None]
    # TODO student env configs
    # player_configs = [yaml.safe_load(submissions[p]) for p in players]
    player_configs = [{} for p in players]
    player_agents = {p: load_onnx(submissions[p]) for p in players}
    env = VizdoomMPEnv(
        num_players=len(players),
        num_bots=0,
        doom_map="TRNM",
        extra_state=[pcfg.get("extra_state", None) for pcfg in player_configs],
        crosshair=[pcfg.get("crosshair", False) for pcfg in player_configs],
        hud=[pcfg.get("hud", "full") for pcfg in player_configs],
        episode_timeout=2000,
    )

    results = []
    for ep in range(num_episodes):
        if record_replay and ep == 0:
            env.enable_replay()
        else:
            env.disable_replay()
        ep_return = {k: 0.0 for k in players}
        obs = env.reset()
        done = False
        while not done:
            player_acts = [agnt(obs) for agnt in player_agents.values()]
            obs, rwd, done, _ = env.step(player_acts)
            ep_return = {k: ep_return[k] + rwd[i] for i, k in enumerate(ep_return)}

        results.append(ep_return)

    # average episodes per player
    match_returns = {p: 0.0 for p in players}
    for result in results:
        for p in players:
            match_returns[p] += result[p]
    match_returns = {p: total / num_episodes for p, total in match_returns.items()}
    leaderboard = sorted(match_returns.items(), key=lambda x: x[1], reverse=True)
    if record_replay:
        player_replays = env.get_player_replays()
        # rename with actual player names
        player_replays = {
            players[i]: player_replays[k] for i, k in enumerate(player_replays)
        }
        return leaderboard, player_replays
    else:
        return leaderboard


def play_round(
    matches, submissions, num_episodes: int = 1, record_replay: bool = False
):
    round_winners, round_leaderboards = [], []
    round_replays = []
    for players in matches:
        results = play_match(players, submissions, num_episodes, record_replay)
        if record_replay:
            leaderboard, replays = results
            round_replays.append(replays)
        else:
            leaderboard = results
        round_winners.append(leaderboard[0][0])
        round_leaderboards.append(leaderboard)

    return round_winners, round_leaderboards, round_replays


def leaderboard_reduce(leaderboards):
    scores = defaultdict(lambda: [0.0, 0])

    for round_group in leaderboards:
        for match in round_group:
            for student, score in match:
                scores[student][0] += float(score)  # Accumulate score
                scores[student][1] += 1  # Increment count

    # Convert to dict of tuples for readability
    return {k: (v[0], v[1]) for k, v in scores.items()}

In [None]:
current_matches = tourn.create_initial_round()

return_leaderboard = []
round_matches = [current_matches]
round_winners = []
round_replays = []
while len(current_matches) > 1:
    print(f"\nNow playing: round {tourn.round}:")
    for match in current_matches:
        print("\tMatch:", match)
    winners, leaderboards, replays = play_round(
        current_matches, tourn.submissions, record_replay=STORE_REPLAYS
    )
    return_leaderboard.append(leaderboards)
    round_winners.append(winners)
    round_replays.append(replays)
    print("\tWinners:", winners)
    # advange tree to next round
    current_matches = tourn.advance_round(winners)
    round_matches.append(current_matches)

print(f"\nNow playing: final round:")
final_match = current_matches[0]
print("\tMatch:", match)
winners, leaderboards, replays = play_round(
    current_matches, tourn.submissions, record_replay=STORE_REPLAYS
)
return_leaderboard.append(leaderboards)
round_winners.append(winners)
round_replays.append(replays)
print("\tWinner:", winners)

## Plotting

In [None]:
from doom_arena.render import render_episode

if len(round_replays[0]) > 0:
    for round, match_replays in enumerate(round_replays):
        for match, player_replays in enumerate(match_replays):
            render_episode(
                player_replays,
                subsample=10,
                replay_path=os.path.join(REPLAY_ROOT, f"round{round}_match{match}.mp4"),
            )

In [None]:
import matplotlib.pyplot as plt
import matplotlib


def plot_tournament_tree(round_matches, round_winners):
    num_rounds = len(round_matches)
    num_matches0 = len(round_matches[0])
    cmap = matplotlib.colormaps["tab10"]

    positions = []
    for round_idx in range(num_rounds):
        if round_idx == 0:
            current_positions = [
                (round_idx, idx * 2) for idx in range(len(round_matches[round_idx]))
            ]
        else:
            current_positions = []
            for match_idx in range(len(round_matches[round_idx])):
                parents = []
                for parent_idx in range(len(round_matches[round_idx - 1])):
                    parent_winner = round_winners[round_idx - 1][parent_idx]
                    if parent_winner in round_matches[round_idx][match_idx]:
                        parents.append(parent_idx)
                if parents:
                    avg_y = sum(positions[round_idx - 1][p][1] for p in parents) / len(
                        parents
                    )
                else:
                    avg_y = 0
                current_positions.append((round_idx, avg_y))
        positions.append(current_positions)

    figs = []
    for plot_idx in range(len(round_matches) + 1):
        fig, ax = plt.subplots(figsize=(7 * num_rounds, 2 * num_matches0))
        ax.axis("off")

        # Draw connections
        for i in range(num_rounds - 1):
            for match_idx in range(len(round_matches[i])):
                winner = round_winners[i][match_idx]
                for next_match_idx in range(len(round_matches[i + 1])):
                    if winner in round_matches[i + 1][next_match_idx]:
                        from_x, from_y = positions[i][match_idx]
                        to_x, to_y = positions[i + 1][next_match_idx]
                        plt.plot(
                            [from_x, to_x],
                            [from_y, to_y],
                            color="gray",
                            linewidth=2,
                            zorder=1,
                        )
                        break

        # Draw matches
        for round_idx, match_list in enumerate(round_matches):
            for match_idx, participants in enumerate(match_list):
                x, y = positions[round_idx][match_idx]
                winner = round_winners[round_idx][match_idx]
                lines = []
                for p in participants:
                    if p is None:
                        continue
                    if p == winner and round_idx < plot_idx:
                        p = p.replace("_", r"\_")
                        lines.append(r"♛$\mathbf{ " + p + "}$♛")
                    else:
                        lines.append(p)
                text_str = "\n".join(lines)

                # Color highlight for round
                color = cmap(round_idx)
                bbox_props = dict(
                    facecolor="white",
                    edgecolor=color,
                    boxstyle="round,pad=0.5",
                    linewidth=2,
                )

                ax.text(
                    x,
                    y,
                    text_str,
                    ha="center",
                    va="center",
                    bbox=bbox_props,
                    fontsize=12,
                    zorder=2,
                )

        figs.append(fig)
        # Round labels
        for round_idx in range(num_rounds):
            max_y = max(y for _, y in positions[0])
            ax.text(
                round_idx,
                max_y + 2,
                f"Round {round_idx + 1}",
                ha="center",
                va="bottom",
                fontsize=10,
                fontweight="bold",
                color=cmap(round_idx),
            )
        plt.show()
    return figs

In [None]:
plot_tournament_tree(round_matches, round_winners)