In [1]:
%%html
<style>
.output_wrapper, .output {
    height:auto !important;
    max-height: 999999999999 in;  /* your desired max-height here */
}
.output_scroll {
    box-shadow:none !important;
    webkit-box-shadow:none !important;
}
</style>

In [3]:
%matplotlib inline
import ast
import os
import os.path as osp
from glob import glob

from IPython.display import display, Image, Markdown
from ipywidgets import interact, interactive
import ipywidgets as widgets 
import matplotlib.pylab as plt
import numpy as np
import pickle
import gin
import torch

from attackgraph import settings
from attackgraph.soccer.envs.gridworld_soccer import GridWorldSoccer
from attackgraph.rl.dqn.dqn import DQN
import attackgraph.gambit_analysis as gambit_ops
import attackgraph.common.plot_ops as plot_ops
import attackgraph.soccer.policies as player2_policies
import attackgraph.soccer.qmixing_main


np.set_printoptions(precision=2)
RESULTS_DIR = settings.get_results_dir()

gin.parse_config_file(osp.join(settings.SRC_DIR, "configs", "soccer_qmix.gin"))

In [10]:
def render_trajectory(p1_path, p2):
    # Load the player's policy.
    p1 = torch.load(p1_path)
    p1.eval()
    
    t = 0
    p1_return = 0.0
    p2_return = 0.0
    
    # Run an episode.
    env = GridWorldSoccer()
    o = env.reset()
    print(env.render())
    
    while True:
        a1 = p1(observation=o[1][None], stochastic=True, update_eps=-1, mask=None, training_attacker=False)
        a2 = p2(observation=o[2], stochastic=True, update_eps=-1, mask=None, training_attacker=False)
    
        o, r, d, _ =env.step({1: a1, 2: a2})
        print(env.render(), '\n')
        
        p1_return += r[1]
        p2_return += r[2]
        
        if d:
            break
        if t > 100:
            break
        t += 1
    print(f"\nFinal Returns:\n\t- P1: {p1_return}\n\t- P2: {p2_return}")


def display_result(policy_path, opponent_id):
    """ Display the trajectories of the soccer training.
    
    :param run_name: Name of the run.
    :type run_name: str
    """
    policy_path = osp.join(RESULTS_DIR, policy_path)
    opponent = {
        0: player2_policies.Player2v0(),
        1: player2_policies.Player2v1(),
        2: player2_policies.Player2v2(),
        3: player2_policies.Player2v3(),
        4: player2_policies.Player2v4(),        
    }[opponent_id]
    
    print(f"Policy: {policy_path}")
    print(f"Opponent: {opponent}")
    
    render_trajectory(policy_path, opponent)
    

In [9]:
policy_path_widget = widgets.Text(
    description="Policy Path: ",
    default="soccer1/qmix.pkl")
opponent_widget = widgets.Dropdown(
    options=[0, 1, 2, 3, 4],
    value=0,
    description='Opponent: ')

display(interactive(
    display_result,
    policy_path=policy_path_widget,
    opponent_id=opponent_widget))

interactive(children=(Text(value='', description='Policy Path: '), Dropdown(description='Opponent: ', options=…