In [1]:
from pystk2_gymnasium.stk_wrappers import ConstantSizedObservations, PolarObservations, DiscreteActionsWrapper
from pystk2_gymnasium.wrappers import FlattenerWrapper
from stk_actor.wrappers import PreprocessObservationWrapper

import ipyparallel
import torch
import tqdm
import pandas as pd
import seaborn as sns


In [2]:





def eval_agent(args_list):

    from stk_actor.agent import UnifiedSACPolicy
    from stk_actor.wrappers import PreprocessObservationWrapper
    from pystk2_gymnasium.stk_wrappers import ConstantSizedObservations, PolarObservations, DiscreteActionsWrapper
    from pystk2_gymnasium.wrappers import FlattenerWrapper
    from pystk2_gymnasium import MonoAgentWrapperAdapter
    from stk_actor.wrappers import PreprocessObservationWrapper
    from pystk2_gymnasium import AgentSpec
    from bbrl.agents import Agents
    from bbrl.agents.gymnasium import ParallelGymAgent, make_env
    from bbrl.workspace import Workspace
    from functools import partial
    from stk_actor.actors import Actor, ArgmaxActor

    agents = []
    agents_spec = []
    player_names = []

    def get_action(workspace: Workspace, t: int):
        name = "action"

        if name in workspace.variables:
            # Action is a tensor
            action = workspace.get(name, t)
        else:
            # Action is a dictionary
            action = {}
            prefix = f"{name}/"
            len_prefix = len(prefix)
            for varname in workspace.variables:
                if not varname.startswith(prefix):
                    continue
                keys = varname[len_prefix:].split("/")
                current = action
                for key in keys[:-1]:
                    current = current.setdefault(key, {})
                current[keys[-1]] = workspace.get(varname, t)

        return action


    def dict_slice(k: int, object):
        if isinstance(object, dict):
            return {key: dict_slice(k, value) for key, value in object.items()}
        return object[k]

    def is_integer(n):
        try:
            float(n)
        except ValueError:
            return False
        else:
            return float(n).is_integer()

    def wrapper_func(state_items, state_karts, state_paths, norm, agent_name, env):
        return PreprocessObservationWrapper(
            FlattenerWrapper(
                DiscreteActionsWrapper(
                    PolarObservations(
                        ConstantSizedObservations(
                            env,
                            state_items = state_items,
                            state_karts = state_karts, 
                            state_paths = state_paths,
                        )
                    )
                )
            ), 
            ret_dict=True, 
            norm=norm,
            agent_name=agent_name
        )
    
    interactive=False

    wrapper_factories = {}
    for agent_ix, (track, agent_name, statedict_path, net_arch, activation_fn, mean, std, state_items, state_karts, state_paths) in enumerate(args_list):
        wrapper_factories[str(agent_ix)]=partial(wrapper_func, state_items, state_karts, state_paths, (std is not None and mean is not None), agent_name)
        agents_spec.append(AgentSpec(name=agent_name))
        player_names.append(agent_name)
        
    n_agents = len(args_list)

    env = make_env(
        "supertuxkart/multi-full-v0",
        render_mode=None,
        agents=agents_spec,
        num_kart=n_agents,
        track=track,
        wrappers=[
            partial(
                    MonoAgentWrapperAdapter,
                    keep_original=interactive,
                    wrapper_factories=wrapper_factories,
            )
        ]
    )

    for agent_ix, (track, agent_name, statedict_path, net_arch, activation_fn, mean, std, state_items, state_karts, state_paths) in enumerate(args_list):

        actor = Actor(
            env.observation_space[str(agent_ix)], env.action_space[str(agent_ix)],
            net_arch=net_arch, 
            activation_fn=activation_fn,
            state_dict_path=statedict_path,
        )
        agent = Agents(actor, ArgmaxActor())
        agents.append(agent)

    for agent in agents:
        agent.eval()

    workspaces = [Workspace() for _ in range(n_agents)]
    print("Starting a race")

    done = False
    obs, _ = env.reset()
    choice = ""

    # List possible keys
    keys = []
    for key, item in obs.items():
        for subkey in item.keys():
            keys.append((key, subkey))
    keys.sort()

    t = 0
    while not done:

        actions = {}
        for ix in range(n_agents):
            key = str(ix)
            obs_agent = ParallelGymAgent._format_frame(obs[key])
            for var_key, var_value in obs_agent.items():
                workspaces[ix].set(f"env/{var_key}", t, var_value)
            agents[ix](workspaces[ix], t=t)
            action = get_action(workspaces[ix], t=t)
            if isinstance(action, dict):
                action = dict_slice(0, action)
            else:
                action = action[0]

            actions[key] = action

        obs, reward, terminated, truncated, info = env.step(actions)
        done = terminated or truncated
        t += 1

    records = []
    rewards = info["reward"]
    for ix in range(n_agents):
        key = str(ix)
        print(  # noqa: T201
            f"{rewards[key]}\t{info['infos'][key]['position']}"
            f"\t{ix}\t{player_names[ix]}"
        )
        records.append({
            'agent_name': player_names[ix],
            'reward' : rewards[str(ix)],
            'position' : info['infos'][str(ix)]['position'],
            'track' : track,
        })

    return records




In [3]:
path = 'stk_actor/trained_agents/'

agents = [
    [
        'normed_behavioral_cloning_num5',
        path+'normed_behavioral_cloning_num5/statedict',
        path+'normed_behavioral_cloning_num5/buffer_mean',
        path+'normed_behavioral_cloning_num5/buffer_std',
        5, 5 ,5,
    ],
    [
        'normed_behavioral_cloning_num10',
        path+'normed_behavioral_cloning_num10/statedict',
        path+'normed_behavioral_cloning_num10/buffer_mean',
        path+'normed_behavioral_cloning_num10/buffer_std',
        10,10,10,
    ],
    [
        'non_normed_behavioral_cloning_num5',
        path+'non_normed_behavioral_cloning_num5/statedict',
        None,
        None,
        5, 5 ,5,
    ],
    [
        'non_normed_behavioral_cloning_num10',
        path+'non_normed_behavioral_cloning_num10/statedict',
        None,
        None,
        10,10,10,
    ],
    [
        'normed_a2c_num5_no_init',
        path+'normed_a2c_num5_no_init/statedict',
        path+'normed_a2c_num5_no_init/buffer_mean',
        path+'normed_a2c_num5_no_init/buffer_std',
        5,5,5,
    ],
    [
        'normed_ppo_num5_no_init',
        path+'normed_ppo_num5_no_init/statedict',
        path+'normed_ppo_num5_no_init/buffer_mean',
        path+'normed_ppo_num5_no_init/buffer_std',
        5,5,5,
    ],
    [
        'normed_a2c_num5_best',
        path+'normed_a2c_num5_best/statedict',
        path+'normed_a2c_num5_best/buffer_mean',
        path+'normed_a2c_num5_best/buffer_std',
        5,5,5,
    ],
    [
        'normed_ppo_num5_best',
        path+'normed_ppo_num5_best/statedict',
        path+'normed_ppo_num5_best/buffer_mean',
        path+'normed_ppo_num5_best/buffer_std',
        5,5,5,
    ],
]

net_arch=[1024,1024,1024]
activation_fn=torch.nn.Tanh

tracks = [
    'black_forest','olivermath','minigolf','gran_paradiso_island', 'candela_city', 'mines', 'snowmountain', 'abyss', 'cornfield_crossing', 'hacienda','lighthouse', 'snowtuxpeak', 'zengarden', 'fortmagma','ravenbridge_mansion', 'cocoa_temple', 'sandtrack', 'scotland', 'stk_enterprise', 'volcano_island','xr591',       
]

args_list_list = []

for track in tracks:

    args_list = []

    for agent_name, statedict_path, mean_path, std_path, state_items, state_karts, state_paths in agents:
        
        mean, std = None, None
        if mean_path is not None:   
            mean = torch.load(mean_path, map_location='cpu', weights_only=True)
        if std_path is not None:
            std = torch.load(std_path, map_location='cpu', weights_only=True)

        args_list.append(
            (
                track, agent_name, statedict_path, net_arch, activation_fn, mean, std, state_items, state_karts, state_paths,
            )
        )
    args_list_list.append(args_list)


In [4]:
def parallel_run_episodes(args,):
    client = ipyparallel.Client()
    dview = client[:]

    dview.push({'eval_agent': eval_agent})
    results = dview.map(eval_agent, args, )

    print('running:', len(args))
    
    records = []
    for rec in tqdm.tqdm(results, total=len(args)):
        records.extend(rec)

    client.close()
    return records


In [5]:
records = parallel_run_episodes(args_list_list)

running: 21


100%|██████████| 21/21 [16:50<00:00, 48.10s/it]   


In [9]:
df = pd.DataFrame.from_records(records)
df = df.sort_values('position')
df

Unnamed: 0,agent_name,reward,position,track
55,normed_ppo_num5_best,18.718530,1,snowmountain
57,normed_behavioral_cloning_num10,18.718872,1,abyss
118,normed_a2c_num5_best,18.717261,1,ravenbridge_mansion
97,normed_behavioral_cloning_num10,18.712756,1,zengarden
23,normed_ppo_num5_best,18.697437,1,minigolf
...,...,...,...,...
141,normed_ppo_num5_no_init,-0.099998,8,scotland
130,non_normed_behavioral_cloning_num5,0.048630,8,sandtrack
76,normed_a2c_num5_no_init,-0.100318,8,hacienda
44,normed_a2c_num5_no_init,-0.099998,8,mines


In [10]:
def std(x):
    return x.std(ddof=0)


df_grouped = df.groupby('agent_name', as_index=False)[['position', 'reward']].agg({'position':['mean',std],'reward':['mean', std]})
df_grouped.columns = ['agent_name'] + [f"{col[0]}_{col[1]}" for col in df_grouped.columns[1:]]

df_grouped

Unnamed: 0,agent_name,position_mean,position_std,reward_mean,reward_std
0,non_normed_behavioral_cloning_num10,6.285714,1.277753,0.548363,0.485289
1,non_normed_behavioral_cloning_num5,6.380952,1.396465,0.519989,0.526661
2,normed_a2c_num5_best,2.095238,1.376841,15.411088,5.788502
3,normed_a2c_num5_no_init,5.904762,1.230747,0.696442,0.477738
4,normed_behavioral_cloning_num10,3.285714,1.749636,12.881247,6.777753
5,normed_behavioral_cloning_num5,3.714286,1.419016,10.420558,7.314071
6,normed_ppo_num5_best,2.047619,1.587936,14.864902,6.634027
7,normed_ppo_num5_no_init,6.285714,1.277753,0.546978,0.493031


In [12]:
print(
df_grouped.sort_values('position_mean').replace(
    {
        'normed_a2c_num5_best':'AC2+Normalization+Size5+InitBHC',
        'normed_a2c_num5_no_init':'AC2+Normalization+Size5+NoInit',
        'normed_behavioral_cloning_num10':'BHC+Normalization+Size10',
        'normed_behavioral_cloning_num5':'BHC+Normalization+Size5',
        'normed_ppo_num5_best':'PPO+Normalization+Size5+InitBHC',
        'normed_ppo_num5_no_init':'PPO+Normalization+Size5+NoInit',
        'non_normed_behavioral_cloning_num5':'BHC+Size10',
        'non_normed_behavioral_cloning_num10':'BHC+Size5',
    }
).to_latex(index=False)
)

\begin{tabular}{lrrrr}
\toprule
agent_name & position_mean & position_std & reward_mean & reward_std \\
\midrule
PPO+Normalization+Size5+InitBHC & 2.047619 & 1.587936 & 14.864902 & 6.634027 \\
AC2+Normalization+Size5+InitBHC & 2.095238 & 1.376841 & 15.411088 & 5.788502 \\
BHC+Normalization+Size10 & 3.285714 & 1.749636 & 12.881247 & 6.777753 \\
BHC+Normalization+Size5 & 3.714286 & 1.419016 & 10.420558 & 7.314071 \\
AC2+Normalization+Size5+NoInit & 5.904762 & 1.230747 & 0.696442 & 0.477738 \\
BHC+Size5 & 6.285714 & 1.277753 & 0.548363 & 0.485289 \\
PPO+Normalization+Size5+NoInit & 6.285714 & 1.277753 & 0.546978 & 0.493031 \\
BHC+Size10 & 6.380952 & 1.396465 & 0.519989 & 0.526661 \\
\bottomrule
\end{tabular}

