In [1]:
import gymnasium as gym
from stk_actor.wrappers import PreprocessObservationWrapper
import torch.nn as nn
from torch.distributions import Categorical
from typing import Dict, List, Tuple, Union, Type
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
from gymnasium import spaces

def get_device(device: Union[torch.device, str] = "auto") -> torch.device:
    if device == "auto":
        device = "cuda"
    device = torch.device(device)
    if device.type == torch.device("cuda").type and not torch.cuda.is_available():
        return torch.device("cpu")
    return device

class BaseFeaturesExtractor(nn.Module):
    def __init__(self, observation_space: gym.Space, features_dim: int = 0) -> None:
        super().__init__()
        assert features_dim > 0
        self._observation_space = observation_space
        self._features_dim = features_dim
    @property
    def features_dim(self) -> int:
        return self._features_dim

def get_flattened_obs_dim(observation_space: spaces.Space) -> int:
    if isinstance(observation_space, spaces.MultiDiscrete):
        return sum(observation_space.nvec)
    else:
        return spaces.utils.flatdim(observation_space)

class FlattenExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.Space) -> None:
        super().__init__(observation_space, get_flattened_obs_dim(observation_space))
        self.flatten = nn.Flatten()
    def forward(self, observations: torch.Tensor) -> torch.Tensor:
        return self.flatten(observations)
    
class MlpExtractor(nn.Module):
    def __init__(
        self,
        feature_dim: int,
        net_arch: Union[List[int], Dict[str, List[int]]],
        activation_fn: Type[nn.Module],
        device: Union[torch.device, str] = "auto",
    ) -> None:
        super().__init__()
        # device = torch.get_device(device)
        policy_net: List[nn.Module] = []
        value_net: List[nn.Module] = []
        last_layer_dim_pi = feature_dim
        last_layer_dim_vf = feature_dim

        if isinstance(net_arch, dict):
            pi_layers_dims = net_arch.get("pi", []) 
            vf_layers_dims = net_arch.get("vf", []) 
        else:
            pi_layers_dims = vf_layers_dims = net_arch
        for curr_layer_dim in pi_layers_dims:
            policy_net.append(nn.Linear(last_layer_dim_pi, curr_layer_dim))
            policy_net.append(activation_fn())
            last_layer_dim_pi = curr_layer_dim
        for curr_layer_dim in vf_layers_dims:
            value_net.append(nn.Linear(last_layer_dim_vf, curr_layer_dim))
            value_net.append(activation_fn())
            last_layer_dim_vf = curr_layer_dim

        self.latent_dim_pi = last_layer_dim_pi
        self.latent_dim_vf = last_layer_dim_vf
        self.policy_net = nn.Sequential(*policy_net)#.to(device)
        self.value_net = nn.Sequential(*value_net)#.to(device)

    def forward(self, features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        :return: latent_policy, latent_value of the specified network.
            If all layers are shared, then ``latent_policy == latent_value``
        """
        return self.forward_actor(features), self.forward_critic(features)

    def forward_actor(self, features: torch.Tensor) -> torch.Tensor:
        return self.policy_net(features)

    def forward_critic(self, features: torch.Tensor) -> torch.Tensor:
        return self.value_net(features)

    
class Policy(nn.Module):
    def __init__(self, observation_space, action_dims, net_arch, activation_fn,):
        super().__init__()
        self.features_extractor = FlattenExtractor(observation_space)
        self.pi_features_extractor = self.features_extractor
        self.vf_features_extractor = self.features_extractor
        self.mlp_extractor = MlpExtractor(
            self.features_extractor.features_dim,
            net_arch=net_arch,
            activation_fn=activation_fn,
        )
        self.action_net = nn.Linear(net_arch[-1], sum(action_dims))
        self.value_net = nn.Linear(net_arch[-1], 1)


class UnifiedSACPolicy(nn.Module):
    def __init__(self, observation_space, action_dims, net_arch, activation_fn):
        super().__init__()
        
        self.shared = Policy(
            observation_space,
            action_dims,
            net_arch=net_arch,
            activation_fn=activation_fn
        )
        self.action_dims = action_dims
    
    def forward(self, x):
        x = self.shared.features_extractor(x)
        x = self.shared.mlp_extractor.policy_net(x)
        x = self.shared.action_net(x)
        return x
    
    def sample(self, x, deterministic=False):
        logits = self.forward(x)
        
        # Split logits for each action dimension
        split_logits = torch.split(logits, self.action_dims, dim=-1)
        
        actions = []
        log_probs = []
        probs = []
        
        for logit in split_logits:
            distribution = Categorical(logits=logit)
            if deterministic:
                action = torch.argmax(logit, dim=-1)
            else:
                action = distribution.sample()
            
            log_prob = distribution.log_prob(action)
            prob = F.softmax(logit, dim=-1)
            
            actions.append(action)
            log_probs.append(log_prob)
            probs.append(prob)
        
        return (
            torch.stack(actions),
            torch.stack(log_probs),
            probs
        )
    
#policy = torch.load('policy_512_512_512_512_SiLU_3_statedict', map_location='cuda')


from stable_baselines3 import PPO, A2C
from stable_baselines3.common.env_util import make_vec_env
import gymnasium as gym
from pystk2_gymnasium import AgentSpec
from bbrl.agents.gymnasium import ParallelGymAgent, make_env
from functools import partial

vec_env = make_vec_env("supertuxkart/flattened_multidiscrete-v0", seed=0, n_envs=21, wrapper_class=lambda x : (PreprocessObservationWrapper(x, ret_dict=False)), env_kwargs={
    'render_mode':None, 'agent':AgentSpec(use_ai=False, name="walid"), 'track':'minigolf', 'laps':2#'difficulty':0, #'num_kart':2, 'difficulty':0
})

tracks = ['abyss',
 'black_forest',
 'candela_city',
 'cocoa_temple',
 'cornfield_crossing',
 'fortmagma',
 'gran_paradiso_island',
 'hacienda',
 'lighthouse',
 'mines',
 'minigolf',
 'olivermath',
 'ravenbridge_mansion',
 'sandtrack',
 'scotland',
 'snowmountain',
 'snowtuxpeak',
 'stk_enterprise',
 'volcano_island',
 'xr591',
 'zengarden']
for i,venv in enumerate(vec_env.envs):
    print(i, tracks[i%len(tracks)])
    venv.env.default_track = tracks[i%len(tracks)]


net_arch=[512,512,512,512]
activation_fn=torch.nn.SiLU
filename = 'policy_512_512_512_512_SiLU_3_statedict'

action_dims = [space.n for space in vec_env.action_space]
unified_policy = UnifiedSACPolicy(
    vec_env.observation_space, 
    action_dims, 
    net_arch=net_arch, 
    activation_fn=activation_fn
)
unified_policy.load_state_dict(torch.load(filename, map_location='cpu'))


..:: Antarctica Rendering Engine 2.0 ::..
0 abyss
1 black_forest
2 candela_city
3 cocoa_temple
4 cornfield_crossing
5 fortmagma
6 gran_paradiso_island
7 hacienda
8 lighthouse
9 mines
10 minigolf
11 olivermath
12 ravenbridge_mansion
13 sandtrack
14 scotland
15 snowmountain
16 snowtuxpeak
17 stk_enterprise
18 volcano_island
19 xr591
20 zengarden


  unified_policy.load_state_dict(torch.load(filename, map_location='cpu'))


<All keys matched successfully>

In [None]:
steps = [(
    #16384,
    1500,
    3_000_000,
)]
for n_steps, total_timesteps in steps:
    model = PPO(
        "MlpPolicy", 
        vec_env, 
        verbose=1, 
        policy_kwargs = dict(net_arch=net_arch, activation_fn=activation_fn,),
        device='cpu',
        learning_rate=0.0001,
        batch_size=128,
        n_steps=n_steps,
        tensorboard_log="./outputs/",
        ent_coef=0.001,
        clip_range=0.001,
    )
    print('DOING', n_steps, total_timesteps)
    model.policy.load_state_dict(unified_policy.shared.state_dict())
    model.policy.load_state_dict(
        PPO.load(
            "ppti_ppo_1500", 
            custom_objects={'policy_kwargs' : dict(net_arch=net_arch, activation_fn=activation_fn), }
        ).policy.state_dict()
    )
    model.learn(total_timesteps=total_timesteps, progress_bar=True)#callback=SummaryWriterCallback())
    model.save(f'ppti_ppo_{n_steps}')


Using cpu device


We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.
Info: (n_steps=1500 and n_envs=21)


..:: Antarctica Rendering Engine 2.0 ::..
..:: Antarctica Rendering Engine 2.0 ::..
..:: Antarctica Rendering Engine 2.0 ::..
..:: Antarctica Rendering Engine 2.0 ::..
..:: Antarctica Rendering Engine 2.0 ::..
..:: Antarctica Rendering Engine 2.0 ::..
..:: Antarctica Rendering Engine 2.0 ::..
..:: Antarctica Rendering Engine 2.0 ::..
..:: Antarctica Rendering Engine 2.0 ::..
..:: Antarctica Rendering Engine 2.0 ::..
..:: Antarctica Rendering Engine 2.0 ::..
..:: Antarctica Rendering Engine 2.0 ::..
..:: Antarctica Rendering Engine 2.0 ::..
..:: Antarctica Rendering Engine 2.0 ::..
..:: Antarctica Rendering Engine 2.0 ::..
..:: Antarctica Rendering Engine 2.0 ::..
..:: Antarctica Rendering Engine 2.0 ::..
..:: Antarctica Rendering Engine 2.0 ::..
..:: Antarctica Rendering Engine 2.0 ::..
..:: Antarctica Rendering Engine 2.0 ::..
DOING 1500 3000000
Logging to ./outputs/PPO_4


Output()

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 937      |
|    ep_rew_mean     | 540      |
| time/              |          |
|    fps             | 88       |
|    iterations      | 1        |
|    time_elapsed    | 357      |
|    total_timesteps | 31500    |
---------------------------------


---------------------------------------
| rollout/                |           |
|    ep_len_mean          | 870       |
|    ep_rew_mean          | 561       |
| time/                   |           |
|    fps                  | 86        |
|    iterations           | 2         |
|    time_elapsed         | 730       |
|    total_timesteps      | 63000     |
| train/                  |           |
|    approx_kl            | 0.4866613 |
|    clip_fraction        | 0.352     |
|    clip_range           | 0.01      |
|    entropy_loss         | -0.122    |
|    explained_variance   | 0.899     |
|    learning_rate        | 0.0001    |
|    loss                 | 79.5      |
|    n_updates            | 10        |
|    policy_gradient_loss | 0.0456    |
|    value_loss           | 114       |
---------------------------------------


---------------------------------------
| rollout/                |           |
|    ep_len_mean          | 850       |
|    ep_rew_mean          | 572       |
| time/                   |           |
|    fps                  | 85        |
|    iterations           | 3         |
|    time_elapsed         | 1101      |
|    total_timesteps      | 94500     |
| train/                  |           |
|    approx_kl            | 1.5720984 |
|    clip_fraction        | 0.402     |
|    clip_range           | 0.01      |
|    entropy_loss         | -0.126    |
|    explained_variance   | 0.885     |
|    learning_rate        | 0.0001    |
|    loss                 | 30.9      |
|    n_updates            | 20        |
|    policy_gradient_loss | 0.0859    |
|    value_loss           | 129       |
---------------------------------------


---------------------------------------
| rollout/                |           |
|    ep_len_mean          | 954       |
|    ep_rew_mean          | 566       |
| time/                   |           |
|    fps                  | 85        |
|    iterations           | 4         |
|    time_elapsed         | 1473      |
|    total_timesteps      | 126000    |
| train/                  |           |
|    approx_kl            | 0.5220405 |
|    clip_fraction        | 0.437     |
|    clip_range           | 0.01      |
|    entropy_loss         | -0.135    |
|    explained_variance   | 0.904     |
|    learning_rate        | 0.0001    |
|    loss                 | 5.04      |
|    n_updates            | 30        |
|    policy_gradient_loss | 0.0898    |
|    value_loss           | 127       |
---------------------------------------


---------------------------------------
| rollout/                |           |
|    ep_len_mean          | 1.07e+03  |
|    ep_rew_mean          | 530       |
| time/                   |           |
|    fps                  | 85        |
|    iterations           | 5         |
|    time_elapsed         | 1846      |
|    total_timesteps      | 157500    |
| train/                  |           |
|    approx_kl            | 0.6807845 |
|    clip_fraction        | 0.229     |
|    clip_range           | 0.01      |
|    entropy_loss         | -0.077    |
|    explained_variance   | 0.919     |
|    learning_rate        | 0.0001    |
|    loss                 | 26.3      |
|    n_updates            | 40        |
|    policy_gradient_loss | 0.0444    |
|    value_loss           | 60.9      |
---------------------------------------


---------------------------------------
| rollout/                |           |
|    ep_len_mean          | 1.2e+03   |
|    ep_rew_mean          | 456       |
| time/                   |           |
|    fps                  | 85        |
|    iterations           | 6         |
|    time_elapsed         | 2220      |
|    total_timesteps      | 189000    |
| train/                  |           |
|    approx_kl            | 0.3089201 |
|    clip_fraction        | 0.274     |
|    clip_range           | 0.01      |
|    entropy_loss         | -0.0807   |
|    explained_variance   | 0.917     |
|    learning_rate        | 0.0001    |
|    loss                 | 3.57      |
|    n_updates            | 50        |
|    policy_gradient_loss | 0.0474    |
|    value_loss           | 51.4      |
---------------------------------------


----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 1.35e+03   |
|    ep_rew_mean          | 369        |
| time/                   |            |
|    fps                  | 84         |
|    iterations           | 7          |
|    time_elapsed         | 2594       |
|    total_timesteps      | 220500     |
| train/                  |            |
|    approx_kl            | 0.33570236 |
|    clip_fraction        | 0.413      |
|    clip_range           | 0.01       |
|    entropy_loss         | -0.168     |
|    explained_variance   | 0.932      |
|    learning_rate        | 0.0001     |
|    loss                 | 47.8       |
|    n_updates            | 60         |
|    policy_gradient_loss | 0.0375     |
|    value_loss           | 38.8       |
----------------------------------------


--------------------------------------
| rollout/                |          |
|    ep_len_mean          | 1.48e+03 |
|    ep_rew_mean          | 312      |
| time/                   |          |
|    fps                  | 84       |
|    iterations           | 8        |
|    time_elapsed         | 2968     |
|    total_timesteps      | 252000   |
| train/                  |          |
|    approx_kl            | 1.465121 |
|    clip_fraction        | 0.413    |
|    clip_range           | 0.01     |
|    entropy_loss         | -0.151   |
|    explained_variance   | 0.955    |
|    learning_rate        | 0.0001   |
|    loss                 | 12.1     |
|    n_updates            | 70       |
|    policy_gradient_loss | 0.0448   |
|    value_loss           | 31.6     |
--------------------------------------


--------------------------------------
| rollout/                |          |
|    ep_len_mean          | 1.5e+03  |
|    ep_rew_mean          | 271      |
| time/                   |          |
|    fps                  | 84       |
|    iterations           | 9        |
|    time_elapsed         | 3340     |
|    total_timesteps      | 283500   |
| train/                  |          |
|    approx_kl            | 5.326346 |
|    clip_fraction        | 0.504    |
|    clip_range           | 0.01     |
|    entropy_loss         | -0.139   |
|    explained_variance   | 0.958    |
|    learning_rate        | 0.0001   |
|    loss                 | 1.81     |
|    n_updates            | 80       |
|    policy_gradient_loss | 0.0509   |
|    value_loss           | 35.4     |
--------------------------------------


----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 1.5e+03    |
|    ep_rew_mean          | 287        |
| time/                   |            |
|    fps                  | 84         |
|    iterations           | 10         |
|    time_elapsed         | 3714       |
|    total_timesteps      | 315000     |
| train/                  |            |
|    approx_kl            | 0.11814453 |
|    clip_fraction        | 0.223      |
|    clip_range           | 0.01       |
|    entropy_loss         | -0.087     |
|    explained_variance   | 0.956      |
|    learning_rate        | 0.0001     |
|    loss                 | 3.81       |
|    n_updates            | 90         |
|    policy_gradient_loss | 0.0213     |
|    value_loss           | 44.4       |
----------------------------------------


In [3]:
model.save(f'ppti_ppo_{n_steps}')