In [1]:
from stable_baselines3 import PPO, A2C
from stable_baselines3.common.env_util import make_vec_env
from stk_actor.agent import UnifiedSACPolicy
from stk_actor.wrappers import PreprocessObservationWrapper
import torch

from pystk2_gymnasium import AgentSpec

vec_env = make_vec_env("supertuxkart/flattened_multidiscrete-v0", seed=0,
                        # wrapper_class=lambda x : ObsTimeExtensionWrapper(PreprocessObservationWrapper(x)), 
                        wrapper_class=lambda x : (PreprocessObservationWrapper(x)), 
                       n_envs=1, env_kwargs={
    'render_mode':None, 'agent':AgentSpec(use_ai=False, name="walid"),  #'track':'abyss', #'num_kart':2, 'difficulty':0
})

# net_arch=[512,512,512]
# activation_fn=torch.nn.SiLU
# filename = 'policy_512_512_512_SiLU_2_statedict'

net_arch=[512,512,512,512]
activation_fn=torch.nn.SiLU
# filename = 'policy_512_512_512_512_SiLU_3_statedict'
# filename = 'airl_policy_512_512_512_512_SiLU_statedict'
filename = 'policy_l3_512_512_512_512_SiLU_statedict'

model = PPO(
        "MlpPolicy", 
        vec_env, 
        verbose=1, 
        policy_kwargs = dict(net_arch=net_arch, activation_fn=activation_fn,), 
        device='cpu',
        n_steps=5,
        tensorboard_log="./outputs/"
)
action_dims = [space.n for space in vec_env.action_space]
unified_policy = UnifiedSACPolicy(
    vec_env.observation_space, 
    action_dims, 
    net_arch=net_arch, 
    activation_fn=activation_fn
)
unified_policy.load_state_dict(torch.load(filename))
model.policy.load_state_dict(unified_policy.shared.state_dict())
# model.policy.load_state_dict(torch.load(filename))
vec_env.close()
# model = PPO.load("ppti_ppo_1500", custom_objects={'policy_kwargs' :  dict(net_arch=net_arch, activation_fn=activation_fn), })
# model = PPO.load("ppti_ppo_3000_batch128_clip0001_ent0001", custom_objects={'policy_kwargs' :  dict(net_arch=net_arch, activation_fn=activation_fn), })
mean = torch.load('buffer_mean', map_location='cpu')
std = torch.load('buffer_std', map_location='cpu')

..:: Antarctica Rendering Engine 2.0 ::..
Using cpu device


We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.
Info: (n_steps=5 and n_envs=1)
  unified_policy.load_state_dict(torch.load(filename))
  mean = torch.load('buffer_mean', map_location='cpu')
  std = torch.load('buffer_std', map_location='cpu')


In [2]:
mean.tolist()

[1.2639206647872925,
 0.0,
 -0.07462655007839203,
 -0.07617677748203278,
 3.325659990310669,
 -0.6260609030723572,
 686.5313720703125,
 0.90716952085495,
 3.248101347708143e-05,
 -2.2646768229606096e-06,
 0.7184554934501648,
 -0.00979185476899147,
 -0.018949387595057487,
 42.55070114135742,
 -0.0018095501000061631,
 -0.0037681907415390015,
 50.57154083251953,
 0.006202289834618568,
 0.0015260628424584866,
 60.573333740234375,
 -0.0015078107826411724,
 -0.008679982274770737,
 83.0267333984375,
 -0.004257954191416502,
 -0.052518781274557114,
 92.90336608886719,
 -0.04041219875216484,
 -0.07217442244291306,
 29.27079963684082,
 -0.06311116367578506,
 -0.09916721284389496,
 44.90357208251953,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.10005245357751846,
 682.83642578125,
 690.8487548828125,
 685.5414428710938,
 693.1297607421875,
 687.2418212890625,
 694.6862182617188,
 688.3353271484375,
 695.63671875,
 689.5924682617188,
 696.7789306640625,
 0.036972709000110626,
 0.0329776

In [3]:
std.tolist()

[4.425504207611084,
 0.0,
 1.603779673576355,
 1.7964426279067993,
 9.993358612060547,
 10.51357650756836,
 512.2705078125,
 1.6997371912002563,
 0.026459209620952606,
 0.010543138720095158,
 0.02026776410639286,
 0.7250668406486511,
 0.39837801456451416,
 34.88729476928711,
 0.7514939904212952,
 0.44161456823349,
 35.87752914428711,
 0.8336662650108337,
 0.539232075214386,
 41.78753662109375,
 1.0223793983459473,
 0.7299631237983704,
 47.53340530395508,
 1.098395586013794,
 0.8274492025375366,
 47.81513214111328,
 1.6523897647857666,
 1.749129295349121,
 28.97491455078125,
 2.2789063453674316,
 2.5216972827911377,
 40.0125846862793,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.06416473537683487,
 512.6343994140625,
 511.6667175292969,
 511.3175964355469,
 510.3323059082031,
 509.7705383300781,
 508.8245544433594,
 508.5116271972656,
 507.6471252441406,
 507.43463134765625,
 506.58172607421875,
 0.6846545338630676,
 0.5445173382759094,
 13.64755916595459,
 0.447633206844329

In [None]:
tracks = [
    # 'black_forest',
    'olivermath',
    # 'minigolf',
    # 'gran_paradiso_island',
    # 'candela_city',
    # 'fortmagma'
    # 'ravenbridge_mansion',
    # 'mines',
    # 'snowmountain',
    #  'abyss',
    # 'cocoa_temple',
    # 'cornfield_crossing',
    # 'hacienda',
    # 'lighthouse',
    # 'sandtrack',
    # 'scotland',
    # 'snowtuxpeak',
    # 'stk_enterprise',
    # 'volcano_island',
    # 'xr591',
    # 'zengarden'
]
for track in tracks:
# while True:
    vec_env = make_vec_env("supertuxkart/flattened_multidiscrete-v0", #seed=0,
                        # wrapper_class=lambda x : ObsTimeExtensionWrapper(PreprocessObservationWrapper(x)), 
                        wrapper_class=lambda x : (PreprocessObservationWrapper(x, ret_dict=False, norm=True)), 
                       n_envs=1, env_kwargs={
    'render_mode':"human", 'agent':AgentSpec(use_ai=False, name="walid"), 'laps':1, 'track':track, #'num_kart':2, 'difficulty':0
    })
    obs = vec_env.reset()
    while True:
        ix = 0
        obs = torch.tensor(obs)
        
        # obs = (obs - mean) / (std + 1e-8)
        action, _states = model.predict(obs, deterministic=False)
        obs, rewards, dones, info = vec_env.step(action)
        vec_env.render("human")
        ix += 1
        if ix == 256:
            break
    vec_env.close()




..:: Antarctica Rendering Engine 2.0 ::..


2025-01-20 15:39:07.170 python[4777:4084785] +[IMKClient subclass]: chose IMKClient_Modern
2025-01-20 15:39:07.170 python[4777:4084785] +[IMKInputSession subclass]: chose IMKInputSession_Modern


normed_array.shape=torch.Size([154])
normed_array.shape=torch.Size([154])
normed_array.shape=torch.Size([154])
normed_array.shape=torch.Size([154])
normed_array.shape=torch.Size([154])
normed_array.shape=torch.Size([154])
normed_array.shape=torch.Size([154])
normed_array.shape=torch.Size([154])
normed_array.shape=torch.Size([154])
normed_array.shape=torch.Size([154])
normed_array.shape=torch.Size([154])
normed_array.shape=torch.Size([154])
normed_array.shape=torch.Size([154])
normed_array.shape=torch.Size([154])
normed_array.shape=torch.Size([154])


  logger.warn(


normed_array.shape=torch.Size([154])
normed_array.shape=torch.Size([154])
normed_array.shape=torch.Size([154])
normed_array.shape=torch.Size([154])
normed_array.shape=torch.Size([154])
normed_array.shape=torch.Size([154])
normed_array.shape=torch.Size([154])
normed_array.shape=torch.Size([154])
normed_array.shape=torch.Size([154])
normed_array.shape=torch.Size([154])
normed_array.shape=torch.Size([154])
normed_array.shape=torch.Size([154])
normed_array.shape=torch.Size([154])
normed_array.shape=torch.Size([154])
normed_array.shape=torch.Size([154])
normed_array.shape=torch.Size([154])
normed_array.shape=torch.Size([154])
normed_array.shape=torch.Size([154])
normed_array.shape=torch.Size([154])
normed_array.shape=torch.Size([154])
normed_array.shape=torch.Size([154])
normed_array.shape=torch.Size([154])
normed_array.shape=torch.Size([154])
normed_array.shape=torch.Size([154])
normed_array.shape=torch.Size([154])
normed_array.shape=torch.Size([154])
normed_array.shape=torch.Size([154])
n

In [5]:
# import torch
# torch.save( model.policy.state_dict(), 'ppo_policy_512_512_512_512_SiLU_4_statedict')