In [1]:
import gymnasium as gym
import torch
import torch.nn.functional as F

class PreprocessObservationWrapper(gym.ObservationWrapper):
    def __init__(self, env):
        """
        A Gym wrapper to preprocess mixed observation space (continuous + discrete)
        into a flat tensor.
        
        Args:
            env: The Gym environment to wrap.
        """
        super().__init__(env)
        self.observation_space = self._get_flat_observation_space(env.observation_space)
        self.mean = torch.tensor([1.2639206647872925,
 0.0,
 -0.07462655007839203,
 -0.07617677748203278,
 3.325659990310669,
 -0.6260609030723572,
 686.5313720703125,
 0.90716952085495,
 3.248101347708143e-05,
 -2.2646768229606096e-06,
 0.7184554934501648,
 -0.00979185476899147,
 -0.018949387595057487,
 42.55070114135742,
 -0.0018095501000061631,
 -0.0037681907415390015,
 50.57154083251953,
 0.006202289834618568,
 0.0015260628424584866,
 60.573333740234375,
 -0.0015078107826411724,
 -0.008679982274770737,
 83.0267333984375,
 -0.004257954191416502,
 -0.052518781274557114,
 92.90336608886719,
 -0.04041219875216484,
 -0.07217442244291306,
 29.27079963684082,
 -0.06311116367578506,
 -0.09916721284389496,
 44.90357208251953,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.10005245357751846,
 682.83642578125,
 690.8487548828125,
 685.5414428710938,
 693.1297607421875,
 687.2418212890625,
 694.6862182617188,
 688.3353271484375,
 695.63671875,
 689.5924682617188,
 696.7789306640625,
 0.036972709000110626,
 0.032977644354104996,
 7.680975437164307,
 0.023797351866960526,
 0.004713739734143019,
 14.885013580322266,
 0.01685260608792305,
 -0.003614815417677164,
 22.13359260559082,
 0.010710783302783966,
 -0.006463498808443546,
 29.126462936401367,
 0.005819129757583141,
 -0.007047213148325682,
 35.78960418701172,
 0.08564247190952301,
 0.011537699960172176,
 6.701535701751709,
 0.03396379202604294,
 0.0338205024600029,
 8.146260261535645,
 0.022045161575078964,
 0.0064937700517475605,
 15.329230308532715,
 0.015207406133413315,
 -0.001814433140680194,
 22.549062728881836,
 0.009987653233110905,
 -0.004953742492944002,
 29.498981475830078,
 10.452672958374023,
 10.426018714904785,
 10.403871536254883,
 10.385071754455566,
 10.37171459197998,
 0.2544378638267517,
 1.0111130475997925,
 -0.006203718949109316,
 0.047264423221349716,
 17.174535751342773,
 0.011489897966384888,
 0.003900744253769517,
 0.04766101762652397,
 0.030281564220786095,
 0.0,
 0.0,
 0.04146302491426468,
 0.0,
 0.0,
 0.8652037382125854,
 0.4444541931152344,
 0.23419933021068573,
 0.03508015722036362,
 0.26018109917640686,
 0.026085197925567627,
 0.0,
 0.0,
 0.4107459783554077,
 0.26497969031333923,
 0.03951539844274521,
 0.26854822039604187,
 0.01621069572865963,
 0.0,
 0.0,
 0.4971160888671875,
 0.18721193075180054,
 0.039390929043293,
 0.2605591118335724,
 0.015721958130598068,
 0.0,
 0.0,
 0.41856396198272705,
 0.22104644775390625,
 0.029045993462204933,
 0.30909326672554016,
 0.022250350564718246,
 0.0,
 0.0,
 0.41806697845458984,
 0.23936127126216888,
 0.04056426137685776,
 0.28310322761535645,
 0.018904240801930428,
 0.0,
 0.0,
 0.9955912828445435,
 0.0044087013229727745,
 0.0,
 0.0,
 0.0,
 1.0,
 1.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0])
        self.std = torch.tensor([4.425504207611084,
 0.0,
 1.603779673576355,
 1.7964426279067993,
 9.993358612060547,
 10.51357650756836,
 512.2705078125,
 1.6997371912002563,
 0.026459209620952606,
 0.010543138720095158,
 0.02026776410639286,
 0.7250668406486511,
 0.39837801456451416,
 34.88729476928711,
 0.7514939904212952,
 0.44161456823349,
 35.87752914428711,
 0.8336662650108337,
 0.539232075214386,
 41.78753662109375,
 1.0223793983459473,
 0.7299631237983704,
 47.53340530395508,
 1.098395586013794,
 0.8274492025375366,
 47.81513214111328,
 1.6523897647857666,
 1.749129295349121,
 28.97491455078125,
 2.2789063453674316,
 2.5216972827911377,
 40.0125846862793,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.06416473537683487,
 512.6343994140625,
 511.6667175292969,
 511.3175964355469,
 510.3323059082031,
 509.7705383300781,
 508.8245544433594,
 508.5116271972656,
 507.6471252441406,
 507.43463134765625,
 506.58172607421875,
 0.6846545338630676,
 0.5445173382759094,
 13.64755916595459,
 0.44763320684432983,
 0.452096164226532,
 17.358463287353516,
 0.42862439155578613,
 0.4299456775188446,
 20.870670318603516,
 0.451620489358902,
 0.43150416016578674,
 24.120229721069336,
 0.4922274351119995,
 0.4394797086715698,
 27.007665634155273,
 2.4632468223571777,
 2.7507951259613037,
 13.712183952331543,
 0.689609706401825,
 0.5509278774261475,
 17.15859603881836,
 0.45551878213882446,
 0.4613153338432312,
 20.009708404541016,
 0.43676477670669556,
 0.4395506978034973,
 22.933229446411133,
 0.45909419655799866,
 0.4403861463069916,
 25.679624557495117,
 2.8094654083251953,
 2.8031206130981445,
 2.80706524848938,
 2.8221468925476074,
 2.842970848083496,
 1.5037912130355835,
 0.11496783792972565,
 0.9299623370170593,
 1.1188806295394897,
 6.216769218444824,
 0.10657340288162231,
 0.06233403459191322,
 0.2130480855703354,
 0.1713610738515854,
 0.0,
 0.0,
 0.19935867190361023,
 0.0,
 0.0,
 0.3415059745311737,
 0.4969053268432617,
 0.4234975576400757,
 0.18398252129554749,
 0.4387334883213043,
 0.15938878059387207,
 0.0,
 0.0,
 0.4919694662094116,
 0.44132259488105774,
 0.19481778144836426,
 0.44320452213287354,
 0.12628509104251862,
 0.0,
 0.0,
 0.49999192357063293,
 0.3900817334651947,
 0.19452330470085144,
 0.4389398992061615,
 0.12439771741628647,
 0.0,
 0.0,
 0.49332383275032043,
 0.414951890707016,
 0.1679355502128601,
 0.4621199071407318,
 0.14749675989151,
 0.0,
 0.0,
 0.4932415187358856,
 0.4266938269138336,
 0.19727858901023865,
 0.4505063593387604,
 0.13618695735931396,
 0.0,
 0.0,
 0.06625155359506607,
 0.06625155359506607,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0])

    def _get_flat_observation_space(self, observation_space):
        """
        Create a flat observation space based on the original observation space.
        
        Args:
            observation_space: Original observation space with 'continuous' and 'discrete' components.
        
        Returns:
            A flattened observation space.
        """
        continuous_dim = observation_space['continuous'].shape[0]
        discrete_dims = sum(space.n for space in observation_space['discrete'])
        flat_dim = continuous_dim + discrete_dims
        return gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(flat_dim,), dtype=float)

    def observation(self, obs):
        """
        Process the observation into a flat tensor.
        
        Args:
            obs: The raw observation from the environment.
        
        Returns:
            A preprocessed flat tensor.
        """
        continuous_obs, discrete_obs = obs['continuous'], obs['discrete']
        continuous_tensor = torch.FloatTensor(continuous_obs)
        
        discrete_tensors = [
            F.one_hot(torch.tensor(x), num_classes=num_classes.n).float()
            for x, num_classes in zip(discrete_obs, self.env.observation_space['discrete'])
        ]
        
        flat_tensor = torch.cat([continuous_tensor] + discrete_tensors)
        normed_flat_tensor = (flat_tensor - self.mean) / (self.std + 1e-8)
        return normed_flat_tensor


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
from typing import Dict, List, Tuple, Union, Type
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical

import gymnasium as gym
from gymnasium import spaces

def get_device(device: Union[torch.device, str] = "auto") -> torch.device:
    if device == "auto":
        device = "cuda"
    device = torch.device(device)
    if device.type == torch.device("cuda").type and not torch.cuda.is_available():
        return torch.device("cpu")
    return device

class BaseFeaturesExtractor(nn.Module):
    def __init__(self, observation_space: gym.Space, features_dim: int = 0) -> None:
        super().__init__()
        assert features_dim > 0
        self._observation_space = observation_space
        self._features_dim = features_dim
    @property
    def features_dim(self) -> int:
        return self._features_dim

def get_flattened_obs_dim(observation_space: spaces.Space) -> int:
    if isinstance(observation_space, spaces.MultiDiscrete):
        return sum(observation_space.nvec)
    else:
        return spaces.utils.flatdim(observation_space)

class FlattenExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.Space) -> None:
        super().__init__(observation_space, get_flattened_obs_dim(observation_space))
        self.flatten = nn.Flatten()
    def forward(self, observations: torch.Tensor) -> torch.Tensor:
        return self.flatten(observations)
    
class MlpExtractor(nn.Module):
    def __init__(
        self,
        feature_dim: int,
        net_arch: Union[List[int], Dict[str, List[int]]],
        activation_fn: Type[nn.Module],
        device: Union[torch.device, str] = "auto",
    ) -> None:
        super().__init__()
        # device = torch.get_device(device)
        policy_net: List[nn.Module] = []
        value_net: List[nn.Module] = []
        last_layer_dim_pi = feature_dim
        last_layer_dim_vf = feature_dim

        if isinstance(net_arch, dict):
            pi_layers_dims = net_arch.get("pi", []) 
            vf_layers_dims = net_arch.get("vf", []) 
        else:
            pi_layers_dims = vf_layers_dims = net_arch
        for curr_layer_dim in pi_layers_dims:
            policy_net.append(nn.Linear(last_layer_dim_pi, curr_layer_dim))
            policy_net.append(activation_fn())
            last_layer_dim_pi = curr_layer_dim
        for curr_layer_dim in vf_layers_dims:
            value_net.append(nn.Linear(last_layer_dim_vf, curr_layer_dim))
            value_net.append(activation_fn())
            last_layer_dim_vf = curr_layer_dim

        self.latent_dim_pi = last_layer_dim_pi
        self.latent_dim_vf = last_layer_dim_vf
        self.policy_net = nn.Sequential(*policy_net)#.to(device)
        self.value_net = nn.Sequential(*value_net)#.to(device)

    def forward(self, features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        :return: latent_policy, latent_value of the specified network.
            If all layers are shared, then ``latent_policy == latent_value``
        """
        return self.forward_actor(features), self.forward_critic(features)

    def forward_actor(self, features: torch.Tensor) -> torch.Tensor:
        return self.policy_net(features)

    def forward_critic(self, features: torch.Tensor) -> torch.Tensor:
        return self.value_net(features)

    
class Policy(nn.Module):
    def __init__(self, observation_space, action_dims, net_arch, activation_fn,):
        super().__init__()
        self.features_extractor = FlattenExtractor(observation_space)
        self.pi_features_extractor = self.features_extractor
        self.vf_features_extractor = self.features_extractor
        self.mlp_extractor = MlpExtractor(
            self.features_extractor.features_dim,
            net_arch=net_arch,
            activation_fn=activation_fn,
        )
        self.action_net = nn.Linear(net_arch[-1], sum(action_dims))
        self.value_net = nn.Linear(net_arch[-1], 1)


class UnifiedSACPolicy(nn.Module):
    def __init__(self, observation_space, action_dims, net_arch, activation_fn):
        super().__init__()
        
        self.shared = Policy(
            observation_space,
            action_dims,
            net_arch=net_arch,
            activation_fn=activation_fn
        )
        self.action_dims = action_dims
    
    def forward(self, x):
        x = self.shared.features_extractor(x)
        x = self.shared.mlp_extractor.policy_net(x)
        x = self.shared.action_net(x)
        return x
    
    def sample(self, x, deterministic=False):
        logits = self.forward(x)
        
        # Split logits for each action dimension
        split_logits = torch.split(logits, self.action_dims, dim=-1)
        
        actions = []
        log_probs = []
        probs = []
        
        for logit in split_logits:
            distribution = Categorical(logits=logit)
            if deterministic:
                action = torch.argmax(logit, dim=-1)
            else:
                action = distribution.sample()
            
            log_prob = distribution.log_prob(action)
            prob = F.softmax(logit, dim=-1)
            
            actions.append(action)
            log_probs.append(log_prob)
            probs.append(prob)
        
        return (
            torch.stack(actions),
            torch.stack(log_probs),
            probs
        )
    
#policy = torch.load('policy_512_512_512_512_SiLU_3_statedict', map_location='cuda')


from stable_baselines3 import PPO, A2C
from stable_baselines3.common.env_util import make_vec_env
import gymnasium as gym
from pystk2_gymnasium import AgentSpec
from bbrl.agents.gymnasium import ParallelGymAgent, make_env
from functools import partial

vec_env = make_vec_env("supertuxkart/flattened_multidiscrete-v0", seed=0, n_envs=21, wrapper_class=lambda x : (PreprocessObservationWrapper(x)), env_kwargs={
    'render_mode':None, 'agent':AgentSpec(use_ai=False, name="walid"), 'track':'minigolf', 'laps':1#'difficulty':0, #'num_kart':2, 'difficulty':0
})

tracks = ['abyss',
 'black_forest',
 'candela_city',
 'cocoa_temple',
 'cornfield_crossing',
 'fortmagma',
 'gran_paradiso_island',
 'hacienda',
 'lighthouse',
 'mines',
 'minigolf',
 'olivermath',
 'ravenbridge_mansion',
 'sandtrack',
 'scotland',
 'snowmountain',
 'snowtuxpeak',
 'stk_enterprise',
 'volcano_island',
 'xr591',
 'zengarden']
for i,venv in enumerate(vec_env.envs):
    print(i, tracks[i%len(tracks)])
    venv.env.default_track = tracks[i%len(tracks)]


net_arch=[512,512,512,512]
activation_fn=torch.nn.SiLU
filename = 'policy_512_512_512_512_SiLU_noise1_statedict'

action_dims = [space.n for space in vec_env.action_space]
unified_policy = UnifiedSACPolicy(
    vec_env.observation_space, 
    action_dims, 
    net_arch=net_arch, 
    activation_fn=activation_fn
)
unified_policy.load_state_dict(torch.load(filename, map_location='cpu'))


..:: Antarctica Rendering Engine 2.0 ::..
0 abyss
1 black_forest
2 candela_city
3 cocoa_temple
4 cornfield_crossing
5 fortmagma
6 gran_paradiso_island
7 hacienda
8 lighthouse
9 mines
10 minigolf
11 olivermath
12 ravenbridge_mansion
13 sandtrack
14 scotland
15 snowmountain
16 snowtuxpeak
17 stk_enterprise
18 volcano_island
19 xr591
20 zengarden


  unified_policy.load_state_dict(torch.load(filename, map_location='cpu'))


<All keys matched successfully>

In [2]:
steps = [(
    3000,
    3_000_000,
)]
for n_steps, total_timesteps in steps:
    model = PPO(
        "MlpPolicy", 
        vec_env, 
        verbose=1, 
        policy_kwargs = dict(net_arch=net_arch, activation_fn=activation_fn,),
        device='cpu',
        learning_rate=0.0001,
        batch_size=128,
        n_steps=n_steps,
        tensorboard_log="./outputs/",
        ent_coef=0.001,
        clip_range=0.1,
    )
    print('DOING', n_steps, total_timesteps)
    model.policy.load_state_dict(unified_policy.shared.state_dict())
    model.learn(total_timesteps=total_timesteps, progress_bar=True)
    model.save(f'ppti_ppo2_{n_steps}_batch128_clip0001_ent0001')


Using cpu device


We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.
Info: (n_steps=3000 and n_envs=21)


..:: Antarctica Rendering Engine 2.0 ::..
..:: Antarctica Rendering Engine 2.0 ::..
..:: Antarctica Rendering Engine 2.0 ::..
..:: Antarctica Rendering Engine 2.0 ::..
..:: Antarctica Rendering Engine 2.0 ::..
..:: Antarctica Rendering Engine 2.0 ::..
..:: Antarctica Rendering Engine 2.0 ::..
..:: Antarctica Rendering Engine 2.0 ::..
..:: Antarctica Rendering Engine 2.0 ::..
..:: Antarctica Rendering Engine 2.0 ::..
..:: Antarctica Rendering Engine 2.0 ::..
..:: Antarctica Rendering Engine 2.0 ::..
..:: Antarctica Rendering Engine 2.0 ::..
..:: Antarctica Rendering Engine 2.0 ::..
..:: Antarctica Rendering Engine 2.0 ::..
..:: Antarctica Rendering Engine 2.0 ::..
..:: Antarctica Rendering Engine 2.0 ::..
..:: Antarctica Rendering Engine 2.0 ::..
..:: Antarctica Rendering Engine 2.0 ::..
..:: Antarctica Rendering Engine 2.0 ::..
DOING 3000 3000000
Logging to ./outputs/PPO_8


Output()

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 754      |
|    ep_rew_mean     | 197      |
| time/              |          |
|    fps             | 67       |
|    iterations      | 1        |
|    time_elapsed    | 930      |
|    total_timesteps | 63000    |
---------------------------------


---------------------------------------
| rollout/                |           |
|    ep_len_mean          | 693       |
|    ep_rew_mean          | 223       |
| time/                   |           |
|    fps                  | 65        |
|    iterations           | 2         |
|    time_elapsed         | 1916      |
|    total_timesteps      | 126000    |
| train/                  |           |
|    approx_kl            | 0.8266921 |
|    clip_fraction        | 0.514     |
|    clip_range           | 0.1       |
|    entropy_loss         | -0.618    |
|    explained_variance   | 0.00874   |
|    learning_rate        | 0.0001    |
|    loss                 | 1.11      |
|    n_updates            | 10        |
|    policy_gradient_loss | -0.00931  |
|    value_loss           | 16.3      |
---------------------------------------


---------------------------------------
| rollout/                |           |
|    ep_len_mean          | 723       |
|    ep_rew_mean          | 233       |
| time/                   |           |
|    fps                  | 64        |
|    iterations           | 3         |
|    time_elapsed         | 2915      |
|    total_timesteps      | 189000    |
| train/                  |           |
|    approx_kl            | 1.9385153 |
|    clip_fraction        | 0.492     |
|    clip_range           | 0.1       |
|    entropy_loss         | -0.647    |
|    explained_variance   | 0.695     |
|    learning_rate        | 0.0001    |
|    loss                 | 12.3      |
|    n_updates            | 20        |
|    policy_gradient_loss | 0.00327   |
|    value_loss           | 22        |
---------------------------------------


---------------------------------------
| rollout/                |           |
|    ep_len_mean          | 614       |
|    ep_rew_mean          | 253       |
| time/                   |           |
|    fps                  | 64        |
|    iterations           | 4         |
|    time_elapsed         | 3935      |
|    total_timesteps      | 252000    |
| train/                  |           |
|    approx_kl            | 0.6516664 |
|    clip_fraction        | 0.447     |
|    clip_range           | 0.1       |
|    entropy_loss         | -0.791    |
|    explained_variance   | 0.838     |
|    learning_rate        | 0.0001    |
|    loss                 | 10.8      |
|    n_updates            | 30        |
|    policy_gradient_loss | -0.00165  |
|    value_loss           | 32.1      |
---------------------------------------


----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 672        |
|    ep_rew_mean          | 265        |
| time/                   |            |
|    fps                  | 39         |
|    iterations           | 5          |
|    time_elapsed         | 8008       |
|    total_timesteps      | 315000     |
| train/                  |            |
|    approx_kl            | 0.07683474 |
|    clip_fraction        | 0.34       |
|    clip_range           | 0.1        |
|    entropy_loss         | -0.757     |
|    explained_variance   | 0.883      |
|    learning_rate        | 0.0001     |
|    loss                 | 12.2       |
|    n_updates            | 40         |
|    policy_gradient_loss | -0.00384   |
|    value_loss           | 42.4       |
----------------------------------------


----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 607        |
|    ep_rew_mean          | 290        |
| time/                   |            |
|    fps                  | 42         |
|    iterations           | 6          |
|    time_elapsed         | 8823       |
|    total_timesteps      | 378000     |
| train/                  |            |
|    approx_kl            | 0.19709535 |
|    clip_fraction        | 0.376      |
|    clip_range           | 0.1        |
|    entropy_loss         | -0.985     |
|    explained_variance   | 0.914      |
|    learning_rate        | 0.0001     |
|    loss                 | 6.28       |
|    n_updates            | 50         |
|    policy_gradient_loss | 0.00187    |
|    value_loss           | 38.7       |
----------------------------------------


----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 676        |
|    ep_rew_mean          | 285        |
| time/                   |            |
|    fps                  | 45         |
|    iterations           | 7          |
|    time_elapsed         | 9640       |
|    total_timesteps      | 441000     |
| train/                  |            |
|    approx_kl            | 0.17149784 |
|    clip_fraction        | 0.353      |
|    clip_range           | 0.1        |
|    entropy_loss         | -0.814     |
|    explained_variance   | 0.845      |
|    learning_rate        | 0.0001     |
|    loss                 | 4.08       |
|    n_updates            | 60         |
|    policy_gradient_loss | -0.00634   |
|    value_loss           | 72.1       |
----------------------------------------


KeyboardInterrupt: 